Commit a129d192 authored by Stefan Behnel's avatar Stefan Behnel

Merge branch '_pep492_async_await'

parents 19ffc8b6 9ee3d69a
...@@ -398,9 +398,16 @@ def init_builtins(): ...@@ -398,9 +398,16 @@ def init_builtins():
init_builtin_structs() init_builtin_structs()
init_builtin_types() init_builtin_types()
init_builtin_funcs() init_builtin_funcs()
builtin_scope.declare_var( builtin_scope.declare_var(
'__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type), '__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True) pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True)
entry = builtin_scope.declare_var(
'StopAsyncIteration', PyrexTypes.py_object_type,
pos=None, cname='__Pyx_PyExc_StopAsyncIteration')
entry.utility_code = UtilityCode.load_cached("StopAsyncIteration", "Coroutine.c")
global list_type, tuple_type, dict_type, set_type, frozenset_type global list_type, tuple_type, dict_type, set_type, frozenset_type
global bytes_type, str_type, unicode_type, basestring_type, slice_type global bytes_type, str_type, unicode_type, basestring_type, slice_type
global float_type, bool_type, type_type, complex_type, bytearray_type global float_type, bool_type, type_type, complex_type, bytearray_type
......
...@@ -49,7 +49,7 @@ non_portable_builtins_map = { ...@@ -49,7 +49,7 @@ non_portable_builtins_map = {
'basestring' : ('PY_MAJOR_VERSION >= 3', 'str'), 'basestring' : ('PY_MAJOR_VERSION >= 3', 'str'),
'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'), 'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'),
'raw_input' : ('PY_MAJOR_VERSION >= 3', 'input'), 'raw_input' : ('PY_MAJOR_VERSION >= 3', 'input'),
} }
basicsize_builtins_map = { basicsize_builtins_map = {
# builtins whose type has a different tp_basicsize than sizeof(...) # builtins whose type has a different tp_basicsize than sizeof(...)
...@@ -63,6 +63,13 @@ uncachable_builtins = [ ...@@ -63,6 +63,13 @@ uncachable_builtins = [
'_', # e.g. gettext '_', # e.g. gettext
] ]
special_py_methods = set([
'__cinit__', '__dealloc__', '__richcmp__', '__next__',
'__await__', '__aiter__', '__anext__',
'__getreadbuffer__', '__getwritebuffer__', '__getsegcount__',
'__getcharbuffer__', '__getbuffer__', '__releasebuffer__'
])
modifier_output_mapper = { modifier_output_mapper = {
'inline': 'CYTHON_INLINE' 'inline': 'CYTHON_INLINE'
}.get }.get
...@@ -1999,7 +2006,7 @@ class CCodeWriter(object): ...@@ -1999,7 +2006,7 @@ class CCodeWriter(object):
def put_pymethoddef(self, entry, term, allow_skip=True): def put_pymethoddef(self, entry, term, allow_skip=True):
if entry.is_special or entry.name == '__getattribute__': if entry.is_special or entry.name == '__getattribute__':
if entry.name not in ['__cinit__', '__dealloc__', '__richcmp__', '__next__', '__getreadbuffer__', '__getwritebuffer__', '__getsegcount__', '__getcharbuffer__', '__getbuffer__', '__releasebuffer__']: if entry.name not in special_py_methods:
if entry.name == '__getattr__' and not self.globalstate.directives['fast_getattr']: if entry.name == '__getattr__' and not self.globalstate.directives['fast_getattr']:
pass pass
# Python's typeobject.c will automatically fill in our slot # Python's typeobject.c will automatically fill in our slot
......
This diff is collapsed.
...@@ -991,6 +991,9 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -991,6 +991,9 @@ class ControlFlowAnalysis(CythonTransform):
self.mark_assignment(target, node.item) self.mark_assignment(target, node.item)
def visit_AsyncForStatNode(self, node):
return self.visit_ForInStatNode(node)
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
condition_block = self.flow.nextblock() condition_block = self.flow.nextblock()
next_block = self.flow.newblock() next_block = self.flow.newblock()
...@@ -1002,6 +1005,9 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -1002,6 +1005,9 @@ class ControlFlowAnalysis(CythonTransform):
if isinstance(node, Nodes.ForInStatNode): if isinstance(node, Nodes.ForInStatNode):
self.mark_forloop_target(node) self.mark_forloop_target(node)
elif isinstance(node, Nodes.AsyncForStatNode):
# not entirely correct, but good enough for now
self.mark_assignment(node.target, node.item)
else: # Parallel else: # Parallel
self.mark_assignment(node.target) self.mark_assignment(node.target)
......
...@@ -2071,16 +2071,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2071,16 +2071,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("%s = PyBytes_FromStringAndSize(\"\", 0); %s" % ( code.putln("%s = PyBytes_FromStringAndSize(\"\", 0); %s" % (
Naming.empty_bytes, code.error_goto_if_null(Naming.empty_bytes, self.pos))) Naming.empty_bytes, code.error_goto_if_null(Naming.empty_bytes, self.pos)))
code.putln("#ifdef __Pyx_CyFunction_USED") for ext_type in ('CyFunction', 'FusedFunction', 'Coroutine', 'Generator', 'StopAsyncIteration'):
code.put_error_if_neg(self.pos, "__Pyx_CyFunction_init()") code.putln("#ifdef __Pyx_%s_USED" % ext_type)
code.putln("#endif") code.put_error_if_neg(self.pos, "__pyx_%s_init()" % ext_type)
code.putln("#ifdef __Pyx_FusedFunction_USED")
code.put_error_if_neg(self.pos, "__pyx_FusedFunction_init()")
code.putln("#endif")
code.putln("#ifdef __Pyx_Generator_USED")
code.put_error_if_neg(self.pos, "__pyx_Generator_init()")
code.putln("#endif") code.putln("#endif")
code.putln("/*--- Library function declarations ---*/") code.putln("/*--- Library function declarations ---*/")
......
...@@ -1576,7 +1576,9 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1576,7 +1576,9 @@ class FuncDefNode(StatNode, BlockNode):
# directive_returns [ExprNode] type defined by cython.returns(...) # directive_returns [ExprNode] type defined by cython.returns(...)
# star_arg PyArgDeclNode or None * argument # star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument # starstar_arg PyArgDeclNode or None ** argument
#
# is_async_def boolean is a Coroutine function
#
# has_fused_arguments boolean # has_fused_arguments boolean
# Whether this cdef function has fused parameters. This is needed # Whether this cdef function has fused parameters. This is needed
# by AnalyseDeclarationsTransform, so it can replace CFuncDefNodes # by AnalyseDeclarationsTransform, so it can replace CFuncDefNodes
...@@ -1588,6 +1590,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1588,6 +1590,7 @@ class FuncDefNode(StatNode, BlockNode):
pymethdef_required = False pymethdef_required = False
is_generator = False is_generator = False
is_generator_body = False is_generator_body = False
is_async_def = False
modifiers = [] modifiers = []
has_fused_arguments = False has_fused_arguments = False
star_arg = None star_arg = None
...@@ -3936,6 +3939,7 @@ class GeneratorDefNode(DefNode): ...@@ -3936,6 +3939,7 @@ class GeneratorDefNode(DefNode):
# #
is_generator = True is_generator = True
is_coroutine = False
needs_closure = True needs_closure = True
child_attrs = DefNode.child_attrs + ["gbody"] child_attrs = DefNode.child_attrs + ["gbody"]
...@@ -3956,8 +3960,9 @@ class GeneratorDefNode(DefNode): ...@@ -3956,8 +3960,9 @@ class GeneratorDefNode(DefNode):
qualname = code.intern_identifier(self.qualname) qualname = code.intern_identifier(self.qualname)
code.putln('{') code.putln('{')
code.putln('__pyx_GeneratorObject *gen = __Pyx_Generator_New(' code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New('
'(__pyx_generator_body_t) %s, (PyObject *) %s, %s, %s); %s' % ( '(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s); %s' % (
'Coroutine' if self.is_coroutine else 'Generator',
body_cname, Naming.cur_scope_cname, name, qualname, body_cname, Naming.cur_scope_cname, name, qualname,
code.error_goto_if_null('gen', self.pos))) code.error_goto_if_null('gen', self.pos)))
code.put_decref(Naming.cur_scope_cname, py_object_type) code.put_decref(Naming.cur_scope_cname, py_object_type)
...@@ -3972,13 +3977,18 @@ class GeneratorDefNode(DefNode): ...@@ -3972,13 +3977,18 @@ class GeneratorDefNode(DefNode):
code.putln('}') code.putln('}')
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
env.use_utility_code(UtilityCode.load_cached("Generator", "Generator.c")) env.use_utility_code(UtilityCode.load_cached(
'Coroutine' if self.is_coroutine else 'Generator', "Coroutine.c"))
self.gbody.generate_function_header(code, proto=True) self.gbody.generate_function_header(code, proto=True)
super(GeneratorDefNode, self).generate_function_definitions(env, code) super(GeneratorDefNode, self).generate_function_definitions(env, code)
self.gbody.generate_function_definitions(env, code) self.gbody.generate_function_definitions(env, code)
class AsyncDefNode(GeneratorDefNode):
is_coroutine = True
class GeneratorBodyDefNode(DefNode): class GeneratorBodyDefNode(DefNode):
# Main code body of a generator implemented as a DefNode. # Main code body of a generator implemented as a DefNode.
# #
...@@ -4005,7 +4015,7 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4005,7 +4015,7 @@ class GeneratorBodyDefNode(DefNode):
self.declare_generator_body(env) self.declare_generator_body(env)
def generate_function_header(self, code, proto=False): def generate_function_header(self, code, proto=False):
header = "static PyObject *%s(__pyx_GeneratorObject *%s, PyObject *%s)" % ( header = "static PyObject *%s(__pyx_CoroutineObject *%s, PyObject *%s)" % (
self.entry.func_cname, self.entry.func_cname,
Naming.generator_cname, Naming.generator_cname,
Naming.sent_value_cname) Naming.sent_value_cname)
...@@ -4070,7 +4080,7 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4070,7 +4080,7 @@ class GeneratorBodyDefNode(DefNode):
code.put_label(code.error_label) code.put_label(code.error_label)
if Future.generator_stop in env.global_scope().context.future_directives: if Future.generator_stop in env.global_scope().context.future_directives:
# PEP 479: turn accidental StopIteration exceptions into a RuntimeError # PEP 479: turn accidental StopIteration exceptions into a RuntimeError
code.globalstate.use_utility_code(UtilityCode.load_cached("pep479", "Generator.c")) code.globalstate.use_utility_code(UtilityCode.load_cached("pep479", "Coroutine.c"))
code.putln("if (unlikely(PyErr_ExceptionMatches(PyExc_StopIteration))) " code.putln("if (unlikely(PyErr_ExceptionMatches(PyExc_StopIteration))) "
"__Pyx_Generator_Replace_StopIteration();") "__Pyx_Generator_Replace_StopIteration();")
for cname, type in code.funcstate.all_managed_temps(): for cname, type in code.funcstate.all_managed_temps():
...@@ -4082,7 +4092,7 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4082,7 +4092,7 @@ class GeneratorBodyDefNode(DefNode):
code.put_xdecref(Naming.retval_cname, py_object_type) code.put_xdecref(Naming.retval_cname, py_object_type)
code.putln('%s->resume_label = -1;' % Naming.generator_cname) code.putln('%s->resume_label = -1;' % Naming.generator_cname)
# clean up as early as possible to help breaking any reference cycles # clean up as early as possible to help breaking any reference cycles
code.putln('__Pyx_Generator_clear((PyObject*)%s);' % Naming.generator_cname) code.putln('__Pyx_Coroutine_clear((PyObject*)%s);' % Naming.generator_cname)
code.put_finish_refcount_context() code.put_finish_refcount_context()
code.putln('return NULL;') code.putln('return NULL;')
code.putln("}") code.putln("}")
...@@ -5512,7 +5522,7 @@ class ReturnStatNode(StatNode): ...@@ -5512,7 +5522,7 @@ class ReturnStatNode(StatNode):
elif self.in_generator: elif self.in_generator:
# return value == raise StopIteration(value), but uncatchable # return value == raise StopIteration(value), but uncatchable
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load_cached("ReturnWithStopIteration", "Generator.c")) UtilityCode.load_cached("ReturnWithStopIteration", "Coroutine.c"))
code.putln("%s = NULL; __Pyx_ReturnWithStopIteration(%s);" % ( code.putln("%s = NULL; __Pyx_ReturnWithStopIteration(%s);" % (
Naming.retval_cname, Naming.retval_cname,
self.value.py_result())) self.value.py_result()))
...@@ -6059,40 +6069,49 @@ class DictIterationNextNode(Node): ...@@ -6059,40 +6069,49 @@ class DictIterationNextNode(Node):
target.generate_assignment_code(result, code) target.generate_assignment_code(result, code)
var.release(code) var.release(code)
def ForStatNode(pos, **kw): def ForStatNode(pos, **kw):
if 'iterator' in kw: if 'iterator' in kw:
if kw['iterator'].is_async:
return AsyncForStatNode(pos, **kw)
else:
return ForInStatNode(pos, **kw) return ForInStatNode(pos, **kw)
else: else:
return ForFromStatNode(pos, **kw) return ForFromStatNode(pos, **kw)
class ForInStatNode(LoopNode, StatNode):
# for statement class _ForInStatNode(LoopNode, StatNode):
# Base class of 'for-in' statements.
# #
# target ExprNode # target ExprNode
# iterator IteratorNode # iterator IteratorNode | AwaitExprNode(AsyncIteratorNode)
# body StatNode # body StatNode
# else_clause StatNode # else_clause StatNode
# item NextNode used internally # item NextNode | AwaitExprNode(AsyncNextNode)
# is_async boolean true for 'async for' statements
child_attrs = ["target", "iterator", "body", "else_clause"] child_attrs = ["target", "item", "iterator", "body", "else_clause"]
item = None item = None
is_async = False
def _create_item_node(self):
raise NotImplementedError("must be implemented by subclasses")
def analyse_declarations(self, env): def analyse_declarations(self, env):
from . import ExprNodes
self.target.analyse_target_declaration(env) self.target.analyse_target_declaration(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
if self.else_clause: if self.else_clause:
self.else_clause.analyse_declarations(env) self.else_clause.analyse_declarations(env)
self.item = ExprNodes.NextNode(self.iterator) self._create_item_node()
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.target = self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.iterator = self.iterator.analyse_expressions(env) self.iterator = self.iterator.analyse_expressions(env)
from . import ExprNodes self._create_item_node() # must rewrap self.item after analysis
self.item = ExprNodes.NextNode(self.iterator) # must rewrap after analysis
self.item = self.item.analyse_expressions(env) self.item = self.item.analyse_expressions(env)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \ if (not self.is_async and
self.target.type.assignable_from(self.iterator.type): (self.iterator.type.is_ptr or self.iterator.type.is_array) and
self.target.type.assignable_from(self.iterator.type)):
# C array slice optimization. # C array slice optimization.
pass pass
else: else:
...@@ -6158,6 +6177,37 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -6158,6 +6177,37 @@ class ForInStatNode(LoopNode, StatNode):
self.item.annotate(code) self.item.annotate(code)
class ForInStatNode(_ForInStatNode):
# 'for' statement
is_async = False
def _create_item_node(self):
from .ExprNodes import NextNode
self.item = NextNode(self.iterator)
class AsyncForStatNode(_ForInStatNode):
# 'async for' statement
#
# iterator AwaitExprNode(AsyncIteratorNode)
# item AwaitIterNextExprNode(AsyncIteratorNode)
is_async = True
def __init__(self, pos, iterator, **kw):
assert 'item' not in kw
from . import ExprNodes
# AwaitExprNodes must appear before running MarkClosureVisitor
kw['iterator'] = ExprNodes.AwaitExprNode(iterator.pos, arg=iterator)
kw['item'] = ExprNodes.AwaitIterNextExprNode(iterator.pos, arg=None)
_ForInStatNode.__init__(self, pos, **kw)
def _create_item_node(self):
from . import ExprNodes
self.item.arg = ExprNodes.AsyncNextNode(self.iterator)
class ForFromStatNode(LoopNode, StatNode): class ForFromStatNode(LoopNode, StatNode):
# for name from expr rel name rel expr # for name from expr rel name rel expr
# #
...@@ -6444,7 +6494,7 @@ class WithStatNode(StatNode): ...@@ -6444,7 +6494,7 @@ class WithStatNode(StatNode):
code.putln("%s = __Pyx_PyObject_LookupSpecial(%s, %s); %s" % ( code.putln("%s = __Pyx_PyObject_LookupSpecial(%s, %s); %s" % (
self.exit_var, self.exit_var,
self.manager.py_result(), self.manager.py_result(),
code.intern_identifier(EncodedString('__exit__')), code.intern_identifier(EncodedString('__aexit__' if self.is_async else '__exit__')),
code.error_goto_if_null(self.exit_var, self.pos), code.error_goto_if_null(self.exit_var, self.pos),
)) ))
code.put_gotref(self.exit_var) code.put_gotref(self.exit_var)
...@@ -7108,7 +7158,7 @@ class GILStatNode(NogilTryFinallyStatNode): ...@@ -7108,7 +7158,7 @@ class GILStatNode(NogilTryFinallyStatNode):
from .ParseTreeTransforms import YieldNodeCollector from .ParseTreeTransforms import YieldNodeCollector
collector = YieldNodeCollector() collector = YieldNodeCollector()
collector.visitchildren(body) collector.visitchildren(body)
if not collector.yields: if not collector.yields and not collector.awaits:
return return
if state == 'gil': if state == 'gil':
...@@ -7205,8 +7255,8 @@ utility_code_for_cimports = { ...@@ -7205,8 +7255,8 @@ utility_code_for_cimports = {
utility_code_for_imports = { utility_code_for_imports = {
# utility code used when special modules are imported. # utility code used when special modules are imported.
# TODO: Consider a generic user-level mechanism for importing # TODO: Consider a generic user-level mechanism for importing
'asyncio': ("__Pyx_patch_asyncio", "PatchAsyncIO", "Generator.c"), 'asyncio': ("__Pyx_patch_asyncio", "PatchAsyncIO", "Coroutine.c"),
'inspect': ("__Pyx_patch_inspect", "PatchInspect", "Generator.c"), 'inspect': ("__Pyx_patch_inspect", "PatchInspect", "Coroutine.c"),
} }
......
...@@ -200,7 +200,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -200,7 +200,7 @@ class PostParse(ScopeTrackingTransform):
node.lambda_name = EncodedString(u'lambda%d' % lambda_id) node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
collector = YieldNodeCollector() collector = YieldNodeCollector()
collector.visitchildren(node.result_expr) collector.visitchildren(node.result_expr)
if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode): if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode):
body = Nodes.ExprStatNode( body = Nodes.ExprStatNode(
node.result_expr.pos, expr=node.result_expr) node.result_expr.pos, expr=node.result_expr)
else: else:
...@@ -1219,15 +1219,19 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1219,15 +1219,19 @@ class WithTransform(CythonTransform, SkipDeclarations):
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
pos = node.pos pos = node.pos
is_async = node.is_async
body, target, manager = node.body, node.target, node.manager body, target, manager = node.body, node.target, node.manager
node.enter_call = ExprNodes.SimpleCallNode( node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode( pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager), pos, obj=ExprNodes.CloneNode(manager),
attribute=EncodedString('__enter__'), attribute=EncodedString('__aenter__' if is_async else '__enter__'),
is_special_lookup=True), is_special_lookup=True),
args=[], args=[],
is_temp=True) is_temp=True)
if is_async:
node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)
if target is not None: if target is not None:
body = Nodes.StatListNode( body = Nodes.StatListNode(
pos, stats=[ pos, stats=[
...@@ -1245,7 +1249,8 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1245,7 +1249,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, operand=ExprNodes.WithExitCallNode( pos, operand=ExprNodes.WithExitCallNode(
pos, with_stat=node, pos, with_stat=node,
test_if_run=False, test_if_run=False,
args=excinfo_target)), args=excinfo_target,
await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
body=Nodes.ReraiseStatNode(pos), body=Nodes.ReraiseStatNode(pos),
), ),
], ],
...@@ -1266,8 +1271,8 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1266,8 +1271,8 @@ class WithTransform(CythonTransform, SkipDeclarations):
pos, with_stat=node, pos, with_stat=node,
test_if_run=True, test_if_run=True,
args=ExprNodes.TupleNode( args=ExprNodes.TupleNode(
pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)] pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
))), await=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
handle_error_case=False, handle_error_case=False,
) )
return node return node
...@@ -2205,6 +2210,7 @@ class YieldNodeCollector(TreeVisitor): ...@@ -2205,6 +2210,7 @@ class YieldNodeCollector(TreeVisitor):
def __init__(self): def __init__(self):
super(YieldNodeCollector, self).__init__() super(YieldNodeCollector, self).__init__()
self.yields = [] self.yields = []
self.awaits = []
self.returns = [] self.returns = []
self.has_return_value = False self.has_return_value = False
...@@ -2215,6 +2221,10 @@ class YieldNodeCollector(TreeVisitor): ...@@ -2215,6 +2221,10 @@ class YieldNodeCollector(TreeVisitor):
self.yields.append(node) self.yields.append(node)
self.visitchildren(node) self.visitchildren(node)
def visit_AwaitExprNode(self, node):
self.awaits.append(node)
self.visitchildren(node)
def visit_ReturnStatNode(self, node): def visit_ReturnStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
if node.value: if node.value:
...@@ -2250,27 +2260,36 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2250,27 +2260,36 @@ class MarkClosureVisitor(CythonTransform):
collector = YieldNodeCollector() collector = YieldNodeCollector()
collector.visitchildren(node) collector.visitchildren(node)
if node.is_async_def:
if collector.yields: if collector.yields:
if isinstance(node, Nodes.CFuncDefNode): error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')")
# Will report error later yields = collector.awaits
elif collector.yields:
if collector.awaits:
error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')")
yields = collector.yields
else:
return node return node
for i, yield_expr in enumerate(collector.yields, 1):
for i, yield_expr in enumerate(yields, 1):
yield_expr.label_num = i yield_expr.label_num = i
for retnode in collector.returns: for retnode in collector.returns:
retnode.in_generator = True retnode.in_generator = True
gbody = Nodes.GeneratorBodyDefNode( gbody = Nodes.GeneratorBodyDefNode(
pos=node.pos, name=node.name, body=node.body) pos=node.pos, name=node.name, body=node.body)
generator = Nodes.GeneratorDefNode( coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)(
pos=node.pos, name=node.name, args=node.args, pos=node.pos, name=node.name, args=node.args,
star_arg=node.star_arg, starstar_arg=node.starstar_arg, star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators, doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name) gbody=gbody, lambda_name=node.lambda_name)
return generator return coroutine
return node
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node) self.needs_closure = False
self.visitchildren(node)
node.needs_closure = self.needs_closure
self.needs_closure = True
if node.needs_closure and node.overridable: if node.needs_closure and node.overridable:
error(node.pos, "closures inside cpdef functions not yet supported") error(node.pos, "closures inside cpdef functions not yet supported")
return node return node
...@@ -2287,6 +2306,7 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2287,6 +2306,7 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True self.needs_closure = True
return node return node
class CreateClosureClasses(CythonTransform): class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
# that really need it. # that really need it.
......
...@@ -44,6 +44,7 @@ cdef p_typecast(PyrexScanner s) ...@@ -44,6 +44,7 @@ cdef p_typecast(PyrexScanner s)
cdef p_sizeof(PyrexScanner s) cdef p_sizeof(PyrexScanner s)
cdef p_yield_expression(PyrexScanner s) cdef p_yield_expression(PyrexScanner s)
cdef p_yield_statement(PyrexScanner s) cdef p_yield_statement(PyrexScanner s)
cdef p_async_statement(PyrexScanner s, ctx)
cdef p_power(PyrexScanner s) cdef p_power(PyrexScanner s)
cdef p_new_expr(PyrexScanner s) cdef p_new_expr(PyrexScanner s)
cdef p_trailer(PyrexScanner s, node1) cdef p_trailer(PyrexScanner s, node1)
...@@ -107,18 +108,18 @@ cdef p_if_statement(PyrexScanner s) ...@@ -107,18 +108,18 @@ cdef p_if_statement(PyrexScanner s)
cdef p_if_clause(PyrexScanner s) cdef p_if_clause(PyrexScanner s)
cdef p_else_clause(PyrexScanner s) cdef p_else_clause(PyrexScanner s)
cdef p_while_statement(PyrexScanner s) cdef p_while_statement(PyrexScanner s)
cdef p_for_statement(PyrexScanner s) cdef p_for_statement(PyrexScanner s, bint is_async=*)
cdef dict p_for_bounds(PyrexScanner s, bint allow_testlist = *) cdef dict p_for_bounds(PyrexScanner s, bint allow_testlist=*, bint is_async=*)
cdef p_for_from_relation(PyrexScanner s) cdef p_for_from_relation(PyrexScanner s)
cdef p_for_from_step(PyrexScanner s) cdef p_for_from_step(PyrexScanner s)
cdef p_target(PyrexScanner s, terminator) cdef p_target(PyrexScanner s, terminator)
cdef p_for_target(PyrexScanner s) cdef p_for_target(PyrexScanner s)
cdef p_for_iterator(PyrexScanner s, bint allow_testlist = *) cdef p_for_iterator(PyrexScanner s, bint allow_testlist=*, bint is_async=*)
cdef p_try_statement(PyrexScanner s) cdef p_try_statement(PyrexScanner s)
cdef p_except_clause(PyrexScanner s) cdef p_except_clause(PyrexScanner s)
cdef p_include_statement(PyrexScanner s, ctx) cdef p_include_statement(PyrexScanner s, ctx)
cdef p_with_statement(PyrexScanner s) cdef p_with_statement(PyrexScanner s)
cdef p_with_items(PyrexScanner s) cdef p_with_items(PyrexScanner s, bint is_async=*)
cdef p_with_template(PyrexScanner s) cdef p_with_template(PyrexScanner s)
cdef p_simple_statement(PyrexScanner s, bint first_statement = *) cdef p_simple_statement(PyrexScanner s, bint first_statement = *)
cdef p_simple_statement_list(PyrexScanner s, ctx, bint first_statement = *) cdef p_simple_statement_list(PyrexScanner s, ctx, bint first_statement = *)
...@@ -128,7 +129,7 @@ cdef p_IF_statement(PyrexScanner s, ctx) ...@@ -128,7 +129,7 @@ cdef p_IF_statement(PyrexScanner s, ctx)
cdef p_statement(PyrexScanner s, ctx, bint first_statement = *) cdef p_statement(PyrexScanner s, ctx, bint first_statement = *)
cdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *) cdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *)
cdef p_suite(PyrexScanner s, ctx = *) cdef p_suite(PyrexScanner s, ctx = *)
cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, with_doc_only = *) cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, bint with_doc_only=*)
cdef tuple _extract_docstring(node) cdef tuple _extract_docstring(node)
cdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, templates = *) cdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, templates = *)
...@@ -176,7 +177,7 @@ cdef p_c_modifiers(PyrexScanner s) ...@@ -176,7 +177,7 @@ cdef p_c_modifiers(PyrexScanner s)
cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx) cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx)
cdef p_ctypedef_statement(PyrexScanner s, ctx) cdef p_ctypedef_statement(PyrexScanner s, ctx)
cdef p_decorators(PyrexScanner s) cdef p_decorators(PyrexScanner s)
cdef p_def_statement(PyrexScanner s, list decorators = *) cdef p_def_statement(PyrexScanner s, list decorators=*, bint is_async_def=*)
cdef p_varargslist(PyrexScanner s, terminator=*, bint annotated = *) cdef p_varargslist(PyrexScanner s, terminator=*, bint annotated = *)
cdef p_py_arg_decl(PyrexScanner s, bint annotated = *) cdef p_py_arg_decl(PyrexScanner s, bint annotated = *)
cdef p_class_statement(PyrexScanner s, decorators) cdef p_class_statement(PyrexScanner s, decorators)
......
...@@ -55,6 +55,7 @@ class Ctx(object): ...@@ -55,6 +55,7 @@ class Ctx(object):
d.update(kwds) d.update(kwds)
return ctx return ctx
def p_ident(s, message="Expected an identifier"): def p_ident(s, message="Expected an identifier"):
if s.sy == 'IDENT': if s.sy == 'IDENT':
name = s.systring name = s.systring
...@@ -350,6 +351,7 @@ def p_sizeof(s): ...@@ -350,6 +351,7 @@ def p_sizeof(s):
s.expect(')') s.expect(')')
return node return node
def p_yield_expression(s): def p_yield_expression(s):
# s.sy == "yield" # s.sy == "yield"
pos = s.position() pos = s.position()
...@@ -370,19 +372,47 @@ def p_yield_expression(s): ...@@ -370,19 +372,47 @@ def p_yield_expression(s):
else: else:
return ExprNodes.YieldExprNode(pos, arg=arg) return ExprNodes.YieldExprNode(pos, arg=arg)
def p_yield_statement(s): def p_yield_statement(s):
# s.sy == "yield" # s.sy == "yield"
yield_expr = p_yield_expression(s) yield_expr = p_yield_expression(s)
return Nodes.ExprStatNode(yield_expr.pos, expr=yield_expr) return Nodes.ExprStatNode(yield_expr.pos, expr=yield_expr)
#power: atom trailer* ('**' factor)*
def p_async_statement(s, ctx, decorators):
# s.sy >> 'async' ...
if s.sy == 'def':
# 'async def' statements aren't allowed in pxd files
if 'pxd' in ctx.level:
s.error('def statement not allowed here')
s.level = ctx.level
return p_def_statement(s, decorators, is_async_def=True)
elif decorators:
s.error("Decorators can only be followed by functions or classes")
elif s.sy == 'for':
return p_for_statement(s, is_async=True)
elif s.sy == 'with':
s.next()
return p_with_items(s, is_async=True)
else:
s.error("expected one of 'def', 'for', 'with' after 'async'")
#power: atom_expr ('**' factor)*
#atom_expr: ['await'] atom trailer*
def p_power(s): def p_power(s):
if s.systring == 'new' and s.peek()[0] == 'IDENT': if s.systring == 'new' and s.peek()[0] == 'IDENT':
return p_new_expr(s) return p_new_expr(s)
await_pos = None
if s.sy == 'await':
await_pos = s.position()
s.next()
n1 = p_atom(s) n1 = p_atom(s)
while s.sy in ('(', '[', '.'): while s.sy in ('(', '[', '.'):
n1 = p_trailer(s, n1) n1 = p_trailer(s, n1)
if await_pos:
n1 = ExprNodes.AwaitExprNode(await_pos, arg=n1)
if s.sy == '**': if s.sy == '**':
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -390,6 +420,7 @@ def p_power(s): ...@@ -390,6 +420,7 @@ def p_power(s):
n1 = ExprNodes.binop_node(pos, '**', n1, n2) n1 = ExprNodes.binop_node(pos, '**', n1, n2)
return n1 return n1
def p_new_expr(s): def p_new_expr(s):
# s.systring == 'new'. # s.systring == 'new'.
pos = s.position() pos = s.position()
...@@ -1568,23 +1599,25 @@ def p_while_statement(s): ...@@ -1568,23 +1599,25 @@ def p_while_statement(s):
condition = test, body = body, condition = test, body = body,
else_clause = else_clause) else_clause = else_clause)
def p_for_statement(s):
def p_for_statement(s, is_async=False):
# s.sy == 'for' # s.sy == 'for'
pos = s.position() pos = s.position()
s.next() s.next()
kw = p_for_bounds(s, allow_testlist=True) kw = p_for_bounds(s, allow_testlist=True, is_async=is_async)
body = p_suite(s) body = p_suite(s)
else_clause = p_else_clause(s) else_clause = p_else_clause(s)
kw.update(body = body, else_clause = else_clause) kw.update(body=body, else_clause=else_clause, is_async=is_async)
return Nodes.ForStatNode(pos, **kw) return Nodes.ForStatNode(pos, **kw)
def p_for_bounds(s, allow_testlist=True):
def p_for_bounds(s, allow_testlist=True, is_async=False):
target = p_for_target(s) target = p_for_target(s)
if s.sy == 'in': if s.sy == 'in':
s.next() s.next()
iterator = p_for_iterator(s, allow_testlist) iterator = p_for_iterator(s, allow_testlist, is_async=is_async)
return dict( target = target, iterator = iterator ) return dict(target=target, iterator=iterator)
elif not s.in_python_file: elif not s.in_python_file and not is_async:
if s.sy == 'from': if s.sy == 'from':
s.next() s.next()
bound1 = p_bit_expr(s) bound1 = p_bit_expr(s)
...@@ -1654,16 +1687,19 @@ def p_target(s, terminator): ...@@ -1654,16 +1687,19 @@ def p_target(s, terminator):
else: else:
return expr return expr
def p_for_target(s): def p_for_target(s):
return p_target(s, 'in') return p_target(s, 'in')
def p_for_iterator(s, allow_testlist=True):
def p_for_iterator(s, allow_testlist=True, is_async=False):
pos = s.position() pos = s.position()
if allow_testlist: if allow_testlist:
expr = p_testlist(s) expr = p_testlist(s)
else: else:
expr = p_or_test(s) expr = p_or_test(s)
return ExprNodes.IteratorNode(pos, sequence = expr) return (ExprNodes.AsyncIteratorNode if is_async else ExprNodes.IteratorNode)(pos, sequence=expr)
def p_try_statement(s): def p_try_statement(s):
# s.sy == 'try' # s.sy == 'try'
...@@ -1745,6 +1781,7 @@ def p_include_statement(s, ctx): ...@@ -1745,6 +1781,7 @@ def p_include_statement(s, ctx):
else: else:
return Nodes.PassStatNode(pos) return Nodes.PassStatNode(pos)
def p_with_statement(s): def p_with_statement(s):
s.next() # 'with' s.next() # 'with'
if s.systring == 'template' and not s.in_python_file: if s.systring == 'template' and not s.in_python_file:
...@@ -1753,9 +1790,12 @@ def p_with_statement(s): ...@@ -1753,9 +1790,12 @@ def p_with_statement(s):
node = p_with_items(s) node = p_with_items(s)
return node return node
def p_with_items(s):
def p_with_items(s, is_async=False):
pos = s.position() pos = s.position()
if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'): if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'):
if is_async:
s.error("with gil/nogil cannot be async")
state = s.systring state = s.systring
s.next() s.next()
if s.sy == ',': if s.sy == ',':
...@@ -1763,7 +1803,7 @@ def p_with_items(s): ...@@ -1763,7 +1803,7 @@ def p_with_items(s):
body = p_with_items(s) body = p_with_items(s)
else: else:
body = p_suite(s) body = p_suite(s)
return Nodes.GILStatNode(pos, state = state, body = body) return Nodes.GILStatNode(pos, state=state, body=body)
else: else:
manager = p_test(s) manager = p_test(s)
target = None target = None
...@@ -1772,11 +1812,11 @@ def p_with_items(s): ...@@ -1772,11 +1812,11 @@ def p_with_items(s):
target = p_starred_expr(s) target = p_starred_expr(s)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
body = p_with_items(s) body = p_with_items(s, is_async=is_async)
else: else:
body = p_suite(s) body = p_suite(s)
return Nodes.WithStatNode(pos, manager = manager, return Nodes.WithStatNode(pos, manager=manager, target=target, body=body, is_async=is_async)
target = target, body = body)
def p_with_template(s): def p_with_template(s):
pos = s.position() pos = s.position()
...@@ -1929,12 +1969,14 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1929,12 +1969,14 @@ def p_statement(s, ctx, first_statement = 0):
s.error('decorator not allowed here') s.error('decorator not allowed here')
s.level = ctx.level s.level = ctx.level
decorators = p_decorators(s) decorators = p_decorators(s)
bad_toks = 'def', 'cdef', 'cpdef', 'class' if not ctx.allow_struct_enum_decorator and s.sy not in ('def', 'cdef', 'cpdef', 'class'):
if not ctx.allow_struct_enum_decorator and s.sy not in bad_toks: if s.sy == 'IDENT' and s.systring == 'async':
pass # handled below
else:
s.error("Decorators can only be followed by functions or classes") s.error("Decorators can only be followed by functions or classes")
elif s.sy == 'pass' and cdef_flag: elif s.sy == 'pass' and cdef_flag:
# empty cdef block # empty cdef block
return p_pass_statement(s, with_newline = 1) return p_pass_statement(s, with_newline=1)
overridable = 0 overridable = 0
if s.sy == 'cdef': if s.sy == 'cdef':
...@@ -1948,11 +1990,11 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1948,11 +1990,11 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'): if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'):
s.error('cdef statement not allowed here') s.error('cdef statement not allowed here')
s.level = ctx.level s.level = ctx.level
node = p_cdef_statement(s, ctx(overridable = overridable)) node = p_cdef_statement(s, ctx(overridable=overridable))
if decorators is not None: if decorators is not None:
tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode tup = (Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode)
if ctx.allow_struct_enum_decorator: if ctx.allow_struct_enum_decorator:
tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode tup += (Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode)
if not isinstance(node, tup): if not isinstance(node, tup):
s.error("Decorators can only be followed by functions or classes") s.error("Decorators can only be followed by functions or classes")
node.decorators = decorators node.decorators = decorators
...@@ -1995,9 +2037,25 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1995,9 +2037,25 @@ def p_statement(s, ctx, first_statement = 0):
return p_try_statement(s) return p_try_statement(s)
elif s.sy == 'with': elif s.sy == 'with':
return p_with_statement(s) return p_with_statement(s)
elif s.sy == 'async':
s.next()
return p_async_statement(s, ctx, decorators)
else: else:
return p_simple_statement_list( if s.sy == 'IDENT' and s.systring == 'async':
s, ctx, first_statement = first_statement) # PEP 492 enables the async/await keywords when it spots "async def ..."
s.next()
if s.sy == 'def':
s.enable_keyword('async')
s.enable_keyword('await')
result = p_async_statement(s, ctx, decorators)
s.enable_keyword('await')
s.disable_keyword('async')
return result
elif decorators:
s.error("Decorators can only be followed by functions or classes")
s.put_back('IDENT', 'async')
return p_simple_statement_list(s, ctx, first_statement=first_statement)
def p_statement_list(s, ctx, first_statement = 0): def p_statement_list(s, ctx, first_statement = 0):
# Parse a series of statements separated by newlines. # Parse a series of statements separated by newlines.
...@@ -3002,7 +3060,8 @@ def p_decorators(s): ...@@ -3002,7 +3060,8 @@ def p_decorators(s):
s.expect_newline("Expected a newline after decorator") s.expect_newline("Expected a newline after decorator")
return decorators return decorators
def p_def_statement(s, decorators=None):
def p_def_statement(s, decorators=None, is_async_def=False):
# s.sy == 'def' # s.sy == 'def'
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -3017,10 +3076,11 @@ def p_def_statement(s, decorators=None): ...@@ -3017,10 +3076,11 @@ def p_def_statement(s, decorators=None):
s.next() s.next()
return_type_annotation = p_test(s) return_type_annotation = p_test(s)
doc, body = p_suite_with_docstring(s, Ctx(level='function')) doc, body = p_suite_with_docstring(s, Ctx(level='function'))
return Nodes.DefNode(pos, name = name, args = args, return Nodes.DefNode(
star_arg = star_arg, starstar_arg = starstar_arg, pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg,
doc = doc, body = body, decorators = decorators, doc=doc, body=body, decorators=decorators, is_async_def=is_async_def,
return_type_annotation = return_type_annotation) return_type_annotation=return_type_annotation)
def p_varargslist(s, terminator=')', annotated=1): def p_varargslist(s, terminator=')', annotated=1):
args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1, args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1,
......
...@@ -172,12 +172,12 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -172,12 +172,12 @@ def create_pipeline(context, mode, exclude_classes=()):
InterpretCompilerDirectives(context, context.compiler_directives), InterpretCompilerDirectives(context, context.compiler_directives),
ParallelRangeTransform(context), ParallelRangeTransform(context),
AdjustDefByDirectives(context), AdjustDefByDirectives(context),
WithTransform(context),
MarkClosureVisitor(context), MarkClosureVisitor(context),
_align_function_definitions, _align_function_definitions,
RemoveUnreachableCode(context), RemoveUnreachableCode(context),
ConstantFolding(), ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
ForwardDeclareTypes(context), ForwardDeclareTypes(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
......
...@@ -30,6 +30,7 @@ cdef class PyrexScanner(Scanner): ...@@ -30,6 +30,7 @@ cdef class PyrexScanner(Scanner):
cdef public bint in_python_file cdef public bint in_python_file
cdef public source_encoding cdef public source_encoding
cdef set keywords cdef set keywords
cdef public dict keywords_stack
cdef public list indentation_stack cdef public list indentation_stack
cdef public indentation_char cdef public indentation_char
cdef public int bracket_nesting_level cdef public int bracket_nesting_level
...@@ -57,3 +58,5 @@ cdef class PyrexScanner(Scanner): ...@@ -57,3 +58,5 @@ cdef class PyrexScanner(Scanner):
cdef expect_indent(self) cdef expect_indent(self)
cdef expect_dedent(self) cdef expect_dedent(self)
cdef expect_newline(self, message=*, bint ignore_semicolon=*) cdef expect_newline(self, message=*, bint ignore_semicolon=*)
cdef enable_keyword(self, name)
cdef disable_keyword(self, name)
...@@ -319,6 +319,7 @@ class PyrexScanner(Scanner): ...@@ -319,6 +319,7 @@ class PyrexScanner(Scanner):
self.in_python_file = False self.in_python_file = False
self.keywords = set(pyx_reserved_words) self.keywords = set(pyx_reserved_words)
self.trace = trace_scanner self.trace = trace_scanner
self.keywords_stack = {}
self.indentation_stack = [0] self.indentation_stack = [0]
self.indentation_char = None self.indentation_char = None
self.bracket_nesting_level = 0 self.bracket_nesting_level = 0
...@@ -497,3 +498,18 @@ class PyrexScanner(Scanner): ...@@ -497,3 +498,18 @@ class PyrexScanner(Scanner):
self.expect('NEWLINE', message) self.expect('NEWLINE', message)
if useless_trailing_semicolon is not None: if useless_trailing_semicolon is not None:
warning(useless_trailing_semicolon, "useless trailing semicolon") warning(useless_trailing_semicolon, "useless trailing semicolon")
def enable_keyword(self, name):
if name in self.keywords_stack:
self.keywords_stack[name] += 1
else:
self.keywords_stack[name] = 1
self.keywords.add(name)
def disable_keyword(self, name):
count = self.keywords_stack.get(name, 1)
if count == 1:
self.keywords.discard(name)
del self.keywords_stack[name]
else:
self.keywords_stack[name] = count - 1
...@@ -431,8 +431,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -431,8 +431,8 @@ class SuiteSlot(SlotDescriptor):
# #
# sub_slots [SlotDescriptor] # sub_slots [SlotDescriptor]
def __init__(self, sub_slots, slot_type, slot_name): def __init__(self, sub_slots, slot_type, slot_name, ifdef=None):
SlotDescriptor.__init__(self, slot_name) SlotDescriptor.__init__(self, slot_name, ifdef=ifdef)
self.sub_slots = sub_slots self.sub_slots = sub_slots
self.slot_type = slot_type self.slot_type = slot_type
substructures.append(self) substructures.append(self)
...@@ -454,6 +454,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -454,6 +454,8 @@ class SuiteSlot(SlotDescriptor):
def generate_substructure(self, scope, code): def generate_substructure(self, scope, code):
if not self.is_empty(scope): if not self.is_empty(scope):
code.putln("") code.putln("")
if self.ifdef:
code.putln("#if %s" % self.ifdef)
code.putln( code.putln(
"static %s %s = {" % ( "static %s %s = {" % (
self.slot_type, self.slot_type,
...@@ -461,6 +463,8 @@ class SuiteSlot(SlotDescriptor): ...@@ -461,6 +463,8 @@ class SuiteSlot(SlotDescriptor):
for slot in self.sub_slots: for slot in self.sub_slots:
slot.generate(scope, code) slot.generate(scope, code)
code.putln("};") code.putln("};")
if self.ifdef:
code.putln("#endif")
substructures = [] # List of all SuiteSlot instances substructures = [] # List of all SuiteSlot instances
...@@ -506,6 +510,29 @@ class BaseClassSlot(SlotDescriptor): ...@@ -506,6 +510,29 @@ class BaseClassSlot(SlotDescriptor):
base_type.typeptr_cname)) base_type.typeptr_cname))
class AlternativeSlot(SlotDescriptor):
"""Slot descriptor that delegates to different slots using C macros."""
def __init__(self, alternatives):
SlotDescriptor.__init__(self, "")
self.alternatives = alternatives
def generate(self, scope, code):
# state machine: "#if ... (#elif ...)* #else ... #endif"
test = 'if'
for guard, slot in self.alternatives:
if guard:
assert test in ('if', 'elif'), test
else:
assert test == 'elif', test
test = 'else'
code.putln("#%s %s" % (test, guard))
slot.generate(scope, code)
if test == 'if':
test = 'elif'
assert test == 'else', test
code.putln("#endif")
# The following dictionary maps __xxx__ method names to slot descriptors. # The following dictionary maps __xxx__ method names to slot descriptors.
method_name_to_slot = {} method_name_to_slot = {}
...@@ -748,6 +775,12 @@ PyBufferProcs = ( ...@@ -748,6 +775,12 @@ PyBufferProcs = (
MethodSlot(releasebufferproc, "bf_releasebuffer", "__releasebuffer__") MethodSlot(releasebufferproc, "bf_releasebuffer", "__releasebuffer__")
) )
PyAsyncMethods = (
MethodSlot(unaryfunc, "am_await", "__await__"),
MethodSlot(unaryfunc, "am_aiter", "__aiter__"),
MethodSlot(unaryfunc, "am_anext", "__anext__"),
)
#------------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------------
# #
# The main slot table. This table contains descriptors for all the # The main slot table. This table contains descriptors for all the
...@@ -761,7 +794,11 @@ slot_table = ( ...@@ -761,7 +794,11 @@ slot_table = (
EmptySlot("tp_print"), #MethodSlot(printfunc, "tp_print", "__print__"), EmptySlot("tp_print"), #MethodSlot(printfunc, "tp_print", "__print__"),
EmptySlot("tp_getattr"), EmptySlot("tp_getattr"),
EmptySlot("tp_setattr"), EmptySlot("tp_setattr"),
MethodSlot(cmpfunc, "tp_compare", "__cmp__", py3 = '<RESERVED>'), AlternativeSlot([
("PY_MAJOR_VERSION < 3", MethodSlot(cmpfunc, "tp_compare", "__cmp__")),
("PY_VERSION_HEX < 0x030500B1", EmptySlot("tp_reserved")),
("", SuiteSlot(PyAsyncMethods, "PyAsyncMethods", "tp_as_async", ifdef="PY_VERSION_HEX >= 0x030500B1")),
]),
MethodSlot(reprfunc, "tp_repr", "__repr__"), MethodSlot(reprfunc, "tp_repr", "__repr__"),
SuiteSlot(PyNumberMethods, "PyNumberMethods", "tp_as_number"), SuiteSlot(PyNumberMethods, "PyNumberMethods", "tp_as_number"),
......
...@@ -13,7 +13,8 @@ eval_input: testlist NEWLINE* ENDMARKER ...@@ -13,7 +13,8 @@ eval_input: testlist NEWLINE* ENDMARKER
decorator: '@' dotted_PY_NAME [ '(' [arglist] ')' ] NEWLINE decorator: '@' dotted_PY_NAME [ '(' [arglist] ')' ] NEWLINE
decorators: decorator+ decorators: decorator+
decorated: decorators (classdef | funcdef | cdef_stmt) decorated: decorators (classdef | funcdef | async_funcdef | cdef_stmt)
async_funcdef: 'async' funcdef
funcdef: 'def' PY_NAME parameters ['->' test] ':' suite funcdef: 'def' PY_NAME parameters ['->' test] ':' suite
parameters: '(' [typedargslist] ')' parameters: '(' [typedargslist] ')'
typedargslist: (tfpdef ['=' (test | '*')] (',' tfpdef ['=' (test | '*')])* [',' typedargslist: (tfpdef ['=' (test | '*')] (',' tfpdef ['=' (test | '*')])* [','
...@@ -96,7 +97,8 @@ shift_expr: arith_expr (('<<'|'>>') arith_expr)* ...@@ -96,7 +97,8 @@ shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)* arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)* term: factor (('*'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power | address | size_of | cast factor: ('+'|'-'|'~') factor | power | address | size_of | cast
power: atom trailer* ['**' factor] power: atom_expr ['**' factor]
atom_expr: ['await'] atom trailer*
atom: ('(' [yield_expr|testlist_comp] ')' | atom: ('(' [yield_expr|testlist_comp] ')' |
'[' [testlist_comp] ']' | '[' [testlist_comp] ']' |
'{' [dictorsetmaker] '}' | '{' [dictorsetmaker] '}' |
......
...@@ -68,7 +68,7 @@ static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m, ...@@ -68,7 +68,7 @@ static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m,
PyObject *dict); PyObject *dict);
static int __Pyx_CyFunction_init(void); static int __pyx_CyFunction_init(void);
//////////////////// CythonFunction //////////////////// //////////////////// CythonFunction ////////////////////
//@substitute: naming //@substitute: naming
...@@ -693,7 +693,7 @@ static PyTypeObject __pyx_CyFunctionType_type = { ...@@ -693,7 +693,7 @@ static PyTypeObject __pyx_CyFunctionType_type = {
}; };
static int __Pyx_CyFunction_init(void) { static int __pyx_CyFunction_init(void) {
#if !CYTHON_COMPILING_IN_PYPY #if !CYTHON_COMPILING_IN_PYPY
// avoid a useless level of call indirection // avoid a useless level of call indirection
__pyx_CyFunctionType_type.tp_call = PyCFunction_Call; __pyx_CyFunctionType_type.tp_call = PyCFunction_Call;
......
# mode: error
# tag: pep492, async
async def foo():
def foo(a=await list()):
pass
_ERRORS = """
5:14: 'await' not supported here
"""
# mode: error
# tag: pep492, async
async def foo():
def foo(a:await list()):
pass
_ERRORS = """
5:14: 'await' not supported here
5:14: 'await' not supported here
"""
# mode: error
# tag: pep492, async
async def foo():
[i async for i in els]
_ERRORS = """
5:7: Expected ']', found 'async'
"""
# mode: error
# tag: pep492, async
async def foo():
async def foo(): await list()
_ERRORS = """
# ??? - this fails in CPython, not sure why ...
"""
# mode: error
# tag: pep492, async
def foo():
await list()
_ERRORS = """
5:10: Syntax error in simple statement list
"""
# mode: error
# tag: pep492, async
async def foo():
yield
_ERRORS = """
5:4: 'yield' not allowed in async coroutines (use 'await')
5:4: 'yield' not supported here
"""
# mode: error
# tag: pep492, async
async def foo():
yield from []
_ERRORS = """
5:4: 'yield from' not supported here
5:4: 'yield' not allowed in async coroutines (use 'await')
"""
# mode: error
# tag: pep492, async
async def foo():
await await fut
_ERRORS = """
5:10: Expected an identifier or literal
"""
# mode: error
# tag: pep492, async
async def foo():
await
_ERRORS = """
5:9: Expected an identifier or literal
"""
# mode: run
# tag: pep492, asyncfor, await
import sys
if sys.version_info >= (3, 5, 0, 'beta'):
# pass Cython implemented AsyncIter() into a Python async-for loop
__doc__ = u"""
>>> def test_py35():
... buffer = []
... async def coro():
... async for i1, i2 in AsyncIter(1):
... buffer.append(i1 + i2)
... return coro, buffer
>>> testfunc, buffer = test_py35()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc(), check_type=False)
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i*2 for i in range(1, 101)] or buffer
True
"""
cdef class AsyncYieldFrom:
cdef object obj
def __init__(self, obj):
self.obj = obj
def __await__(self):
yield from self.obj
cdef class AsyncYield:
cdef object value
def __init__(self, value):
self.value = value
def __await__(self):
yield self.value
def run_async(coro, check_type='coroutine'):
if check_type:
assert coro.__class__.__name__ == check_type, \
'type(%s) != %s' % (coro.__class__, check_type)
buffer = []
result = None
while True:
try:
buffer.append(coro.send(None))
except StopIteration as ex:
result = ex.args[0] if ex.args else None
break
return buffer, result
cdef class AsyncIter:
cdef long i
cdef long aiter_calls
cdef long max_iter_calls
def __init__(self, long max_iter_calls=1):
self.i = 0
self.aiter_calls = 0
self.max_iter_calls = max_iter_calls
async def __aiter__(self):
self.aiter_calls += 1
return self
async def __anext__(self):
self.i += 1
assert self.aiter_calls <= self.max_iter_calls
if not (self.i % 10):
await AsyncYield(self.i * 10)
if self.i > 100:
raise StopAsyncIteration
return self.i, self.i
def test_for_1():
"""
>>> testfunc, buffer = test_for_1()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i*2 for i in range(1, 101)] or buffer
True
"""
buffer = []
async def test1():
async for i1, i2 in AsyncIter(1):
buffer.append(i1 + i2)
return test1, buffer
def test_for_2():
"""
>>> testfunc, buffer = test_for_2()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [100, 200] or yielded
True
>>> buffer == [i for i in range(1, 21)] + ['end'] or buffer
True
"""
buffer = []
async def test2():
nonlocal buffer
async for i in AsyncIter(2):
buffer.append(i[0])
if i[0] == 20:
break
else:
buffer.append('what?')
buffer.append('end')
return test2, buffer
def test_for_3():
"""
>>> testfunc, buffer = test_for_3()
>>> buffer
[]
>>> yielded, _ = run_async(testfunc())
>>> yielded == [i * 100 for i in range(1, 11)] or yielded
True
>>> buffer == [i for i in range(1, 21)] + ['what?', 'end'] or buffer
True
"""
buffer = []
async def test3():
nonlocal buffer
async for i in AsyncIter(3):
if i[0] > 20:
continue
buffer.append(i[0])
else:
buffer.append('what?')
buffer.append('end')
return test3, buffer
cdef class NonAwaitableFromAnext:
async def __aiter__(self):
return self
def __anext__(self):
return 123
def test_broken_anext():
"""
>>> testfunc = test_broken_anext()
>>> try: run_async(testfunc())
... except TypeError as exc:
... assert ' int ' in str(exc)
... else:
... print("NOT RAISED!")
"""
async def foo():
async for i in NonAwaitableFromAnext():
print('never going to happen')
return foo
cdef class Manager:
cdef readonly list counter
def __init__(self, counter):
self.counter = counter
async def __aenter__(self):
self.counter[0] += 10000
async def __aexit__(self, *args):
self.counter[0] += 100000
cdef class Iterable:
cdef long i
def __init__(self):
self.i = 0
async def __aiter__(self):
return self
async def __anext__(self):
if self.i > 10:
raise StopAsyncIteration
self.i += 1
return self.i
def test_with_for():
"""
>>> test_with_for()
111011
333033
20555255
"""
I = [0]
manager = Manager(I)
iterable = Iterable()
mrefs_before = sys.getrefcount(manager)
irefs_before = sys.getrefcount(iterable)
async def main():
async with manager:
async for i in iterable:
I[0] += 1
I[0] += 1000
run_async(main())
print(I[0])
assert sys.getrefcount(manager) == mrefs_before
assert sys.getrefcount(iterable) == irefs_before
##############
async def main():
nonlocal I
async with Manager(I):
async for i in Iterable():
I[0] += 1
I[0] += 1000
async with Manager(I):
async for i in Iterable():
I[0] += 1
I[0] += 1000
run_async(main())
print(I[0])
##############
async def main():
async with Manager(I):
I[0] += 100
async for i in Iterable():
I[0] += 1
else:
I[0] += 10000000
I[0] += 1000
async with Manager(I):
I[0] += 100
async for i in Iterable():
I[0] += 1
else:
I[0] += 10000000
I[0] += 1000
run_async(main())
print(I[0])
cdef class AI:
async def __aiter__(self):
1/0
def test_aiter_raises():
"""
>>> test_aiter_raises()
RAISED
0
"""
CNT = 0
async def foo():
nonlocal CNT
async for i in AI():
CNT += 1
CNT += 10
try:
run_async(foo())
except ZeroDivisionError:
print("RAISED")
else:
print("NOT RAISED")
return CNT
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment