Commit 6eb80104 authored by Kirill Smelkov's avatar Kirill Smelkov

sync.WorkGroup: Provide "with" support

So that it becomes possible to write

    with WorkGroup(ctx) as wg:
        wg.go(f1)
        wg.go(f2)

instead of

    wg = WorkGroup(ctx)
    defer(wg.wait)
    wg.go(f1)
    wg.go(f2)

or

    wg = WorkGroup(ctx)
    wg.go(f1)
    wg.go(f2)
    wg.wait()

This is sometimes handy and is referred to as "structured concurrency"
in Python world.

sync.Sema, sync.Mutex, sync.RWMutex already support "with".
sync.WaitGroup is imho too low-level, but we might consider adding
"with" support for it in the future as well.

In general pygolang way is to use defer instead of plugging all classes
with __enter__/__exit__ "with" support, but for small well-known class of
concurrency-related things its seems "with" support is worth it:

- having "with" for sync.Mutex+co allows it to be used as a drop-in
  replacement instead of threading.Lock+co, and
- having "with" for sync.WorkGroup - the most commonly-used tool to
  spawn jobs and wait for their completion - makes it on-par with
  "structured concurrency".

/reviewed-on !12
parent 85257b2a
Pipeline #12805 failed with stage
in 0 seconds
......@@ -22,7 +22,7 @@
from __future__ import print_function, absolute_import
from cython cimport final
from cpython cimport PyObject
from cpython cimport PyObject, PY_MAJOR_VERSION
from golang cimport nil, newref, topyexc
from golang cimport context
from golang.pyx cimport runtime
......@@ -34,6 +34,8 @@ cdef extern from "golang/sync.h" namespace "golang::sync" nogil:
from libcpp.cast cimport dynamic_cast
import sys as pysys
@final
cdef class PySema:
......@@ -196,6 +198,13 @@ cdef class PyWorkGroup:
work. .wait() waits for all spawned goroutines to complete and returns/raises
error, if any, from the first failed subtask.
WorkGroup can be also used via `with` statement where .wait() is
automatically called at the end of the block, for example:
with WorkGroup(ctx) as wg:
wg.go(f1)
wg.go(f2)
WorkGroup is modelled after https://godoc.org/golang.org/x/sync/errgroup but
is not equal to it.
"""
......@@ -236,6 +245,40 @@ cdef class PyWorkGroup:
# reraise pyerr with original traceback
pyerr_reraise(pyerr)
# with support
def __enter__(PyWorkGroup pyg):
return pyg
def __exit__(PyWorkGroup pyg, exc_typ, exc_val, exc_tb):
# py2: prepare exc_val to be chained into
if PY_MAJOR_VERSION == 2 and exc_val is not None:
_pyexc_contextify(exc_val, None)
# if .wait() raises, we want raised exception to be chained into
# exc_val via .__context__, so that
#
# wg = sync.WorkGroup(ctx)
# defer(wg.wait)
# ...
#
# and
#
# with sync.WorkGroup(ctx) as wg:
# ...
#
# are equivalent.
#
# Even if Python3 implements exception chaining natively, it does not
# automatically chain exceptions in __exit__. Implement the chaining ourselves.
try:
pyg.wait()
except:
if PY_MAJOR_VERSION == 2:
if exc_val is not None and not hasattr(exc_val, '__traceback__'):
exc_val.__traceback__ = exc_tb
exc = pysys.exc_info()[1]
_pyexc_contextify(exc, exc_val)
raise
# _PyCtxFunc complements PyWorkGroup.go() : it's operator()(ctx) verifies that
# ctx is expected context and further calls python function without any arguments.
# PyWorkGroup.go() arranges to use python functions that are bound to PyContext
......@@ -271,6 +314,19 @@ cdef extern from * nogil:
# ---- misc ----
# _pyexc_contextify makes sure pyexc has .__context__, .__cause__ and
# .__suppress_context__ attributes.
#
# .__context__ if not already present, or if it was previously None, is set to pyexccontext.
cdef _pyexc_contextify(object pyexc, pyexccontext):
if not hasattr(pyexc, '__context__') or pyexc.__context__ is None:
pyexc.__context__ = pyexccontext
if not hasattr(pyexc, '__cause__'):
pyexc.__cause__ = None
if not hasattr(pyexc, '__suppress_context__'):
pyexc.__suppress_context__ = False
cdef nogil:
void semaacquire_pyexc(Sema *sema) except +topyexc:
......
......@@ -20,9 +20,10 @@
from __future__ import print_function, absolute_import
from golang import go, chan, select, default
from golang import go, chan, select, default, func, defer
from golang import sync, context, time
from pytest import raises, mark
from _pytest._code import Traceback
from golang.golang_test import import_pyx_tests, panics
from golang.time_test import dt
from six.moves import range as xrange
......@@ -245,6 +246,17 @@ PyErr_Restore_traceback_ok = True
if 'PyPy' in sys.version and sys.pypy_version_info < (7,3):
PyErr_Restore_traceback_ok = False
# WorkGroup must catch/propagate all exception classes.
# Python2 allows to raise old-style classes not derived from BaseException.
# Python3 allows to raise only BaseException derivatives.
if six.PY2:
class MyError:
def __init__(self, *args):
self.args = args
else:
class MyError(BaseException):
pass
def test_workgroup():
ctx, cancel = context.with_cancel(context.background())
mu = sync.Mutex()
......@@ -260,16 +272,6 @@ def test_workgroup():
wg.wait()
assert l == [1, 2]
# WorkGroup must catch/propagate all exception classes.
# Python2 allows to raise old-style classes not derived from BaseException.
# Python3 allows to raise only BaseException derivatives.
if six.PY2:
class MyError:
def __init__(self, *args):
self.args = args
else:
class MyError(BaseException):
pass
# t1=fail, t2=ok, does not look at ctx
wg = sync.WorkGroup(ctx)
......@@ -337,6 +339,92 @@ def test_workgroup():
wg.wait()
assert l == [1, 2]
@func
def test_workgroup_with():
# verify with support for sync.WorkGroup
ctx, cancel = context.with_cancel(context.background())
defer(cancel)
mu = sync.Mutex()
# t1=ok, t2=ok
l = [0, 0]
with sync.WorkGroup(ctx) as wg:
for i in range(2):
def _(ctx, i):
with mu:
l[i] = i+1
wg.go(_, i)
assert l == [1, 2]
# t1=fail, t2=wait cancel, fail
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def _(ctx):
Iam_t1 = 0
raise MyError('hello (fail)')
wg.go(_)
def _(ctx):
ctx.done().recv()
raise MyError('world (after zzz)')
wg.go(_)
e = exci.value
assert e.__class__ is MyError
assert e.args == ('hello (fail)',)
assert e.__cause__ is None
assert e.__context__ is None
assert e.__suppress_context__ == False
if PyErr_Restore_traceback_ok:
assert 'Iam_t1' in exci.traceback[-1].locals
# t=ok, but code from under with raises
l = [0]
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def _(ctx):
l[0] = 1
wg.go(_)
def bad():
raise MyError('wow')
bad()
e = exci.value
assert e.__class__ is MyError
assert e.args == ('wow',)
assert e.__cause__ is None
assert e.__context__ is None
assert e.__suppress_context__ == False
assert exci.traceback[-1].name == 'bad'
assert l[0] == 1
# t=fail, code from under with also raises
with raises(MyError) as exci:
with sync.WorkGroup(ctx) as wg:
def f(ctx):
raise MyError('fail from go')
wg.go(f)
def g():
raise MyError('just raise')
g()
e = exci.value
assert e.__class__ is MyError
assert e.args == ('fail from go',)
assert e.__cause__ is None
assert e.__context__ is not None
assert e.__suppress_context__ == False
assert exci.traceback[-1].name == 'f'
e2 = e.__context__
assert e2.__class__ is MyError
assert e2.args == ('just raise',)
assert e2.__cause__ is None
assert e2.__context__ is None
assert e2.__suppress_context__ == False
assert e2.__traceback__ is not None
t2 = Traceback(e2.__traceback__)
assert t2[-1].name == 'g'
# create/wait workgroup with 1 empty worker.
def bench_workgroup_empty(b):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment