Commit d279c180 authored by Xavier Thompson's avatar Xavier Thompson

Destroy forks with stolen continuations sooner

Previously when forked tasks didn't finish synchronously, they were
added to a list of children tasks and destroyed at the next sync.

But this meant that in the case of an infinite loop spawning new forks,
such memory could grow unboundedly and never be reclaimed.

Fixing this also required changing how exceptions in forks are propagated.
parent 4f16398b
#ifndef TYPON_FORK_HPP_INCLUDED #ifndef TYPON_FORK_HPP_INCLUDED
#define TYPON_FORK_HPP_INCLUDED #define TYPON_FORK_HPP_INCLUDED
#include <atomic>
#include <coroutine> #include <coroutine>
#include <cstdint> #include <cstdint>
...@@ -34,7 +35,12 @@ namespace typon ...@@ -34,7 +35,12 @@ namespace typon
struct promise_type : Result<T> struct promise_type : Result<T>
{ {
Span * _span; Span * _span;
ForkList::Node _node; u64 _rank;
ForkNode _node;
promise_type() noexcept
: _node{ std::coroutine_handle<promise_type>::from_promise(*this) }
{}
Fork get_return_object() noexcept Fork get_return_object() noexcept
{ {
...@@ -57,6 +63,15 @@ namespace typon ...@@ -57,6 +63,15 @@ namespace typon
{ {
return span->_coroutine; return span->_coroutine;
} }
if (auto & exception = coroutine.promise()._exception)
{
span->set_exception(exception, coroutine.promise()._rank);
}
auto & ref = coroutine.promise()._node._ref;
if (!ref.exchange(false, std::memory_order_acq_rel))
{
coroutine.destroy();
}
u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel); u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel);
if (n == 1) if (n == 1)
{ {
...@@ -73,7 +88,6 @@ namespace typon ...@@ -73,7 +88,6 @@ namespace typon
struct awaitable struct awaitable
{ {
std::coroutine_handle<promise_type> _coroutine; std::coroutine_handle<promise_type> _coroutine;
u64 _thefts;
awaitable(std::coroutine_handle<promise_type> coroutine) noexcept awaitable(std::coroutine_handle<promise_type> coroutine) noexcept
: _coroutine(coroutine) : _coroutine(coroutine)
...@@ -89,7 +103,7 @@ namespace typon ...@@ -89,7 +103,7 @@ namespace typon
{ {
Span * span = &(continuation.promise()._span); Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span; _coroutine.promise()._span = span;
_thefts = span->_thefts; _coroutine.promise()._rank = span->_thefts;
std::coroutine_handle<> on_stack_handle = _coroutine; std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(span); Scheduler::push(span);
...@@ -98,13 +112,9 @@ namespace typon ...@@ -98,13 +112,9 @@ namespace typon
auto await_resume() auto await_resume()
{ {
auto span = _coroutine.promise()._span; auto thefts = _coroutine.promise()._span->_thefts;
bool stolen = span->_thefts > _thefts; auto rank = _coroutine.promise()._rank;
if (stolen) return Forked<T>(_coroutine, (thefts == rank));
{
span->_children.insert(_coroutine);
}
return Forked<T>(_coroutine, !stolen);
} }
}; };
......
...@@ -13,138 +13,102 @@ ...@@ -13,138 +13,102 @@
namespace typon namespace typon
{ {
template <typename T> struct ForkNode
struct Forked
{ {
using value_type = T; std::coroutine_handle<> _coroutine;
std::atomic<bool> _ref {true};
};
Result<T> * _result = nullptr;
template <typename T>
struct ForkResult
{
union union
{ {
T _value; T _value;
ForkNode * _node;
}; };
template <typename Promise> template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready) void construct_value(std::coroutine_handle<Promise> coroutine)
{ {
if (ready) std::construct_at(std::addressof(_value), coroutine.promise().get());
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
std::construct_at(std::addressof(_value), coroutine.promise().get());
}
else
{
_result = &(coroutine.promise());
}
} }
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>) void construct_value(ForkResult && other)
{ {
_result = other._result; std::construct_at(std::addressof(_value), std::move(other._value));
if (!_result)
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
} }
Forked& operator=(Forked && other) T get_value() noexcept
noexcept(std::is_nothrow_move_constructible_v<T>)
{ {
if (this != &other) return _value;
{
Forked old { std::move(*this) };
_result = other._result;
if (!_result)
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
}
return *this;
} }
~Forked() void destroy_value() noexcept
{ {
if (!_result) std::destroy_at(std::addressof(_value));
{
std::destroy_at(std::addressof(_value));
}
}
T get()
{
if (_result)
{
return _result->get();
}
return _value;
} }
}; };
template <typename T> template <typename T>
requires requires { sizeof(T); } && (sizeof(T) > 2 * sizeof(void*)) struct ForkResult<T&>
struct Forked<T>
{ {
using value_type = T; union
{
Result<T> * _result = nullptr; T * _value;
std::coroutine_handle<> _coroutine; ForkNode * _node;
};
template <typename Promise> template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready) void construct_value(std::coroutine_handle<Promise> coroutine)
{ {
_result = &(coroutine.promise()); _value = std::addressof(coroutine.promise().get());
if (ready)
{
_coroutine = coroutine;
if (coroutine.promise()._exception)
{
std::rethrow_exception(coroutine.promise()._exception);
}
}
} }
Forked(const Forked &) = delete; void construct_value(ForkResult && other)
Forked& operator=(const Forked &) = delete; {
_value = other._value;
Forked(Forked && other) noexcept }
: _result(other._result)
, _coroutine(std::exchange(other._coroutine, nullptr))
{}
Forked& operator=(Forked && other) noexcept T& get_value() noexcept
{ {
std::swap(_coroutine, other._coroutine); return *_value;
std::swap(_result, other._result);
return *this;
} }
~Forked() void destroy_value() noexcept {}
};
template <>
struct ForkResult<void>
{
ForkNode * _node {nullptr};
template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine)
{ {
if (_coroutine) coroutine.promise().get();
{
_coroutine.destroy();
}
} }
T get() void construct_value(ForkResult && other) noexcept
{ {
return _result->get(); (void) other;
} }
void get_value() noexcept {}
void destroy_value() noexcept {}
}; };
template <typename T> template <typename T>
requires std::is_trivially_copyable_v<T> struct Forked : ForkResult<T>
struct Forked<T>
{ {
using value_type = T; using value_type = T;
Result<T> * _result = nullptr; Result<T> * _result = nullptr;
union
{
value_type _value;
};
template <typename Promise> template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready) Forked(std::coroutine_handle<Promise> coroutine, bool ready)
...@@ -152,97 +116,140 @@ namespace typon ...@@ -152,97 +116,140 @@ namespace typon
if (ready) if (ready)
{ {
Defer defer { [&coroutine]() { coroutine.destroy(); } }; Defer defer { [&coroutine]() { coroutine.destroy(); } };
std::construct_at(std::addressof(_value), coroutine.promise().get()); this->construct_value(coroutine);
} }
else else
{ {
this->_node = &(coroutine.promise()._node);
_result = &(coroutine.promise()); _result = &(coroutine.promise());
} }
} }
Forked(Forked && other) = default; Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
Forked& operator=(Forked && other) = default;
~Forked()
{ {
if (!_result) _result = other._result;
if (_result)
{ {
std::destroy_at(std::addressof(_value)); this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
} }
} }
T get() Forked& operator=(Forked && other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{ {
if (_result) if (this != &other)
{ {
return _result->get(); Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
} }
return _value; return *this;
} }
};
template <typename T>
struct Forked<T&>
{
using value_type = T&;
Result<T> * _result = nullptr; ~Forked()
T * _value;
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{ {
if (ready) if (_result)
{ {
Defer defer { [&coroutine]() { coroutine.destroy(); } }; if (auto node = this->_node)
_value = std::addressof(coroutine.promise().get()); {
if (!node->_ref.exchange(false, std::memory_order_acq_rel))
{
node->_coroutine.destroy();
}
}
} }
else else
{ {
_result = &(coroutine.promise()); this->destroy_value();
} }
} }
T& get() T get()
{ {
if (_result) if (_result)
{ {
return _result->get(); return _result->get();
} }
return *_value; return this->get_value();
} }
}; };
template <> template <typename T>
struct Forked<void> requires requires { sizeof(T); } && (sizeof(T) > 2 * sizeof(void*))
struct Forked<T>
{ {
using value_type = void; using value_type = T;
Result<void> * _result = nullptr; bool _ready;
Result<T> * _result;
ForkNode * _node;
template <typename Promise> template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready) Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{ {
_ready = ready;
_result = &(coroutine.promise());
_node = &(coroutine.promise()._node);
if (ready) if (ready)
{ {
Defer defer { [&coroutine]() { coroutine.destroy(); } }; if (auto & exception = coroutine.promise()._exception)
coroutine.promise().get(); {
} std::rethrow_exception(exception);
else }
{
_result = &(coroutine.promise());
} }
} }
void get() Forked(const Forked &) = delete;
Forked& operator=(const Forked &) = delete;
Forked(Forked && other) noexcept
: _ready(other._ready)
, _result(other._result)
, _node(std::exchange(other._node, nullptr))
{}
Forked& operator=(Forked && other) noexcept
{ {
if (_result) std::swap(_ready, other._ready);
std::swap(_node, other._node);
std::swap(_result, other._result);
return *this;
}
~Forked()
{
if (_node)
{ {
_result->get(); if (_ready)
{
_node->_coroutine.destroy();
}
else
{
if (!_node->_ref.exchange(false, std::memory_order_acq_rel))
{
_node->_coroutine.destroy();
}
}
} }
} }
T get()
{
return _result->get();
}
}; };
} }
......
...@@ -40,7 +40,6 @@ namespace typon ...@@ -40,7 +40,6 @@ namespace typon
{ {
if (_coroutine) if (_coroutine)
{ {
_coroutine.promise()._span._children.clear();
_coroutine.destroy(); _coroutine.destroy();
} }
} }
...@@ -107,7 +106,7 @@ namespace typon ...@@ -107,7 +106,7 @@ namespace typon
{ {
_span._thefts = 0; _span._thefts = 0;
_span._n.store(UMAX, std::memory_order_release); _span._n.store(UMAX, std::memory_order_release);
_span._children.check_exceptions(); _span.check_exception();
} }
}; };
...@@ -158,7 +157,7 @@ namespace typon ...@@ -158,7 +157,7 @@ namespace typon
decltype(auto) await_resume() decltype(auto) await_resume()
{ {
_coroutine.promise()._span._children.check_exceptions(); _coroutine.promise()._span.check_exception();
return _coroutine.promise().get(); return _coroutine.promise().get();
} }
}; };
......
...@@ -8,77 +8,70 @@ ...@@ -8,77 +8,70 @@
#include <exception> #include <exception>
#include <limits> #include <limits>
#include <typon/defer.hpp>
#include <typon/theft_point.hpp> #include <typon/theft_point.hpp>
namespace typon namespace typon
{ {
struct ForkList struct Span : TheftPoint
{ {
struct Node using u64 = TheftPoint::u64;
struct Error
{ {
Node * _next; u64 _rank;
std::coroutine_handle<> _coroutine; std::exception_ptr _exception;
std::exception_ptr * _exception;
}; };
Node * _first = nullptr; static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation;
std::atomic<Error *> _error { nullptr };
template <typename Promise> std::atomic<u64> _n = UMAX;
requires requires (Promise p)
Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine)
{}
~Span()
{
if (auto error = _error.load(std::memory_order_relaxed))
{ {
{ p._exception } -> std::same_as<std::exception_ptr&>; delete error;
{ p._node } -> std::same_as<Node&>;
} }
void insert(std::coroutine_handle<Promise> coroutine) noexcept
{
std::exception_ptr * exception = &(coroutine.promise()._exception);
coroutine.promise()._node = { _first, coroutine, exception };
_first = &(coroutine.promise()._node);
} }
void check_exceptions() void check_exception()
{ {
for (auto node = _first; node != nullptr; node = node->_next) if (auto error = _error.load(std::memory_order_relaxed))
{ {
std::exception_ptr & exception = *(node->_exception); _error.store(nullptr, std::memory_order_relaxed);
if (exception) Defer defer { [error]() { delete error; } };
{ std::rethrow_exception(error->_exception);
std::rethrow_exception(exception);
}
} }
} }
void clear() noexcept void set_exception(std::exception_ptr & exception, u64 rank) noexcept
{ {
auto next = _first; auto error = new Error(rank, exception);
while(next) Error * expected = nullptr;
while (!_error.compare_exchange_strong(expected, error))
{
if (expected->_rank < rank)
{
delete error;
return;
}
}
if (expected)
{ {
auto node = next; delete expected;
next = next->_next;
node->_coroutine.destroy();
} }
_first = nullptr;
} }
};
struct Span : TheftPoint
{
using u64 = TheftPoint::u64;
static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation;
ForkList _children;
std::atomic<u64> _n = UMAX;
Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine)
{}
void resume() void resume()
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment