Commit 9ee7ba91 authored by Kirill Smelkov's avatar Kirill Smelkov

sync += WorkGroup

WorkGroup provides way to spawn goroutines that work on a common task
and wait for their completion. It is modelled after
https://godoc.org/golang.org/x/sync/errgroup but is not equal to it.

See WorkGroup docstring for details.
parent e6bea2cf
......@@ -182,7 +182,8 @@ handle concurrency in structured ways.
- `golang.context` provides contexts to propagate cancellation and task-scoped
values among spawned goroutines.
- `golang.sync` provides low-level primitives - for example
- `golang.sync` provides `sync.WorkGroup` to spawn group of goroutines working
on a common task. It also provides low-level primitives - for example
`sync.Once` and `sync.WaitGroup` - that are sometimes useful too.
See `Go Concurrency Patterns: Context`__ for overview of contexts.
......
......@@ -24,8 +24,11 @@ See the following link about Go sync package:
https://golang.org/pkg/sync
"""
import threading
from golang import panic
import threading, sys
from golang import go, defer, func, panic
from golang import context
import six
# Once allows to execute an action only once.
#
......@@ -73,3 +76,65 @@ class WaitGroup(object):
return
event = wg._event
event.wait()
# WorkGroup is a group of goroutines working on a common task.
#
# Use .go() to spawn goroutines, and .wait() to wait for all of them to
# complete, for example:
#
# wg = WorkGroup(ctx)
# wg.go(f1)
# wg.go(f2)
# wg.wait()
#
# Every spawned function accepts context related to the whole work and derived
# from ctx used to initialize WorkGroup, for example:
#
# def f1(ctx):
# ...
#
# Whenever a function returns error (raises exception), the work context is
# canceled indicating to other spawned goroutines that they have to cancel their
# work. .wait() waits for all spawned goroutines to complete and returns/raises
# error, if any, from the first failed subtask.
#
# WorkGroup is modelled after https://godoc.org/golang.org/x/sync/errgroup but
# is not equal to it.
class WorkGroup(object):
def __init__(g, ctx):
g._ctx, g._cancel = context.with_cancel(ctx)
g._wg = WaitGroup()
g._mu = threading.Lock()
g._err = None
def go(g, f, *argv, **kw):
g._wg.add(1)
@func
def _():
defer(g._wg.done)
try:
f(g._ctx, *argv, **kw)
except Exception as exc:
with g._mu:
if g._err is None:
# this goroutine is the first failed task
g._err = exc
if six.PY2:
# py3 has __traceback__ automatically
exc.__traceback__ = sys.exc_info()[2]
g._cancel()
go(_)
def wait(g):
g._wg.wait()
g._cancel()
if g._err is not None:
# reraise the exception so that original traceback is there
if six.PY3:
raise g._err
else:
six.reraise(g._err, None, g._err.__traceback__)
......@@ -19,8 +19,8 @@
# See https://www.nexedi.com/licensing for rationale and options.
from golang import go, chan, _PanicError
from golang import sync
import time
from golang import sync, context
import time, threading
from pytest import raises
def test_once():
......@@ -80,3 +80,83 @@ def test_waitgroup():
with raises(_PanicError):
wg.done()
def test_workgroup():
ctx, cancel = context.with_cancel(context.background())
mu = threading.Lock()
# t1=ok, t2=ok
wg = sync.WorkGroup(ctx)
l = [0, 0]
for i in range(2):
def _(ctx, i):
with mu:
l[i] = i+1
wg.go(_, i)
wg.wait()
assert l == [1, 2]
# t1=fail, t2=ok, does not look at ctx
wg = sync.WorkGroup(ctx)
l = [0, 0]
for i in range(2):
def _(ctx, i):
Iam__ = 0
with mu:
l[i] = i+1
if i == 0:
raise RuntimeError('aaa')
def f(ctx, i):
Iam_f = 0
_(ctx, i)
wg.go(f, i)
with raises(RuntimeError) as exc:
wg.wait()
assert exc.type is RuntimeError
assert exc.value.args == ('aaa',)
assert 'Iam__' in exc.traceback[-1].locals
assert 'Iam_f' in exc.traceback[-2].locals
assert l == [1, 2]
# t1=fail, t2=wait cancel, fail
wg = sync.WorkGroup(ctx)
l = [0, 0]
for i in range(2):
def _(ctx, i):
Iam__ = 0
with mu:
l[i] = i+1
if i == 0:
raise RuntimeError('bbb')
if i == 1:
ctx.done().recv()
raise ValueError('ccc') # != RuntimeError
def f(ctx, i):
Iam_f = 0
_(ctx, i)
wg.go(f, i)
with raises(RuntimeError) as exc:
wg.wait()
assert exc.type is RuntimeError
assert exc.value.args == ('bbb',)
assert 'Iam__' in exc.traceback[-1].locals
assert 'Iam_f' in exc.traceback[-2].locals
assert l == [1, 2]
# t1=ok,wait cancel t2=ok,wait cancel
# cancel parent
wg = sync.WorkGroup(ctx)
l = [0, 0]
for i in range(2):
def _(ctx, i):
with mu:
l[i] = i+1
ctx.done().recv()
wg.go(_, i)
cancel() # parent cancel - must be propagated into workgroup
wg.wait()
assert l == [1, 2]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment