Commit 81addd67 authored by Xavier Thompson's avatar Xavier Thompson

Make continuation-stealing interface more generic

parent 679b1887
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <atomic> #include <atomic>
#include <coroutine> #include <coroutine>
#include <cstdint> #include <cstdint>
#include <utility>
#include <typon/result.hpp> #include <typon/result.hpp>
#include <typon/scheduler.hpp> #include <typon/scheduler.hpp>
...@@ -18,7 +19,6 @@ namespace typon ...@@ -18,7 +19,6 @@ namespace typon
struct [[nodiscard]] Future struct [[nodiscard]] Future
{ {
struct promise_type; struct promise_type;
using u64 = TheftPoint::u64;
using enum std::memory_order; using enum std::memory_order;
...@@ -67,6 +67,26 @@ namespace typon ...@@ -67,6 +67,26 @@ namespace typon
} }
} }
struct Continuation : TheftPoint
{
std::coroutine_handle<> _continuation;
std::coroutine_handle<> steal() noexcept override
{
return std::exchange(_continuation, nullptr);
}
operator std::coroutine_handle<>() noexcept
{
return _continuation;
}
bool ready() noexcept
{
return bool(_continuation);
}
};
struct promise_type : Result<T> struct promise_type : Result<T>
{ {
std::atomic<std::uintptr_t> _state { no_waiter }; std::atomic<std::uintptr_t> _state { no_waiter };
...@@ -90,7 +110,7 @@ namespace typon ...@@ -90,7 +110,7 @@ namespace typon
auto theftpoint = Scheduler::peek(); auto theftpoint = Scheduler::peek();
if (Scheduler::pop()) if (Scheduler::pop())
{ {
return theftpoint->_coroutine; return *static_cast<Continuation *>(theftpoint);
} }
auto state = coroutine.promise()._state.exchange(ready, acq_rel); auto state = coroutine.promise()._state.exchange(ready, acq_rel);
if (state == discarded) if (state == discarded)
...@@ -118,7 +138,7 @@ namespace typon ...@@ -118,7 +138,7 @@ namespace typon
{ {
Future _future; Future _future;
std::coroutine_handle<promise_type> _coroutine; std::coroutine_handle<promise_type> _coroutine;
TheftPoint _theftpoint; Continuation _continuation;
awaitable(Future && f, std::coroutine_handle<promise_type> c) noexcept awaitable(Future && f, std::coroutine_handle<promise_type> c) noexcept
: _future(std::move(f)) : _future(std::move(f))
...@@ -132,20 +152,19 @@ namespace typon ...@@ -132,20 +152,19 @@ namespace typon
auto await_suspend(std::coroutine_handle<> continuation) noexcept auto await_suspend(std::coroutine_handle<> continuation) noexcept
{ {
_theftpoint._coroutine = continuation; _continuation._continuation = continuation;
std::coroutine_handle<> on_stack_handle = _coroutine; std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(&(_theftpoint)); Scheduler::push(&(_continuation));
return on_stack_handle; return on_stack_handle;
} }
auto await_resume() noexcept auto await_resume() noexcept
{ {
_future._ready = !_theftpoint._thefts; _future._ready = _continuation.ready();
return std::move(_future); return std::move(_future);
} }
}; };
return awaitable { std::move(*this), _coroutine }; return awaitable { std::move(*this), _coroutine };
} }
......
...@@ -201,8 +201,7 @@ namespace typon ...@@ -201,8 +201,7 @@ namespace typon
{ {
if (auto task = stack->steal()) if (auto task = stack->steal())
{ {
task->_thefts++; coroutine = task;
coroutine = task->_coroutine;
} }
return; return;
} }
...@@ -210,8 +209,7 @@ namespace typon ...@@ -210,8 +209,7 @@ namespace typon
{ {
if (auto task = stack->pop_top()) if (auto task = stack->pop_top())
{ {
task->_thefts++; coroutine = task;
coroutine = task->_coroutine;
return; return;
} }
if (stack->_state.compare_exchange_strong(state, Stack::EMPTY)) if (stack->_state.compare_exchange_strong(state, Stack::EMPTY))
......
...@@ -18,12 +18,14 @@ namespace typon ...@@ -18,12 +18,14 @@ namespace typon
struct Span : TheftPoint struct Span : TheftPoint
{ {
using u64 = TheftPoint::u64; using u64 = std::uint_fast64_t;
static constexpr u64 UMAX = std::numeric_limits<u64>::max(); static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _coroutine;
std::coroutine_handle<> _continuation; std::coroutine_handle<> _continuation;
u64 _thefts = 0;
std::atomic<bool> _concurrent_error_flag { false }; std::atomic<bool> _concurrent_error_flag { false };
std::exception_ptr _concurrent_exception; std::exception_ptr _concurrent_exception;
...@@ -32,7 +34,7 @@ namespace typon ...@@ -32,7 +34,7 @@ namespace typon
std::atomic<u64> _n = UMAX; std::atomic<u64> _n = UMAX;
Span(std::coroutine_handle<> coroutine) noexcept Span(std::coroutine_handle<> coroutine) noexcept
: TheftPoint(coroutine) : _coroutine(coroutine)
{} {}
void propagate_exception() void propagate_exception()
...@@ -83,6 +85,12 @@ namespace typon ...@@ -83,6 +85,12 @@ namespace typon
_n.store(UMAX, std::memory_order_release); _n.store(UMAX, std::memory_order_release);
} }
std::coroutine_handle<> steal() noexcept override
{
_thefts++;
return _coroutine;
}
operator std::coroutine_handle<>() noexcept operator std::coroutine_handle<>() noexcept
{ {
return _coroutine; return _coroutine;
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include <atomic> #include <atomic>
#include <coroutine>
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
...@@ -118,7 +119,7 @@ namespace typon ...@@ -118,7 +119,7 @@ namespace typon
return buffer->get(bottom); return buffer->get(bottom);
} }
TheftPoint * steal() noexcept std::coroutine_handle<> steal() noexcept
{ {
u64 top = _top.load(acquire); u64 top = _top.load(acquire);
std::atomic_thread_fence(seq_cst); std::atomic_thread_fence(seq_cst);
...@@ -131,27 +132,31 @@ namespace typon ...@@ -131,27 +132,31 @@ namespace typon
{ {
return nullptr; return nullptr;
} }
return x; return x->steal();
} }
return nullptr; return nullptr;
} }
TheftPoint * pop_top() noexcept std::coroutine_handle<> pop_top() noexcept
{ {
u64 top = _top.load(relaxed); u64 top = _top.load(relaxed);
u64 bottom = _bottom.load(relaxed); u64 bottom = _bottom.load(relaxed);
auto buffer = _buffer.load(relaxed); auto buffer = _buffer.load(relaxed);
TheftPoint * x = nullptr;
if (top < bottom) if (top < bottom)
{ {
x = buffer->get(top); TheftPoint * x = buffer->get(top);
_top.store(top + 1, relaxed); _top.store(top + 1, relaxed);
if (auto garbage = reclaim())
{
delete garbage;
}
return x->steal();
} }
if (auto garbage = reclaim()) if (auto garbage = reclaim())
{ {
delete garbage; delete garbage;
} }
return x; return nullptr;
} }
RingBuffer * reclaim() noexcept RingBuffer * reclaim() noexcept
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define TYPON_THEFT_POINT_HPP_INCLUDED #define TYPON_THEFT_POINT_HPP_INCLUDED
#include <coroutine> #include <coroutine>
#include <cstdint>
namespace typon namespace typon
...@@ -10,16 +9,7 @@ namespace typon ...@@ -10,16 +9,7 @@ namespace typon
struct TheftPoint struct TheftPoint
{ {
using u64 = std::uint_fast64_t; virtual std::coroutine_handle<> steal() noexcept = 0;
std::coroutine_handle<> _coroutine;
u64 _thefts = 0;
TheftPoint() noexcept {}
TheftPoint(std::coroutine_handle<> coroutine) noexcept
: _coroutine(coroutine)
{}
}; };
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment