Commit c8f49ef0 authored by Xavier Thompson's avatar Xavier Thompson

Refactor fork reclamation strategies into policies

parent f79f9af7
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <coroutine> #include <coroutine>
#include <cstdint> #include <cstdint>
#include <typon/defer.hpp>
#include <typon/forked.hpp> #include <typon/forked.hpp>
#include <typon/result.hpp> #include <typon/result.hpp>
#include <typon/scheduler.hpp> #include <typon/scheduler.hpp>
...@@ -14,7 +15,73 @@ ...@@ -14,7 +15,73 @@
namespace typon namespace typon
{ {
template <typename T = void> namespace policy
{
struct Bundle
{
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
(void) coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
if (!ready)
{
coroutine.promise()._span->_children.push_back(coroutine);
}
return Forked<T>(coroutine, ready, nullptr);
}
};
};
struct Refcnt
{
ForkNode _node;
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
_node.decref();
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
coroutine.promise()._policy._node._coroutine = coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
auto node = &(coroutine.promise()._policy._node);
return Forked<T>(coroutine, ready, node);
}
};
};
}
template <typename T = void, typename Policy = policy::Refcnt>
struct [[nodiscard]] Fork struct [[nodiscard]] Fork
{ {
struct promise_type; struct promise_type;
...@@ -36,11 +103,7 @@ namespace typon ...@@ -36,11 +103,7 @@ namespace typon
{ {
Span * _span; Span * _span;
u64 _rank; u64 _rank;
ForkNode _node; [[no_unique_address]] Policy _policy;
promise_type() noexcept
: _node{ std::coroutine_handle<promise_type>::from_promise(*this) }
{}
Fork get_return_object() noexcept Fork get_return_object() noexcept
{ {
...@@ -63,19 +126,11 @@ namespace typon ...@@ -63,19 +126,11 @@ namespace typon
{ {
return span->_coroutine; return span->_coroutine;
} }
auto rank = coroutine.promise()._rank;
if (auto & exception = coroutine.promise()._exception) if (auto & exception = coroutine.promise()._exception)
{ {
span->set_exception(exception, rank); span->set_exception(exception, coroutine.promise()._rank);
}
if (!(rank & 1))
{
auto & ref = coroutine.promise()._node._ref;
if (!ref.exchange(false, std::memory_order_acq_rel))
{
coroutine.destroy();
}
} }
coroutine.promise()._policy.on_final_suspend(coroutine);
u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel); u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel);
if (n == 1) if (n == 1)
{ {
...@@ -92,42 +147,16 @@ namespace typon ...@@ -92,42 +147,16 @@ namespace typon
struct awaitable : std::suspend_always struct awaitable : std::suspend_always
{ {
std::coroutine_handle<promise_type> _coroutine; std::coroutine_handle<promise_type> _coroutine;
[[no_unique_address]] Policy::OnAwaitable _policy;
template <typename Promise> template <typename Promise>
auto await_suspend(std::coroutine_handle<Promise> continuation) noexcept auto await_suspend(std::coroutine_handle<Promise> continuation) noexcept
{ {
Span * span = &(continuation.promise()._span); Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span; _coroutine.promise()._span = span;
_coroutine.promise()._rank = (span->_thefts << 1); _coroutine.promise()._rank = span->_thefts;
std::coroutine_handle<> on_stack_handle = _coroutine; _policy.on_await_suspend(_coroutine);
Scheduler::push(span);
return on_stack_handle;
}
auto await_resume()
{
auto thefts = _coroutine.promise()._span->_thefts;
auto rank = _coroutine.promise()._rank;
return Forked<T>(_coroutine, (thefts == (rank >> 1)), true);
}
};
auto operator co_await() &&
{
return awaitable { {}, _coroutine };
}
struct noloop_awaitable : std::suspend_always
{
std::coroutine_handle<promise_type> _coroutine;
template <typename Promise>
auto await_suspend(std::coroutine_handle<Promise> continuation) noexcept
{
Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span;
_coroutine.promise()._rank = (span->_thefts << 1) + 1;
std::coroutine_handle<> on_stack_handle = _coroutine; std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(span); Scheduler::push(span);
...@@ -136,20 +165,13 @@ namespace typon ...@@ -136,20 +165,13 @@ namespace typon
auto await_resume() auto await_resume()
{ {
auto thefts = _coroutine.promise()._span->_thefts; return _policy.on_await_resume(_coroutine);
auto rank = _coroutine.promise()._rank;
bool ready = thefts == (rank >> 1);
if (!ready)
{
_coroutine.promise()._span->_children.push_back(_coroutine);
}
return Forked<T>(_coroutine, ready, false);
} }
}; };
auto noloop() && auto operator co_await() &&
{ {
return noloop_awaitable { {}, _coroutine }; return awaitable { {}, _coroutine, {} };
} }
}; };
......
...@@ -17,6 +17,14 @@ namespace typon ...@@ -17,6 +17,14 @@ namespace typon
{ {
std::coroutine_handle<> _coroutine; std::coroutine_handle<> _coroutine;
std::atomic<bool> _ref {true}; std::atomic<bool> _ref {true};
void decref() noexcept
{
if (!_ref.exchange(false, std::memory_order_acq_rel))
{
_coroutine.destroy();
}
}
}; };
...@@ -110,8 +118,8 @@ namespace typon ...@@ -110,8 +118,8 @@ namespace typon
Result<T> * _result = nullptr; Result<T> * _result = nullptr;
template <typename Promise> template <typename Coroutine>
Forked(std::coroutine_handle<Promise> coroutine, bool ready, bool owning) Forked(Coroutine coroutine, bool ready, ForkNode * node)
{ {
if (ready) if (ready)
{ {
...@@ -120,7 +128,7 @@ namespace typon ...@@ -120,7 +128,7 @@ namespace typon
} }
else else
{ {
this->_node = owning ? &(coroutine.promise()._node) : nullptr; this->_node = node;
_result = &(coroutine.promise()); _result = &(coroutine.promise());
} }
} }
...@@ -163,10 +171,7 @@ namespace typon ...@@ -163,10 +171,7 @@ namespace typon
{ {
if (auto node = this->_node) if (auto node = this->_node)
{ {
if (!node->_ref.exchange(false, std::memory_order_acq_rel)) node->decref();
{
node->_coroutine.destroy();
}
} }
} }
else else
...@@ -194,20 +199,21 @@ namespace typon ...@@ -194,20 +199,21 @@ namespace typon
bool _ready; bool _ready;
Result<T> * _result; Result<T> * _result;
ForkNode * _node; void * _coroutine;
template <typename Promise> template <typename Coroutine>
Forked(std::coroutine_handle<Promise> coroutine, bool ready, bool owning) Forked(Coroutine coroutine, bool ready, ForkNode * node)
{ {
_ready = ready; _ready = ready;
_result = &(coroutine.promise()); _result = &(coroutine.promise());
_node = (owning | ready) ? &(coroutine.promise()._node) : nullptr;
if (ready) if (ready)
{ {
if (auto & exception = coroutine.promise()._exception) _coroutine = coroutine.address();
{ coroutine.promise().get();
std::rethrow_exception(exception);
} }
else
{
_coroutine = node;
} }
} }
...@@ -217,31 +223,28 @@ namespace typon ...@@ -217,31 +223,28 @@ namespace typon
Forked(Forked && other) noexcept Forked(Forked && other) noexcept
: _ready(other._ready) : _ready(other._ready)
, _result(other._result) , _result(other._result)
, _node(std::exchange(other._node, nullptr)) , _coroutine(std::exchange(other._coroutine, nullptr))
{} {}
Forked& operator=(Forked && other) noexcept Forked& operator=(Forked && other) noexcept
{ {
std::swap(_ready, other._ready); std::swap(_ready, other._ready);
std::swap(_node, other._node);
std::swap(_result, other._result); std::swap(_result, other._result);
std::swap(_coroutine, other._coroutine);
return *this; return *this;
} }
~Forked() ~Forked()
{ {
if (_node) if (_coroutine)
{ {
if (_ready) if (_ready)
{ {
_node->_coroutine.destroy(); std::coroutine_handle<void>::from_address(_coroutine).destroy();
} }
else else
{ {
if (!_node->_ref.exchange(false, std::memory_order_acq_rel)) reinterpret_cast<ForkNode *>(_coroutine)->decref();
{
_node->_coroutine.destroy();
}
} }
} }
} }
......
...@@ -26,6 +26,16 @@ namespace typon ...@@ -26,6 +26,16 @@ namespace typon
} }
template <typename Policy, typename Task>
Fork<typename Task::promise_type::value_type, Policy> fork(Task task)
{
// Put the task in a local variable to ensure its destructor will
// be called on co_return instead of only on coroutine destruction.
Task local_task = std::move(task);
co_return co_await std::move(local_task);
}
template <typename Task> template <typename Task>
Future<typename Task::promise_type::value_type> future(Task task) Future<typename Task::promise_type::value_type> future(Task task)
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment