Commit 673cea98 authored by Xavier Thompson's avatar Xavier Thompson

Destroy forks with stolen continuations sooner

Previously when forked tasks didn't finish synchronously, they were
added to a list of children tasks and destroyed at the next sync.

But this meant that in the case of an infinite loop spawning new forks,
such memory could grow unboundedly and never be reclaimed.

Fixing this also required changing how exceptions in forks are propagated.
parent 1316d783
#ifndef TYPON_FORK_HPP_INCLUDED
#define TYPON_FORK_HPP_INCLUDED
#include <atomic>
#include <coroutine>
#include <cstdint>
......@@ -34,7 +35,12 @@ namespace typon
struct promise_type : Result<T>
{
Span * _span;
ForkList::Node _node;
u64 _rank;
ForkNode _node;
promise_type() noexcept
: _node{ std::coroutine_handle<promise_type>::from_promise(*this) }
{}
Fork get_return_object() noexcept
{
......@@ -57,6 +63,15 @@ namespace typon
{
return span->_coroutine;
}
if (auto & exception = coroutine.promise()._exception)
{
span->set_exception(exception, coroutine.promise()._rank);
}
auto & ref = coroutine.promise()._node._ref;
if (!ref.exchange(false, std::memory_order_acq_rel))
{
coroutine.destroy();
}
u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel);
if (n == 1)
{
......@@ -73,7 +88,6 @@ namespace typon
struct awaitable
{
std::coroutine_handle<promise_type> _coroutine;
u64 _thefts;
awaitable(std::coroutine_handle<promise_type> coroutine) noexcept
: _coroutine(coroutine)
......@@ -89,7 +103,7 @@ namespace typon
{
Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span;
_thefts = span->_thefts;
_coroutine.promise()._rank = span->_thefts;
std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(span);
......@@ -98,13 +112,9 @@ namespace typon
auto await_resume()
{
auto span = _coroutine.promise()._span;
bool stolen = span->_thefts > _thefts;
if (stolen)
{
span->_children.insert(_coroutine);
}
return Forked<T>(_coroutine, !stolen);
auto thefts = _coroutine.promise()._span->_thefts;
auto rank = _coroutine.promise()._rank;
return Forked<T>(_coroutine, (thefts == rank));
}
};
......
......@@ -13,138 +13,102 @@
namespace typon
{
template <typename T>
struct Forked
struct ForkNode
{
using value_type = T;
std::coroutine_handle<> _coroutine;
std::atomic<bool> _ref {true};
};
Result<T> * _result = nullptr;
template <typename T>
struct ForkResult
{
union
{
T _value;
ForkNode * _node;
};
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
if (ready)
void construct_value(std::coroutine_handle<Promise> coroutine)
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
std::construct_at(std::addressof(_value), coroutine.promise().get());
}
else
{
_result = &(coroutine.promise());
}
}
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{
_result = other._result;
if (!_result)
void construct_value(ForkResult && other)
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
}
Forked& operator=(Forked && other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{
if (this != &other)
{
Forked old { std::move(*this) };
_result = other._result;
if (!_result)
T get_value() noexcept
{
std::construct_at(std::addressof(_value), std::move(other._value));
}
}
return *this;
return _value;
}
~Forked()
{
if (!_result)
void destroy_value() noexcept
{
std::destroy_at(std::addressof(_value));
}
}
T get()
{
if (_result)
{
return _result->get();
}
return _value;
}
};
template <typename T>
requires requires { sizeof(T); } && (sizeof(T) > 2 * sizeof(void*))
struct Forked<T>
struct ForkResult<T&>
{
using value_type = T;
Result<T> * _result = nullptr;
std::coroutine_handle<> _coroutine;
union
{
T * _value;
ForkNode * _node;
};
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
_result = &(coroutine.promise());
if (ready)
{
_coroutine = coroutine;
if (coroutine.promise()._exception)
void construct_value(std::coroutine_handle<Promise> coroutine)
{
std::rethrow_exception(coroutine.promise()._exception);
_value = std::addressof(coroutine.promise().get());
}
void construct_value(ForkResult && other)
{
_value = other._value;
}
T& get_value() noexcept
{
return *_value;
}
Forked(const Forked &) = delete;
Forked& operator=(const Forked &) = delete;
void destroy_value() noexcept {}
};
Forked(Forked && other) noexcept
: _result(other._result)
, _coroutine(std::exchange(other._coroutine, nullptr))
{}
Forked& operator=(Forked && other) noexcept
template <>
struct ForkResult<void>
{
std::swap(_coroutine, other._coroutine);
std::swap(_result, other._result);
return *this;
}
ForkNode * _node {nullptr};
~Forked()
{
if (_coroutine)
template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine)
{
_coroutine.destroy();
}
coroutine.promise().get();
}
T get()
void construct_value(ForkResult && other) noexcept
{
return _result->get();
(void) other;
}
void get_value() noexcept {}
void destroy_value() noexcept {}
};
template <typename T>
requires std::is_trivially_copyable_v<T>
struct Forked<T>
struct Forked : ForkResult<T>
{
using value_type = T;
Result<T> * _result = nullptr;
union
{
value_type _value;
};
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
......@@ -152,96 +116,139 @@ namespace typon
if (ready)
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
std::construct_at(std::addressof(_value), coroutine.promise().get());
this->construct_value(coroutine);
}
else
{
this->_node = &(coroutine.promise()._node);
_result = &(coroutine.promise());
}
}
Forked(Forked && other) = default;
Forked& operator=(Forked && other) = default;
~Forked()
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{
if (!_result)
_result = other._result;
if (_result)
{
std::destroy_at(std::addressof(_value));
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
}
T get()
Forked& operator=(Forked && other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{
if (this != &other)
{
Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
return _result->get();
this->_node = std::exchange(other._node, nullptr);
}
return _value;
else
{
this->construct_value(std::move(other));
}
}
return *this;
}
};
template <typename T>
struct Forked<T&>
~Forked()
{
using value_type = T&;
Result<T> * _result = nullptr;
T * _value;
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
if (_result)
{
if (ready)
if (auto node = this->_node)
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
_value = std::addressof(coroutine.promise().get());
if (!node->_ref.exchange(false, std::memory_order_acq_rel))
{
node->_coroutine.destroy();
}
}
}
else
{
_result = &(coroutine.promise());
this->destroy_value();
}
}
T& get()
T get()
{
if (_result)
{
return _result->get();
}
return *_value;
return this->get_value();
}
};
template <>
struct Forked<void>
template <typename T>
requires requires { sizeof(T); } && (sizeof(T) > 2 * sizeof(void*))
struct Forked<T>
{
using value_type = void;
using value_type = T;
Result<void> * _result = nullptr;
bool _ready;
Result<T> * _result;
ForkNode * _node;
template <typename Promise>
Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{
_ready = ready;
_result = &(coroutine.promise());
_node = &(coroutine.promise()._node);
if (ready)
{
Defer defer { [&coroutine]() { coroutine.destroy(); } };
coroutine.promise().get();
}
else
if (auto & exception = coroutine.promise()._exception)
{
_result = &(coroutine.promise());
std::rethrow_exception(exception);
}
}
}
void get()
Forked(const Forked &) = delete;
Forked& operator=(const Forked &) = delete;
Forked(Forked && other) noexcept
: _ready(other._ready)
, _result(other._result)
, _node(std::exchange(other._node, nullptr))
{}
Forked& operator=(Forked && other) noexcept
{
std::swap(_ready, other._ready);
std::swap(_node, other._node);
std::swap(_result, other._result);
return *this;
}
~Forked()
{
if (_result)
if (_node)
{
_result->get();
if (_ready)
{
_node->_coroutine.destroy();
}
else
{
if (!_node->_ref.exchange(false, std::memory_order_acq_rel))
{
_node->_coroutine.destroy();
}
}
}
}
T get()
{
return _result->get();
}
};
......
......@@ -40,7 +40,6 @@ namespace typon
{
if (_coroutine)
{
_coroutine.promise()._span._children.clear();
_coroutine.destroy();
}
}
......@@ -107,7 +106,7 @@ namespace typon
{
_span._thefts = 0;
_span._n.store(UMAX, std::memory_order_release);
_span._children.check_exceptions();
_span.check_exception();
}
};
......@@ -158,7 +157,7 @@ namespace typon
decltype(auto) await_resume()
{
_coroutine.promise()._span._children.check_exceptions();
_coroutine.promise()._span.check_exception();
return _coroutine.promise().get();
}
};
......
......@@ -8,77 +8,70 @@
#include <exception>
#include <limits>
#include <typon/defer.hpp>
#include <typon/theft_point.hpp>
namespace typon
{
struct ForkList
struct Span : TheftPoint
{
struct Node
using u64 = TheftPoint::u64;
struct Error
{
Node * _next;
std::coroutine_handle<> _coroutine;
std::exception_ptr * _exception;
u64 _rank;
std::exception_ptr _exception;
};
Node * _first = nullptr;
static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation;
std::atomic<Error *> _error { nullptr };
std::atomic<u64> _n = UMAX;
Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine)
{}
template <typename Promise>
requires requires (Promise p)
~Span()
{
{ p._exception } -> std::same_as<std::exception_ptr&>;
{ p._node } -> std::same_as<Node&>;
}
void insert(std::coroutine_handle<Promise> coroutine) noexcept
if (auto error = _error.load(std::memory_order_relaxed))
{
std::exception_ptr * exception = &(coroutine.promise()._exception);
coroutine.promise()._node = { _first, coroutine, exception };
_first = &(coroutine.promise()._node);
delete error;
}
}
void check_exceptions()
{
for (auto node = _first; node != nullptr; node = node->_next)
void check_exception()
{
std::exception_ptr & exception = *(node->_exception);
if (exception)
if (auto error = _error.load(std::memory_order_relaxed))
{
std::rethrow_exception(exception);
}
_error.store(nullptr, std::memory_order_relaxed);
Defer defer { [error]() { delete error; } };
std::rethrow_exception(error->_exception);
}
}
void clear() noexcept
void set_exception(std::exception_ptr & exception, u64 rank) noexcept
{
auto error = new Error(rank, exception);
Error * expected = nullptr;
while (!_error.compare_exchange_strong(expected, error))
{
auto next = _first;
while(next)
if (expected->_rank < rank)
{
auto node = next;
next = next->_next;
node->_coroutine.destroy();
delete error;
return;
}
_first = nullptr;
}
};
struct Span : TheftPoint
if (expected)
{
using u64 = TheftPoint::u64;
static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation;
ForkList _children;
std::atomic<u64> _n = UMAX;
Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine)
{}
delete expected;
}
}
void resume()
{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment