Commit da1a6710 authored by Xavier Thompson's avatar Xavier Thompson

Make continuation-stealing interface more generic

parent df650572
......@@ -4,6 +4,7 @@
#include <atomic>
#include <coroutine>
#include <cstdint>
#include <utility>
#include <typon/result.hpp>
#include <typon/scheduler.hpp>
......@@ -18,7 +19,6 @@ namespace typon
struct [[nodiscard]] Future
{
struct promise_type;
using u64 = TheftPoint::u64;
using enum std::memory_order;
......@@ -67,6 +67,26 @@ namespace typon
}
}
struct Continuation : TheftPoint
{
std::coroutine_handle<> _continuation;
std::coroutine_handle<> steal() noexcept override
{
return std::exchange(_continuation, nullptr);
}
operator std::coroutine_handle<>() noexcept
{
return _continuation;
}
bool ready() noexcept
{
return bool(_continuation);
}
};
struct promise_type : Result<T>
{
std::atomic<std::uintptr_t> _state { no_waiter };
......@@ -90,7 +110,7 @@ namespace typon
auto theftpoint = Scheduler::peek();
if (Scheduler::pop())
{
return theftpoint->_coroutine;
return *static_cast<Continuation *>(theftpoint);
}
auto state = coroutine.promise()._state.exchange(ready, acq_rel);
if (state == discarded)
......@@ -118,7 +138,7 @@ namespace typon
{
Future _future;
std::coroutine_handle<promise_type> _coroutine;
TheftPoint _theftpoint;
Continuation _continuation;
awaitable(Future && f, std::coroutine_handle<promise_type> c) noexcept
: _future(std::move(f))
......@@ -132,20 +152,19 @@ namespace typon
auto await_suspend(std::coroutine_handle<> continuation) noexcept
{
_theftpoint._coroutine = continuation;
_continuation._continuation = continuation;
std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(&(_theftpoint));
Scheduler::push(&(_continuation));
return on_stack_handle;
}
auto await_resume() noexcept
{
_future._ready = !_theftpoint._thefts;
_future._ready = _continuation.ready();
return std::move(_future);
}
};
return awaitable { std::move(*this), _coroutine };
}
......
......@@ -201,8 +201,7 @@ namespace typon
{
if (auto task = stack->steal())
{
task->_thefts++;
coroutine = task->_coroutine;
coroutine = task;
}
return;
}
......@@ -210,8 +209,7 @@ namespace typon
{
if (auto task = stack->pop_top())
{
task->_thefts++;
coroutine = task->_coroutine;
coroutine = task;
return;
}
if (stack->_state.compare_exchange_strong(state, Stack::EMPTY))
......
......@@ -18,12 +18,14 @@ namespace typon
struct Span : TheftPoint
{
using u64 = TheftPoint::u64;
using u64 = std::uint_fast64_t;
static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _coroutine;
std::coroutine_handle<> _continuation;
u64 _thefts = 0;
std::atomic<bool> _concurrent_error_flag { false };
std::exception_ptr _concurrent_exception;
......@@ -32,7 +34,7 @@ namespace typon
std::atomic<u64> _n = UMAX;
Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine)
: _coroutine(coroutine)
{}
void propagate_exception()
......@@ -83,6 +85,12 @@ namespace typon
_n.store(UMAX, std::memory_order_release);
}
std::coroutine_handle<> steal() noexcept override
{
_thefts++;
return _coroutine;
}
operator std::coroutine_handle<>() noexcept
{
return _coroutine;
......
......@@ -39,6 +39,7 @@
#include <atomic>
#include <coroutine>
#include <cstdint>
#include <memory>
#include <type_traits>
......@@ -118,7 +119,7 @@ namespace typon
return buffer->get(bottom);
}
TheftPoint * steal() noexcept
std::coroutine_handle<> steal() noexcept
{
u64 top = _top.load(acquire);
std::atomic_thread_fence(seq_cst);
......@@ -131,27 +132,31 @@ namespace typon
{
return nullptr;
}
return x;
return x->steal();
}
return nullptr;
}
TheftPoint * pop_top() noexcept
std::coroutine_handle<> pop_top() noexcept
{
u64 top = _top.load(relaxed);
u64 bottom = _bottom.load(relaxed);
auto buffer = _buffer.load(relaxed);
TheftPoint * x = nullptr;
if (top < bottom)
{
x = buffer->get(top);
TheftPoint * x = buffer->get(top);
_top.store(top + 1, relaxed);
if (auto garbage = reclaim())
{
delete garbage;
}
return x->steal();
}
if (auto garbage = reclaim())
{
delete garbage;
}
return x;
return nullptr;
}
RingBuffer * reclaim() noexcept
......
......@@ -2,7 +2,6 @@
#define TYPON_THEFT_POINT_HPP_INCLUDED
#include <coroutine>
#include <cstdint>
namespace typon
......@@ -10,16 +9,7 @@ namespace typon
struct TheftPoint
{
using u64 = std::uint_fast64_t;
std::coroutine_handle<> _coroutine;
u64 _thefts = 0;
TheftPoint() noexcept {}
TheftPoint(std::coroutine_handle<> coroutine) noexcept
: _coroutine(coroutine)
{}
virtual std::coroutine_handle<> steal() noexcept = 0;
};
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment