Commit 3fccb633 authored by Xavier Thompson's avatar Xavier Thompson

Rethink exception propagation from forks

Previously, an exception thrown from a fork was rethrown:
- a) when the call to fork() returns if the fork completes synchronously
- b) otherwise, at the next explicit call to Sync() if there is one
- c) otherwise, when the call to the enclosing Join coroutine returns

In the case of infinite fork loops, this meant exceptions might never be
propagated.

Now, when an exception is thrown from a fork it's always rethrown when
the call to the enclosing Join coroutine returns. The body of the Join
coroutine just stops executing as soon as possible once a fork signals
an exception. This will be at the call to fork() if the fork completes
synchronously, or at any ensuing call to Sync() or fork() otherwise.

Essentially once an exception is signaled from a parallel fork, the next
call to fork() behaves like Sync() instead of creating a fork, and once
all the parallel forks have completed, execution resumes directly at the
call to the enclosing Join coroutine, where the exception is rethrown.
parent fe0843f5
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
#include <atomic> #include <atomic>
#include <coroutine> #include <coroutine>
#include <cstdint> #include <cstdint>
#include <type_traits>
#include <typon/defer.hpp> #include <typon/defer.hpp>
#include <typon/fork_refcount.hpp>
#include <typon/forked.hpp> #include <typon/forked.hpp>
#include <typon/meta.hpp>
#include <typon/result.hpp> #include <typon/result.hpp>
#include <typon/scheduler.hpp> #include <typon/scheduler.hpp>
#include <typon/span.hpp> #include <typon/span.hpp>
...@@ -15,109 +18,13 @@ ...@@ -15,109 +18,13 @@
namespace typon namespace typon
{ {
namespace policy template <typename T = void>
{
struct Bundle
{
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
(void) coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
if (!ready)
{
coroutine.promise()._span->_children.push_back(coroutine);
}
return Forked<T>(coroutine, ready, nullptr);
}
};
};
struct Refcnt
{
ForkNode _node;
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
(void) coroutine;
_node.decref();
}
struct OnAwaitable
{
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
coroutine.promise()._policy._node._coroutine = coroutine;
}
template <typename Promise>
auto on_await_resume(std::coroutine_handle<Promise> coroutine)
{
using T = typename Promise::value_type;
auto thefts = coroutine.promise()._span->_thefts;
auto rank = coroutine.promise()._rank;
bool ready = (thefts == rank);
auto node = &(coroutine.promise()._policy._node);
return Forked<T>(coroutine, ready, node);
}
};
};
struct Drop
{
void on_final_suspend(std::coroutine_handle<> coroutine) noexcept
{
coroutine.destroy();
}
struct OnAwaitable
{
Span * _span;
Span::u64 _rank;
template <typename Promise>
void on_await_suspend(std::coroutine_handle<Promise> coroutine) noexcept
{
_span = coroutine.promise()._span;
_rank = _span->_thefts;
}
template <typename Promise>
void on_await_resume(std::coroutine_handle<Promise> coroutine)
{
if (_span->_thefts == _rank)
{
Defer defer { [coroutine]() { coroutine.destroy(); } };
coroutine.promise().get();
}
}
};
};
}
template <typename T = void, typename P = policy::Refcnt>
struct [[nodiscard]] Fork struct [[nodiscard]] Fork
{ {
struct promise_type; struct promise_type;
using u64 = Span::u64; using u64 = Span::u64;
using Policy = std::conditional_t<std::is_same_v<T, void>, policy::Drop, P>;
static constexpr bool is_void { std::is_same_v<T, void> };
std::coroutine_handle<promise_type> _coroutine; std::coroutine_handle<promise_type> _coroutine;
...@@ -133,9 +40,10 @@ namespace typon ...@@ -133,9 +40,10 @@ namespace typon
struct promise_type : Result<T> struct promise_type : Result<T>
{ {
using Refcount = std::conditional_t<is_void, meta::Empty, ForkRefcount>;
Span * _span; Span * _span;
u64 _rank; [[no_unique_address]] Refcount _refcount;
[[no_unique_address]] Policy _policy;
Fork get_return_object() noexcept Fork get_return_object() noexcept
{ {
...@@ -154,19 +62,35 @@ namespace typon ...@@ -154,19 +62,35 @@ namespace typon
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept
{ {
auto span = coroutine.promise()._span; auto span = coroutine.promise()._span;
auto exception = std::move(coroutine.promise()._exception);
if constexpr(is_void)
{
coroutine.destroy();
}
if (Scheduler::pop()) if (Scheduler::pop())
{ {
if (exception)
{
span->set_sequential_exception(exception);
if ((span->_thefts == 0) || span->notify_sync())
{
return span->_continuation;
}
return std::noop_coroutine();
}
return span->_coroutine; return span->_coroutine;
} }
if (auto & exception = coroutine.promise()._exception) if (exception)
{
span->set_concurrent_exception(exception);
}
if constexpr(!is_void)
{ {
span->set_exception(exception, coroutine.promise()._rank); coroutine.promise()._refcount.decref();
} }
coroutine.promise()._policy.on_final_suspend(coroutine); if (span->notify_fork())
u64 n = span->_n.fetch_sub(1, std::memory_order_acq_rel);
if (n == 1)
{ {
return span->continuation(); return span->fork_continuation();
} }
return std::noop_coroutine(); return std::noop_coroutine();
} }
...@@ -179,31 +103,51 @@ namespace typon ...@@ -179,31 +103,51 @@ namespace typon
struct awaitable : std::suspend_always struct awaitable : std::suspend_always
{ {
std::coroutine_handle<promise_type> _coroutine; std::coroutine_handle<promise_type> _coroutine;
[[no_unique_address]] Policy::OnAwaitable _policy; u64 _thefts;
awaitable(std::coroutine_handle<promise_type> coroutine)
: _coroutine(coroutine)
{}
template <typename Promise> template <typename Promise>
auto await_suspend(std::coroutine_handle<Promise> continuation) noexcept std::coroutine_handle<> await_suspend(std::coroutine_handle<Promise> continuation) noexcept
{ {
Span * span = &(continuation.promise()._span); Span * span = &(continuation.promise()._span);
_coroutine.promise()._span = span; _coroutine.promise()._span = span;
_coroutine.promise()._rank = span->_thefts; _thefts = span->_thefts;
if constexpr(!is_void)
_policy.on_await_suspend(_coroutine); {
_coroutine.promise()._refcount.set(_coroutine);
}
if (_thefts && span->has_concurrent_exception())
{
// Destroy the fork because it will not be run.
_coroutine.destroy();
if (span->notify_sync())
{
return span->_continuation;
}
return std::noop_coroutine();
}
std::coroutine_handle<> on_stack_handle = _coroutine; std::coroutine_handle<> on_stack_handle = _coroutine;
Scheduler::push(span); Scheduler::push(span);
return on_stack_handle; return on_stack_handle;
} }
auto await_resume() auto await_resume() noexcept
{ {
return _policy.on_await_resume(_coroutine); if constexpr(!is_void)
{
auto span = _coroutine.promise()._span;
bool ready = (span->_thefts == _thefts);
return Forked<T>(_coroutine, ready);
}
} }
}; };
auto operator co_await() && auto operator co_await() &&
{ {
return awaitable { {}, _coroutine, {} }; return awaitable { _coroutine };
} }
}; };
......
#ifndef TYPON_FORK_REFCOUNT_HPP_INCLUDED
#define TYPON_FORK_REFCOUNT_HPP_INCLUDED
#include <atomic>
#include <coroutine>
namespace typon
{
struct ForkRefcount
{
std::coroutine_handle<> _coroutine;
std::atomic<bool> _refcount {true};
void set(std::coroutine_handle<> coroutine) noexcept
{
_coroutine = coroutine;
}
void decref() noexcept
{
if (!_refcount.exchange(false, std::memory_order_acq_rel))
{
_coroutine.destroy();
}
}
};
}
#endif // TYPON_FORK_REFCOUNT_HPP_INCLUDED
...@@ -7,106 +7,123 @@ ...@@ -7,106 +7,123 @@
#include <utility> #include <utility>
#include <typon/defer.hpp> #include <typon/defer.hpp>
#include <typon/fork_refcount.hpp>
#include <typon/result.hpp> #include <typon/result.hpp>
namespace typon namespace typon
{ {
struct ForkNode
{
std::coroutine_handle<> _coroutine;
std::atomic<bool> _ref {true};
void decref() noexcept
{
if (!_ref.exchange(false, std::memory_order_acq_rel))
{
_coroutine.destroy();
}
}
};
template <typename T> template <typename T>
struct ForkResult struct Forked
{ {
using value_type = T;
Result<T> * _result = nullptr;
union union
{ {
T _value; T _value;
ForkNode * _node; ForkRefcount * _refcount;
}; };
template <typename Promise> template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine) Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{ {
std::construct_at(std::addressof(_value), coroutine.promise().get()); if (ready)
{
std::construct_at(std::addressof(_value), coroutine.promise().value());
coroutine.destroy();
}
else
{
_refcount = &(coroutine.promise()._refcount);
_result = &(coroutine.promise());
}
} }
void construct_value(ForkResult && other) Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{
_result = other._result;
if (_result)
{
_refcount = std::exchange(other._refcount, nullptr);
}
else
{ {
std::construct_at(std::addressof(_value), std::move(other._value)); std::construct_at(std::addressof(_value), std::move(other._value));
} }
}
T get_value() noexcept Forked& operator=(Forked && other)
noexcept(std::is_nothrow_move_constructible_v<T>)
{ {
return _value; if (this != &other)
{
Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
_refcount = std::exchange(other._refcount, nullptr);
} }
else
void destroy_value() noexcept
{ {
std::destroy_at(std::addressof(_value)); std::construct_at(std::addressof(_value), std::move(other._value));
}
}
return *this;
} }
};
template <typename T> ~Forked()
struct ForkResult<T&>
{ {
union if (_result)
{ {
T * _value; _refcount->decref();
ForkNode * _node; }
}; else
template <typename Promise>
void construct_value(std::coroutine_handle<Promise> coroutine)
{ {
_value = std::addressof(coroutine.promise().get()); std::destroy_at(std::addressof(_value));
}
} }
void construct_value(ForkResult && other) T get() &
{
if (_result)
{ {
_value = other._value; return _result->value();
}
return _value;
} }
T& get_value() noexcept T get() &&
{
if (_result)
{ {
return *_value; return _result->value();
}
return std::move(_value);
} }
void destroy_value() noexcept {}
}; };
template <typename T> template <typename T>
struct Forked : ForkResult<T> struct Forked<T&>
{ {
using value_type = T; using value_type = T;
Result<T> * _result = nullptr; Result<T> * _result = nullptr;
void * _data;
template <typename Coroutine> template <typename Promise>
Forked(Coroutine coroutine, bool ready, ForkNode * node) Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{ {
if (ready) if (ready)
{ {
Defer defer { [&coroutine]() { coroutine.destroy(); } }; _data = std::addressof(coroutine.promise().value());
this->construct_value(coroutine); coroutine.destroy();
} }
else else
{ {
this->_node = node; _data = &(coroutine.promise()._refcount);
_result = &(coroutine.promise()); _result = &(coroutine.promise());
} }
} }
...@@ -114,32 +131,14 @@ namespace typon ...@@ -114,32 +131,14 @@ namespace typon
Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>) Forked(Forked && other) noexcept(std::is_nothrow_move_constructible_v<T>)
{ {
_result = other._result; _result = other._result;
if (_result) _data = std::exchange(other._data, nullptr);
{
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
} }
Forked& operator=(Forked && other) Forked& operator=(Forked other)
noexcept(std::is_nothrow_move_constructible_v<T>) noexcept(std::is_nothrow_move_constructible_v<T>)
{ {
if (this != &other) std::swap(_result, other._result);
{ std::swap(_data, other._data);
Forked old { std::move(*this) };
_result = other._result;
if (_result)
{
this->_node = std::exchange(other._node, nullptr);
}
else
{
this->construct_value(std::move(other));
}
}
return *this; return *this;
} }
...@@ -147,24 +146,17 @@ namespace typon ...@@ -147,24 +146,17 @@ namespace typon
{ {
if (_result) if (_result)
{ {
if (auto node = this->_node) reinterpret_cast<ForkRefcount *>(_data)->decref();
{
node->decref();
}
}
else
{
this->destroy_value();
} }
} }
T get() T& get()
{ {
if (_result) if (_result)
{ {
return _result->get(); return _result->value();
} }
return this->get_value(); return *(reinterpret_cast<T *>(_data));
} }
}; };
...@@ -179,8 +171,8 @@ namespace typon ...@@ -179,8 +171,8 @@ namespace typon
Result<T> * _result; Result<T> * _result;
void * _coroutine; void * _coroutine;
template <typename Coroutine> template <typename Promise>
Forked(Coroutine coroutine, bool ready, ForkNode * node) Forked(std::coroutine_handle<Promise> coroutine, bool ready)
{ {
_ready = ready; _ready = ready;
_result = &(coroutine.promise()); _result = &(coroutine.promise());
...@@ -191,7 +183,7 @@ namespace typon ...@@ -191,7 +183,7 @@ namespace typon
} }
else else
{ {
_coroutine = node; _coroutine = &(coroutine.promise()._refcount);
} }
} }
...@@ -222,14 +214,14 @@ namespace typon ...@@ -222,14 +214,14 @@ namespace typon
} }
else else
{ {
reinterpret_cast<ForkNode *>(_coroutine)->decref(); reinterpret_cast<ForkRefcount *>(_coroutine)->decref();
} }
} }
} }
T get() T get()
{ {
return _result->get(); return _result->value();
} }
}; };
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <coroutine> #include <coroutine>
#include <utility> #include <utility>
#include <typon/meta.hpp>
#include <typon/result.hpp> #include <typon/result.hpp>
#include <typon/span.hpp> #include <typon/span.hpp>
...@@ -19,9 +20,11 @@ namespace typon ...@@ -19,9 +20,11 @@ namespace typon
{ {
struct promise_type; struct promise_type;
std::coroutine_handle<promise_type> _coroutine; using coroutine_type = std::coroutine_handle<promise_type>;
Join(std::coroutine_handle<promise_type> coroutine) noexcept : _coroutine(coroutine) {} coroutine_type _coroutine;
Join(coroutine_type coroutine) noexcept : _coroutine(coroutine) {}
Join(const Join &) = delete; Join(const Join &) = delete;
Join & operator=(const Join &) = delete; Join & operator=(const Join &) = delete;
...@@ -44,7 +47,7 @@ namespace typon ...@@ -44,7 +47,7 @@ namespace typon
} }
} }
struct promise_type : Result<T> struct promise_type : Result<T, meta::Empty>
{ {
using u64 = Span::u64; using u64 = Span::u64;
static constexpr u64 UMAX = Span::UMAX; static constexpr u64 UMAX = Span::UMAX;
...@@ -52,12 +55,12 @@ namespace typon ...@@ -52,12 +55,12 @@ namespace typon
Span _span; Span _span;
promise_type() noexcept promise_type() noexcept
: _span(std::coroutine_handle<promise_type>::from_promise(*this)) : _span(coroutine_type::from_promise(*this))
{} {}
Join get_return_object() noexcept Join get_return_object() noexcept
{ {
return { std::coroutine_handle<promise_type>::from_promise(*this) }; return { coroutine_type::from_promise(*this) };
} }
std::suspend_always initial_suspend() noexcept std::suspend_always initial_suspend() noexcept
...@@ -65,6 +68,11 @@ namespace typon ...@@ -65,6 +68,11 @@ namespace typon
return {}; return {};
} }
void unhandled_exception() noexcept
{
_span.set_sequential_exception(std::current_exception());
}
template <typename U> template <typename U>
decltype(auto) await_transform(U && expr) noexcept decltype(auto) await_transform(U && expr) noexcept
{ {
...@@ -79,34 +87,22 @@ namespace typon ...@@ -79,34 +87,22 @@ namespace typon
bool await_ready() noexcept bool await_ready() noexcept
{ {
if (u64 thefts = _span._thefts) return (_span._thefts == 0);
{
u64 n = _span._n.load(std::memory_order_acquire);
if (n - (UMAX - thefts) == 0)
{
return true;
}
return false;
}
return true;
} }
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept std::coroutine_handle<> await_suspend(coroutine_type coroutine) noexcept
{ {
u64 thefts = _span._thefts; (void) coroutine;
u64 n = _span._n.fetch_sub(UMAX - thefts, std::memory_order_acq_rel); if (_span.notify_sync())
if (n - (UMAX - thefts) == 0)
{ {
return coroutine; return _span.sync_continuation();
} }
return std::noop_coroutine(); return std::noop_coroutine();
} }
void await_resume() void await_resume() noexcept
{ {
_span._thefts = 0; _span.reset_sync();
_span._n.store(UMAX, std::memory_order_release);
_span.check_exception();
} }
}; };
...@@ -117,16 +113,10 @@ namespace typon ...@@ -117,16 +113,10 @@ namespace typon
{ {
struct awaitable : std::suspend_always struct awaitable : std::suspend_always
{ {
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> coroutine) noexcept std::coroutine_handle<> await_suspend(coroutine_type coroutine) noexcept
{ {
Span & span = coroutine.promise()._span; Span & span = coroutine.promise()._span;
u64 thefts = span._thefts; if ((span._thefts == 0) || span.notify_sync())
if (thefts == 0)
{
return span._continuation;
}
u64 n = span._n.fetch_sub(UMAX - thefts, std::memory_order_acq_rel);
if (n - (UMAX - thefts) == 0)
{ {
return span._continuation; return span._continuation;
} }
...@@ -142,14 +132,14 @@ namespace typon ...@@ -142,14 +132,14 @@ namespace typon
{ {
struct awaitable struct awaitable
{ {
std::coroutine_handle<promise_type> _coroutine; coroutine_type _coroutine;
bool await_ready() noexcept bool await_ready() noexcept
{ {
return false; return false;
} }
std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) noexcept auto await_suspend(std::coroutine_handle<> continuation) noexcept
{ {
_coroutine.promise()._span._continuation = continuation; _coroutine.promise()._span._continuation = continuation;
return _coroutine; return _coroutine;
...@@ -157,8 +147,8 @@ namespace typon ...@@ -157,8 +147,8 @@ namespace typon
decltype(auto) await_resume() decltype(auto) await_resume()
{ {
_coroutine.promise()._span.check_exception(); _coroutine.promise()._span.propagate_exception();
return _coroutine.promise().get(); return _coroutine.promise().value();
} }
}; };
......
#ifndef TYPON_META_HPP_INCLUDED
#define TYPON_META_HPP_INCLUDED
namespace typon::meta
{
struct Empty {};
}
#endif // TYPON_META_HPP_INCLUDED
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
namespace typon namespace typon
{ {
template <typename T> template <typename T, typename E = std::exception_ptr>
struct Result struct Result
{ {
using value_type = T; using value_type = T;
bool _valid = false; bool _valid = false;
std::exception_ptr _exception; [[no_unique_address]] E _exception;
union union
{ {
T _value; T _value;
...@@ -39,38 +39,57 @@ namespace typon ...@@ -39,38 +39,57 @@ namespace typon
} }
void unhandled_exception() noexcept void unhandled_exception() noexcept
{
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{ {
_exception = std::current_exception(); _exception = std::current_exception();
} }
}
T& get() & T& get() &
{
if constexpr(std::is_same_v<E, std::exception_ptr>)
{ {
if (_exception) if (_exception)
{ {
std::rethrow_exception(std::exchange(_exception, nullptr)); std::rethrow_exception(std::exchange(_exception, nullptr));
} }
}
return _value; return _value;
} }
T&& get() && T&& get() &&
{
if constexpr(std::is_same_v<E, std::exception_ptr>)
{ {
if (_exception) if (_exception)
{ {
std::rethrow_exception(std::exchange(_exception, nullptr)); std::rethrow_exception(std::exchange(_exception, nullptr));
} }
}
return std::move(_value);
}
T& value() & noexcept
{
return _value;
}
T&& value() && noexcept
{
return std::move(_value); return std::move(_value);
} }
}; };
template <typename T> template <typename T, typename E>
struct Result<T&> struct Result<T&, E>
{ {
using value_type = T&; using value_type = T&;
T* _value; T* _value;
std::exception_ptr _exception; [[no_unique_address]] E _exception;
void return_value(T& expr) noexcept void return_value(T& expr) noexcept
{ {
...@@ -78,42 +97,61 @@ namespace typon ...@@ -78,42 +97,61 @@ namespace typon
} }
void unhandled_exception() noexcept void unhandled_exception() noexcept
{
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{ {
_exception = std::current_exception(); _exception = std::current_exception();
} }
}
T& get() & T& get() &
{
if constexpr(std::is_same_v<E, std::exception_ptr>)
{ {
if (_exception) if (_exception)
{ {
std::rethrow_exception(std::exchange(_exception, nullptr)); std::rethrow_exception(std::exchange(_exception, nullptr));
} }
}
return *_value;
}
T& value() & noexcept
{
return *_value; return *_value;
} }
}; };
template <> template <typename E>
struct Result<void> struct Result<void, E>
{ {
using value_type = void; using value_type = void;
std::exception_ptr _exception; [[no_unique_address]] E _exception;
void return_void() noexcept {} void return_void() noexcept {}
void unhandled_exception() noexcept void unhandled_exception() noexcept
{
if constexpr(std::is_assignable_v<E, std::exception_ptr>)
{ {
_exception = std::current_exception(); _exception = std::current_exception();
} }
}
void get() void get()
{
if constexpr(std::is_same_v<E, std::exception_ptr>)
{ {
if (_exception) if (_exception)
{ {
std::rethrow_exception(std::exchange(_exception, nullptr)); std::rethrow_exception(std::exchange(_exception, nullptr));
} }
} }
}
void value() noexcept {}
}; };
} }
......
...@@ -20,19 +20,14 @@ namespace typon ...@@ -20,19 +20,14 @@ namespace typon
{ {
using u64 = TheftPoint::u64; using u64 = TheftPoint::u64;
struct Error
{
u64 _rank;
std::exception_ptr _exception;
};
static constexpr u64 UMAX = std::numeric_limits<u64>::max(); static constexpr u64 UMAX = std::numeric_limits<u64>::max();
std::coroutine_handle<> _continuation; std::coroutine_handle<> _continuation;
std::atomic<Error *> _error { nullptr };
std::vector<std::coroutine_handle<>> _children; std::atomic<bool> _concurrent_error_flag { false };
std::exception_ptr _concurrent_exception;
std::exception_ptr _sequential_exception;
std::atomic<u64> _n = UMAX; std::atomic<u64> _n = UMAX;
...@@ -40,55 +35,52 @@ namespace typon ...@@ -40,55 +35,52 @@ namespace typon
: TheftPoint(coroutine) : TheftPoint(coroutine)
{} {}
~Span() void propagate_exception()
{ {
if (auto error = _error.load(std::memory_order_relaxed)) if (_sequential_exception)
{
std::rethrow_exception(_sequential_exception);
}
if (_concurrent_exception)
{ {
delete error; std::rethrow_exception(_concurrent_exception);
} }
clear_children();
} }
void clear_children() noexcept bool has_concurrent_exception() noexcept
{ {
for (auto & child : _children) return _concurrent_error_flag.load(std::memory_order_acquire);
{
child.destroy();
}
_children.clear();
} }
void check_exception() void set_concurrent_exception(std::exception_ptr & exception) noexcept
{ {
if (auto error = _error.load(std::memory_order_relaxed)) if (!_concurrent_error_flag.exchange(true, std::memory_order_acq_rel))
{ {
_error.store(nullptr, std::memory_order_relaxed); _concurrent_exception = exception;
Defer defer { [error]() { delete error; } };
std::rethrow_exception(error->_exception);
} }
} }
void set_exception(std::exception_ptr & exception, u64 rank) noexcept void set_sequential_exception(std::exception_ptr exception) noexcept
{
auto error = new Error(rank, exception);
Error * expected = nullptr;
while (!_error.compare_exchange_strong(expected, error))
{ {
if (expected->_rank < rank) _sequential_exception = std::move(exception);
{
delete error;
return;
}
} }
if (expected)
bool notify_sync() noexcept
{ {
delete expected; u64 n = _n.fetch_sub(UMAX - _thefts, std::memory_order_acq_rel);
return (n - (UMAX - _thefts) == 0);
} }
bool notify_fork() noexcept
{
u64 n = _n.fetch_sub(1, std::memory_order_acq_rel);
return (n == 1);
} }
void resume() void reset_sync() noexcept
{ {
_coroutine.resume(); _thefts = 0;
_n.store(UMAX, std::memory_order_release);
} }
operator std::coroutine_handle<>() noexcept operator std::coroutine_handle<>() noexcept
...@@ -96,9 +88,22 @@ namespace typon ...@@ -96,9 +88,22 @@ namespace typon
return _coroutine; return _coroutine;
} }
std::coroutine_handle<> continuation() noexcept std::coroutine_handle<> fork_continuation() noexcept
{
// It's safe to access _concurrent_exception here
// because this is only called when all strands are done
if (_coroutine.done() || _sequential_exception || _concurrent_exception)
{
return _continuation;
}
return _coroutine;
}
std::coroutine_handle<> sync_continuation() noexcept
{ {
if (_coroutine.done()) // It's safe to access _concurrent_exception here
// because this is only called when all strands are done
if (_sequential_exception || _concurrent_exception)
{ {
return _continuation; return _continuation;
} }
......
...@@ -26,16 +26,6 @@ namespace typon ...@@ -26,16 +26,6 @@ namespace typon
} }
template <typename Policy, typename Task>
Fork<typename Task::promise_type::value_type, Policy> fork(Task task)
{
// Put the task in a local variable to ensure its destructor will
// be called on co_return instead of only on coroutine destruction.
Task local_task = std::move(task);
co_return co_await std::move(local_task);
}
template <typename Task> template <typename Task>
Future<typename Task::promise_type::value_type> future(Task task) Future<typename Task::promise_type::value_type> future(Task task)
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment