Commit d71ad15e authored by Stefan Behnel's avatar Stefan Behnel

allow yield in async def functions (which turns them into async generators)

parent dee2a0cb
...@@ -9411,6 +9411,7 @@ class YieldExprNode(ExprNode): ...@@ -9411,6 +9411,7 @@ class YieldExprNode(ExprNode):
label_num = 0 label_num = 0
is_yield_from = False is_yield_from = False
is_await = False is_await = False
in_async_gen = False
expr_keyword = 'yield' expr_keyword = 'yield'
def analyse_types(self, env): def analyse_types(self, env):
...@@ -9469,7 +9470,10 @@ class YieldExprNode(ExprNode): ...@@ -9469,7 +9470,10 @@ class YieldExprNode(ExprNode):
code.putln("/* return from generator, yielding value */") code.putln("/* return from generator, yielding value */")
code.putln("%s->resume_label = %d;" % ( code.putln("%s->resume_label = %d;" % (
Naming.generator_cname, label_num)) Naming.generator_cname, label_num))
code.putln("return %s;" % Naming.retval_cname) if self.in_async_gen and not self.is_await:
code.putln("return __pyx__PyAsyncGenWrapValue(%s);" % Naming.retval_cname)
else:
code.putln("return %s;" % Naming.retval_cname)
code.put_label(label_name) code.put_label(label_name)
for cname, save_cname, type in saved: for cname, save_cname, type in saved:
......
...@@ -2151,7 +2151,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2151,7 +2151,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("%s = PyUnicode_FromStringAndSize(\"\", 0); %s" % ( code.putln("%s = PyUnicode_FromStringAndSize(\"\", 0); %s" % (
Naming.empty_unicode, code.error_goto_if_null(Naming.empty_unicode, self.pos))) Naming.empty_unicode, code.error_goto_if_null(Naming.empty_unicode, self.pos)))
for ext_type in ('CyFunction', 'FusedFunction', 'Coroutine', 'Generator', 'StopAsyncIteration'): for ext_type in ('CyFunction', 'FusedFunction', 'Coroutine', 'Generator', 'AsyncGen', 'StopAsyncIteration'):
code.putln("#ifdef __Pyx_%s_USED" % ext_type) code.putln("#ifdef __Pyx_%s_USED" % ext_type)
code.put_error_if_neg(self.pos, "__pyx_%s_init()" % ext_type) code.put_error_if_neg(self.pos, "__pyx_%s_init()" % ext_type)
code.putln("#endif") code.putln("#endif")
......
...@@ -3969,6 +3969,8 @@ class GeneratorDefNode(DefNode): ...@@ -3969,6 +3969,8 @@ class GeneratorDefNode(DefNode):
is_generator = True is_generator = True
is_coroutine = False is_coroutine = False
is_asyncgen = False
gen_type_name = 'Generator'
needs_closure = True needs_closure = True
child_attrs = DefNode.child_attrs + ["gbody"] child_attrs = DefNode.child_attrs + ["gbody"]
...@@ -3992,7 +3994,7 @@ class GeneratorDefNode(DefNode): ...@@ -3992,7 +3994,7 @@ class GeneratorDefNode(DefNode):
code.putln('{') code.putln('{')
code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New(' code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New('
'(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s, %s); %s' % ( '(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s, %s); %s' % (
'Coroutine' if self.is_coroutine else 'Generator', self.gen_type_name,
body_cname, Naming.cur_scope_cname, name, qualname, module_name, body_cname, Naming.cur_scope_cname, name, qualname, module_name,
code.error_goto_if_null('gen', self.pos))) code.error_goto_if_null('gen', self.pos)))
code.put_decref(Naming.cur_scope_cname, py_object_type) code.put_decref(Naming.cur_scope_cname, py_object_type)
...@@ -4007,18 +4009,23 @@ class GeneratorDefNode(DefNode): ...@@ -4007,18 +4009,23 @@ class GeneratorDefNode(DefNode):
code.putln('}') code.putln('}')
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
env.use_utility_code(UtilityCode.load_cached( env.use_utility_code(UtilityCode.load_cached(self.gen_type_name, "Coroutine.c"))
'Coroutine' if self.is_coroutine else 'Generator', "Coroutine.c"))
self.gbody.generate_function_header(code, proto=True) self.gbody.generate_function_header(code, proto=True)
super(GeneratorDefNode, self).generate_function_definitions(env, code) super(GeneratorDefNode, self).generate_function_definitions(env, code)
self.gbody.generate_function_definitions(env, code) self.gbody.generate_function_definitions(env, code)
class AsyncDefNode(GeneratorDefNode): class AsyncDefNode(GeneratorDefNode):
gen_type_name = 'Coroutine'
is_coroutine = True is_coroutine = True
class AsyncGenNode(AsyncDefNode):
gen_type_name = 'AsyncGen'
is_asyncgen = True
class GeneratorBodyDefNode(DefNode): class GeneratorBodyDefNode(DefNode):
# Main code body of a generator implemented as a DefNode. # Main code body of a generator implemented as a DefNode.
# #
......
...@@ -52,6 +52,8 @@ cdef class YieldNodeCollector(TreeVisitor): ...@@ -52,6 +52,8 @@ cdef class YieldNodeCollector(TreeVisitor):
cdef public list yields cdef public list yields
cdef public list returns cdef public list returns
cdef public bint has_return_value cdef public bint has_return_value
cdef public bint has_yield
cdef public bint has_await
cdef class MarkClosureVisitor(CythonTransform): cdef class MarkClosureVisitor(CythonTransform):
cdef bint needs_closure cdef bint needs_closure
......
...@@ -192,7 +192,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -192,7 +192,7 @@ class PostParse(ScopeTrackingTransform):
# unpack a lambda expression into the corresponding DefNode # unpack a lambda expression into the corresponding DefNode
collector = YieldNodeCollector() collector = YieldNodeCollector()
collector.visitchildren(node.result_expr) collector.visitchildren(node.result_expr)
if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode): if collector.has_yield or collector.has_await or isinstance(node.result_expr, ExprNodes.YieldExprNode):
body = Nodes.ExprStatNode( body = Nodes.ExprStatNode(
node.result_expr.pos, expr=node.result_expr) node.result_expr.pos, expr=node.result_expr)
else: else:
...@@ -2457,19 +2457,22 @@ class YieldNodeCollector(TreeVisitor): ...@@ -2457,19 +2457,22 @@ class YieldNodeCollector(TreeVisitor):
def __init__(self): def __init__(self):
super(YieldNodeCollector, self).__init__() super(YieldNodeCollector, self).__init__()
self.yields = [] self.yields = []
self.awaits = []
self.returns = [] self.returns = []
self.has_return_value = False self.has_return_value = False
self.has_yield = False
self.has_await = False
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self.visitchildren(node)
def visit_YieldExprNode(self, node): def visit_YieldExprNode(self, node):
self.yields.append(node) self.yields.append(node)
self.has_yield = True
self.visitchildren(node) self.visitchildren(node)
def visit_AwaitExprNode(self, node): def visit_AwaitExprNode(self, node):
self.awaits.append(node) self.yields.append(node)
self.has_await = True
self.visitchildren(node) self.visitchildren(node)
def visit_ReturnStatNode(self, node): def visit_ReturnStatNode(self, node):
...@@ -2513,24 +2516,27 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2513,24 +2516,27 @@ class MarkClosureVisitor(CythonTransform):
collector.visitchildren(node) collector.visitchildren(node)
if node.is_async_def: if node.is_async_def:
if collector.yields: coroutine_type = Nodes.AsyncGenNode if collector.has_yield else Nodes.AsyncDefNode
error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')") if collector.has_yield:
yields = collector.awaits for yield_expr in collector.yields:
elif collector.yields: yield_expr.in_async_gen = True
if collector.awaits: elif collector.has_await:
error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')") found = next(y for y in collector.yields if y.is_await)
yields = collector.yields error(found.pos, "'await' not allowed in generators (use 'yield')")
return node
elif collector.has_yield:
coroutine_type = Nodes.GeneratorDefNode
else: else:
return node return node
for i, yield_expr in enumerate(yields, 1): for i, yield_expr in enumerate(collector.yields, 1):
yield_expr.label_num = i yield_expr.label_num = i
for retnode in collector.returns: for retnode in collector.returns:
retnode.in_generator = True retnode.in_generator = True
gbody = Nodes.GeneratorBodyDefNode( gbody = Nodes.GeneratorBodyDefNode(
pos=node.pos, name=node.name, body=node.body) pos=node.pos, name=node.name, body=node.body)
coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)( coroutine = coroutine_type(
pos=node.pos, name=node.name, args=node.args, pos=node.pos, name=node.name, args=node.args,
star_arg=node.star_arg, starstar_arg=node.starstar_arg, star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators, doc=node.doc, decorators=node.decorators,
......
...@@ -21,6 +21,9 @@ static PyTypeObject *__pyx_AsyncGenType = 0; ...@@ -21,6 +21,9 @@ static PyTypeObject *__pyx_AsyncGenType = 0;
static PyObject *__Pyx_AsyncGen_ANext(PyObject *o); static PyObject *__Pyx_AsyncGen_ANext(PyObject *o);
static PyObject *__pyx__PyAsyncGenWrapValue(PyObject *val);
static __pyx_CoroutineObject *__Pyx_AsyncGen_New( static __pyx_CoroutineObject *__Pyx_AsyncGen_New(
__pyx_coroutine_body_t body, PyObject *closure, __pyx_coroutine_body_t body, PyObject *closure,
PyObject *name, PyObject *qualname, PyObject *module_name) { PyObject *name, PyObject *qualname, PyObject *module_name) {
...@@ -679,7 +682,8 @@ static PyObject * ...@@ -679,7 +682,8 @@ static PyObject *
__pyx__PyAsyncGenWrapValue(PyObject *val) __pyx__PyAsyncGenWrapValue(PyObject *val)
{ {
__pyx__PyAsyncGenWrappedValue *o; __pyx__PyAsyncGenWrappedValue *o;
assert(val); if (unlikely(!val))
return NULL;
if (__Pyx_ag_value_fl_free) { if (__Pyx_ag_value_fl_free) {
__Pyx_ag_value_fl_free--; __Pyx_ag_value_fl_free--;
...@@ -689,11 +693,12 @@ __pyx__PyAsyncGenWrapValue(PyObject *val) ...@@ -689,11 +693,12 @@ __pyx__PyAsyncGenWrapValue(PyObject *val)
} else { } else {
o = PyObject_New(__pyx__PyAsyncGenWrappedValue, __pyx__PyAsyncGenWrappedValueType); o = PyObject_New(__pyx__PyAsyncGenWrappedValue, __pyx__PyAsyncGenWrappedValueType);
if (o == NULL) { if (o == NULL) {
Py_DECREF(val);
return NULL; return NULL;
} }
} }
o->val = val; o->val = val;
Py_INCREF(val); // no Py_INCREF(val) - steals reference!
return (PyObject*)o; return (PyObject*)o;
} }
......
# mode: error
# tag: pep492, async
async def foo():
yield
_ERRORS = """
5:4: 'yield' not allowed in async coroutines (use 'await')
5:4: 'yield' not supported here
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment