Commit b9e0bdcf authored by Stefan Behnel's avatar Stefan Behnel

implement "async def" statement and "await" expression (PEP 492)

parent 4ed35f23
......@@ -49,7 +49,8 @@ non_portable_builtins_map = {
'basestring' : ('PY_MAJOR_VERSION >= 3', 'str'),
'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'),
'raw_input' : ('PY_MAJOR_VERSION >= 3', 'input'),
}
'StopAsyncIteration': ('PY_VERSION_HEX < 0x030500B1', 'StopIteration'),
}
basicsize_builtins_map = {
# builtins whose type has a different tp_basicsize than sizeof(...)
......
......@@ -8593,10 +8593,12 @@ class YieldExprNode(ExprNode):
type = py_object_type
label_num = 0
is_yield_from = False
is_await = False
expr_keyword = 'yield'
def analyse_types(self, env):
if not self.label_num:
error(self.pos, "'yield' not supported here")
error(self.pos, "'%s' not supported here" % self.expr_keyword)
self.is_temp = 1
if self.arg is not None:
self.arg = self.arg.analyse_types(env)
......@@ -8661,6 +8663,7 @@ class YieldExprNode(ExprNode):
class YieldFromExprNode(YieldExprNode):
# "yield from GEN" expression
is_yield_from = True
expr_keyword = 'yield from'
def coerce_yield_argument(self, env):
if not self.arg.type.is_string:
......@@ -8668,14 +8671,17 @@ class YieldFromExprNode(YieldExprNode):
error(self.pos, "yielding from non-Python object not supported")
self.arg = self.arg.coerce_to_pyobject(env)
def generate_evaluation_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("YieldFrom", "Coroutine.c"))
def yield_from_func(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("GeneratorYieldFrom", "Coroutine.c"))
return "__Pyx_Generator_Yield_From"
def generate_evaluation_code(self, code):
self.arg.generate_evaluation_code(code)
code.putln("%s = __Pyx_Generator_Yield_From(%s, %s);" % (
code.putln("%s = %s(%s, %s);" % (
Naming.retval_cname,
self.yield_from_func(code),
Naming.generator_cname,
self.arg.result_as(py_object_type)))
self.arg.py_result()))
self.arg.generate_disposal_code(code)
self.arg.free_temps(code)
code.put_xgotref(Naming.retval_cname)
......@@ -8687,9 +8693,7 @@ class YieldFromExprNode(YieldExprNode):
if self.result_is_used:
# YieldExprNode has allocated the result temp for us
code.putln("%s = NULL;" % self.result())
code.putln("if (unlikely(__Pyx_PyGen_FetchStopIterationValue(&%s) < 0)) %s" % (
self.result(),
code.error_goto(self.pos)))
code.put_error_if_neg(self.pos, "__Pyx_PyGen_FetchStopIterationValue(&%s)" % self.result())
code.put_gotref(self.result())
else:
code.putln("PyObject* exc_type = PyErr_Occurred();")
......@@ -8700,6 +8704,25 @@ class YieldFromExprNode(YieldExprNode):
code.putln("}")
code.putln("}")
class AwaitExprNode(YieldFromExprNode):
# 'await' expression node
#
# arg ExprNode the Awaitable value to await
# label_num integer yield label number
# is_yield_from boolean is a YieldFromExprNode to delegate to another generator
is_await = True
expr_keyword = 'await'
def coerce_yield_argument(self, env):
# FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ?
self.arg = self.arg.coerce_to_pyobject(env)
def yield_from_func(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("CoroutineYieldFrom", "Coroutine.c"))
return "__Pyx_Coroutine_Yield_From"
class GlobalsExprNode(AtomicExprNode):
type = dict_type
is_temp = 1
......
......@@ -1574,9 +1574,11 @@ class FuncDefNode(StatNode, BlockNode):
# pymethdef_required boolean Force Python method struct generation
# directive_locals { string : ExprNode } locals defined by cython.locals(...)
# directive_returns [ExprNode] type defined by cython.returns(...)
# star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument
# star_arg PyArgDeclNode or None * argument
# starstar_arg PyArgDeclNode or None ** argument
#
# is_async_def boolean is a Coroutine function
#
# has_fused_arguments boolean
# Whether this cdef function has fused parameters. This is needed
# by AnalyseDeclarationsTransform, so it can replace CFuncDefNodes
......@@ -1588,6 +1590,7 @@ class FuncDefNode(StatNode, BlockNode):
pymethdef_required = False
is_generator = False
is_generator_body = False
is_async_def = False
modifiers = []
has_fused_arguments = False
star_arg = None
......@@ -3936,6 +3939,7 @@ class GeneratorDefNode(DefNode):
#
is_generator = True
is_coroutine = False
needs_closure = True
child_attrs = DefNode.child_attrs + ["gbody"]
......@@ -3956,8 +3960,9 @@ class GeneratorDefNode(DefNode):
qualname = code.intern_identifier(self.qualname)
code.putln('{')
code.putln('__pyx_CoroutineObject *gen = __Pyx_Generator_New('
code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New('
'(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s); %s' % (
'Coroutine' if self.is_coroutine else 'Generator',
body_cname, Naming.cur_scope_cname, name, qualname,
code.error_goto_if_null('gen', self.pos)))
code.put_decref(Naming.cur_scope_cname, py_object_type)
......@@ -3972,13 +3977,18 @@ class GeneratorDefNode(DefNode):
code.putln('}')
def generate_function_definitions(self, env, code):
env.use_utility_code(UtilityCode.load_cached("Generator", "Coroutine.c"))
env.use_utility_code(UtilityCode.load_cached(
'Coroutine' if self.is_coroutine else 'Generator', "Coroutine.c"))
self.gbody.generate_function_header(code, proto=True)
super(GeneratorDefNode, self).generate_function_definitions(env, code)
self.gbody.generate_function_definitions(env, code)
class AsyncDefNode(GeneratorDefNode):
is_coroutine = True
class GeneratorBodyDefNode(DefNode):
# Main code body of a generator implemented as a DefNode.
#
......@@ -7108,7 +7118,7 @@ class GILStatNode(NogilTryFinallyStatNode):
from .ParseTreeTransforms import YieldNodeCollector
collector = YieldNodeCollector()
collector.visitchildren(body)
if not collector.yields:
if not collector.yields and not collector.awaits:
return
if state == 'gil':
......
......@@ -200,7 +200,7 @@ class PostParse(ScopeTrackingTransform):
node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
collector = YieldNodeCollector()
collector.visitchildren(node.result_expr)
if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode):
if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode):
body = Nodes.ExprStatNode(
node.result_expr.pos, expr=node.result_expr)
else:
......@@ -2205,6 +2205,7 @@ class YieldNodeCollector(TreeVisitor):
def __init__(self):
super(YieldNodeCollector, self).__init__()
self.yields = []
self.awaits = []
self.returns = []
self.has_return_value = False
......@@ -2215,6 +2216,10 @@ class YieldNodeCollector(TreeVisitor):
self.yields.append(node)
self.visitchildren(node)
def visit_AwaitExprNode(self, node):
self.awaits.append(node)
self.visitchildren(node)
def visit_ReturnStatNode(self, node):
self.visitchildren(node)
if node.value:
......@@ -2250,27 +2255,36 @@ class MarkClosureVisitor(CythonTransform):
collector = YieldNodeCollector()
collector.visitchildren(node)
if collector.yields:
if isinstance(node, Nodes.CFuncDefNode):
# Will report error later
return node
for i, yield_expr in enumerate(collector.yields, 1):
yield_expr.label_num = i
for retnode in collector.returns:
retnode.in_generator = True
if node.is_async_def:
if collector.yields:
error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')")
yields = collector.awaits
elif collector.yields:
if collector.awaits:
error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')")
yields = collector.yields
else:
return node
gbody = Nodes.GeneratorBodyDefNode(
pos=node.pos, name=node.name, body=node.body)
generator = Nodes.GeneratorDefNode(
pos=node.pos, name=node.name, args=node.args,
star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name)
return generator
return node
for i, yield_expr in enumerate(yields, 1):
yield_expr.label_num = i
for retnode in collector.returns:
retnode.in_generator = True
gbody = Nodes.GeneratorBodyDefNode(
pos=node.pos, name=node.name, body=node.body)
coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)(
pos=node.pos, name=node.name, args=node.args,
star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name)
return coroutine
def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node)
self.needs_closure = False
self.visitchildren(node)
node.needs_closure = self.needs_closure
self.needs_closure = True
if node.needs_closure and node.overridable:
error(node.pos, "closures inside cpdef functions not yet supported")
return node
......@@ -2287,6 +2301,7 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True
return node
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
# that really need it.
......
......@@ -44,6 +44,8 @@ cdef p_typecast(PyrexScanner s)
cdef p_sizeof(PyrexScanner s)
cdef p_yield_expression(PyrexScanner s)
cdef p_yield_statement(PyrexScanner s)
cdef p_await_expression(PyrexScanner s)
cdef p_async_statement(PyrexScanner s, ctx)
cdef p_power(PyrexScanner s)
cdef p_new_expr(PyrexScanner s)
cdef p_trailer(PyrexScanner s, node1)
......@@ -128,7 +130,7 @@ cdef p_IF_statement(PyrexScanner s, ctx)
cdef p_statement(PyrexScanner s, ctx, bint first_statement = *)
cdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *)
cdef p_suite(PyrexScanner s, ctx = *)
cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, with_doc_only = *)
cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, bint with_doc_only=*)
cdef tuple _extract_docstring(node)
cdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, templates = *)
......@@ -176,7 +178,7 @@ cdef p_c_modifiers(PyrexScanner s)
cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx)
cdef p_ctypedef_statement(PyrexScanner s, ctx)
cdef p_decorators(PyrexScanner s)
cdef p_def_statement(PyrexScanner s, list decorators = *)
cdef p_def_statement(PyrexScanner s, list decorators=*, bint is_async_def=*)
cdef p_varargslist(PyrexScanner s, terminator=*, bint annotated = *)
cdef p_py_arg_decl(PyrexScanner s, bint annotated = *)
cdef p_class_statement(PyrexScanner s, decorators)
......
......@@ -55,6 +55,7 @@ class Ctx(object):
d.update(kwds)
return ctx
def p_ident(s, message="Expected an identifier"):
if s.sy == 'IDENT':
name = s.systring
......@@ -350,6 +351,7 @@ def p_sizeof(s):
s.expect(')')
return node
def p_yield_expression(s):
# s.sy == "yield"
pos = s.position()
......@@ -370,19 +372,52 @@ def p_yield_expression(s):
else:
return ExprNodes.YieldExprNode(pos, arg=arg)
def p_yield_statement(s):
# s.sy == "yield"
yield_expr = p_yield_expression(s)
return Nodes.ExprStatNode(yield_expr.pos, expr=yield_expr)
#power: atom trailer* ('**' factor)*
def p_async_statement(s, ctx, decorators):
# s.sy >> 'async' ...
if s.sy == 'def':
# 'async def' statements aren't allowed in pxd files
if 'pxd' in ctx.level:
s.error('def statement not allowed here')
s.level = ctx.level
return p_def_statement(s, decorators, is_async_def=True)
elif decorators:
s.error("Decorators can only be followed by functions or classes")
elif s.sy == 'for':
#s.error("'async for' is not currently supported", fatal=False)
return p_statement(s, ctx) # TODO: implement
elif s.sy == 'with':
#s.error("'async with' is not currently supported", fatal=False)
return p_statement(s, ctx) # TODO: implement
else:
s.error("expected one of 'def', 'for', 'with' after 'async'")
def p_await_expression(s):
n1 = p_atom(s)
#power: atom_expr ('**' factor)*
#atom_expr: ['await'] atom trailer*
def p_power(s):
if s.systring == 'new' and s.peek()[0] == 'IDENT':
return p_new_expr(s)
await_pos = None
if s.sy == 'await':
await_pos = s.position()
s.next()
n1 = p_atom(s)
while s.sy in ('(', '[', '.'):
n1 = p_trailer(s, n1)
if await_pos:
n1 = ExprNodes.AwaitExprNode(await_pos, arg=n1)
if s.sy == '**':
pos = s.position()
s.next()
......@@ -390,6 +425,7 @@ def p_power(s):
n1 = ExprNodes.binop_node(pos, '**', n1, n2)
return n1
def p_new_expr(s):
# s.systring == 'new'.
pos = s.position()
......@@ -1929,12 +1965,14 @@ def p_statement(s, ctx, first_statement = 0):
s.error('decorator not allowed here')
s.level = ctx.level
decorators = p_decorators(s)
bad_toks = 'def', 'cdef', 'cpdef', 'class'
if not ctx.allow_struct_enum_decorator and s.sy not in bad_toks:
s.error("Decorators can only be followed by functions or classes")
if not ctx.allow_struct_enum_decorator and s.sy not in ('def', 'cdef', 'cpdef', 'class'):
if s.sy == 'IDENT' and s.systring == 'async':
pass # handled below
else:
s.error("Decorators can only be followed by functions or classes")
elif s.sy == 'pass' and cdef_flag:
# empty cdef block
return p_pass_statement(s, with_newline = 1)
return p_pass_statement(s, with_newline=1)
overridable = 0
if s.sy == 'cdef':
......@@ -1948,11 +1986,11 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'):
s.error('cdef statement not allowed here')
s.level = ctx.level
node = p_cdef_statement(s, ctx(overridable = overridable))
node = p_cdef_statement(s, ctx(overridable=overridable))
if decorators is not None:
tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode
tup = (Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode)
if ctx.allow_struct_enum_decorator:
tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode
tup += (Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode)
if not isinstance(node, tup):
s.error("Decorators can only be followed by functions or classes")
node.decorators = decorators
......@@ -1995,9 +2033,25 @@ def p_statement(s, ctx, first_statement = 0):
return p_try_statement(s)
elif s.sy == 'with':
return p_with_statement(s)
elif s.sy == 'async':
s.next()
return p_async_statement(s, ctx, decorators)
else:
return p_simple_statement_list(
s, ctx, first_statement = first_statement)
if s.sy == 'IDENT' and s.systring == 'async':
# PEP 492 enables the async/await keywords when it spots "async def ..."
s.next()
if s.sy == 'def':
s.enable_keyword('async')
s.enable_keyword('await')
result = p_async_statement(s, ctx, decorators)
s.enable_keyword('await')
s.disable_keyword('async')
return result
elif decorators:
s.error("Decorators can only be followed by functions or classes")
s.put_back('IDENT', 'async')
return p_simple_statement_list(s, ctx, first_statement=first_statement)
def p_statement_list(s, ctx, first_statement = 0):
# Parse a series of statements separated by newlines.
......@@ -3002,7 +3056,8 @@ def p_decorators(s):
s.expect_newline("Expected a newline after decorator")
return decorators
def p_def_statement(s, decorators=None):
def p_def_statement(s, decorators=None, is_async_def=False):
# s.sy == 'def'
pos = s.position()
s.next()
......@@ -3017,10 +3072,11 @@ def p_def_statement(s, decorators=None):
s.next()
return_type_annotation = p_test(s)
doc, body = p_suite_with_docstring(s, Ctx(level='function'))
return Nodes.DefNode(pos, name = name, args = args,
star_arg = star_arg, starstar_arg = starstar_arg,
doc = doc, body = body, decorators = decorators,
return_type_annotation = return_type_annotation)
return Nodes.DefNode(
pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg,
doc=doc, body=body, decorators=decorators, is_async_def=is_async_def,
return_type_annotation=return_type_annotation)
def p_varargslist(s, terminator=')', annotated=1):
args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1,
......
......@@ -30,6 +30,7 @@ cdef class PyrexScanner(Scanner):
cdef public bint in_python_file
cdef public source_encoding
cdef set keywords
cdef public dict keywords_stack
cdef public list indentation_stack
cdef public indentation_char
cdef public int bracket_nesting_level
......@@ -57,3 +58,5 @@ cdef class PyrexScanner(Scanner):
cdef expect_indent(self)
cdef expect_dedent(self)
cdef expect_newline(self, message=*, bint ignore_semicolon=*)
cdef enable_keyword(self, name)
cdef disable_keyword(self, name)
......@@ -319,6 +319,7 @@ class PyrexScanner(Scanner):
self.in_python_file = False
self.keywords = set(pyx_reserved_words)
self.trace = trace_scanner
self.keywords_stack = {}
self.indentation_stack = [0]
self.indentation_char = None
self.bracket_nesting_level = 0
......@@ -497,3 +498,18 @@ class PyrexScanner(Scanner):
self.expect('NEWLINE', message)
if useless_trailing_semicolon is not None:
warning(useless_trailing_semicolon, "useless trailing semicolon")
def enable_keyword(self, name):
if name in self.keywords_stack:
self.keywords_stack[name] += 1
else:
self.keywords_stack[name] = 1
self.keywords.add(name)
def disable_keyword(self, name):
count = self.keywords_stack.get(name, 1)
if count == 1:
self.keywords.discard(name)
del self.keywords_stack[name]
else:
self.keywords_stack[name] = count - 1
......@@ -13,7 +13,8 @@ eval_input: testlist NEWLINE* ENDMARKER
decorator: '@' dotted_PY_NAME [ '(' [arglist] ')' ] NEWLINE
decorators: decorator+
decorated: decorators (classdef | funcdef | cdef_stmt)
decorated: decorators (classdef | funcdef | async_funcdef | cdef_stmt)
async_funcdef: 'async' funcdef
funcdef: 'def' PY_NAME parameters ['->' test] ':' suite
parameters: '(' [typedargslist] ')'
typedargslist: (tfpdef ['=' (test | '*')] (',' tfpdef ['=' (test | '*')])* [','
......@@ -96,7 +97,8 @@ shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power | address | size_of | cast
power: atom trailer* ['**' factor]
power: atom_expr ['**' factor]
atom_expr: ['await'] atom trailer*
atom: ('(' [yield_expr|testlist_comp] ')' |
'[' [testlist_comp] ']' |
'{' [dictorsetmaker] '}' |
......
//////////////////// YieldFrom.proto ////////////////////
//////////////////// GeneratorYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject *gen, PyObject *source);
//////////////////// YieldFrom ////////////////////
//////////////////// GeneratorYieldFrom ////////////////////
//@requires: Generator
static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
......@@ -21,6 +21,125 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject
}
//////////////////// CoroutineYieldFrom.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source);
//////////////////// CoroutineYieldFrom ////////////////////
//@requires: Coroutine
//@requires: GetAwaitIter
static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) {
PyObject *retval;
if (__Pyx_Coroutine_CheckExact(source)) {
retval = __Pyx_Generator_Next(source);
if (retval) {
Py_INCREF(source);
gen->yieldfrom = source;
return retval;
}
} else {
PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source);
if (unlikely(!source_gen))
return NULL;
// source_gen is now the iterator, make the first next() call
if (__Pyx_Coroutine_CheckExact(source_gen)) {
retval = __Pyx_Generator_Next(source_gen);
} else {
retval = Py_TYPE(source_gen)->tp_iternext(source_gen);
}
if (retval) {
gen->yieldfrom = source_gen;
return retval;
}
Py_DECREF(source_gen);
}
return NULL;
}
//////////////////// GetAwaitIter.proto ////////////////////
static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o); /*proto*/
static PyObject *__Pyx__Coroutine_GetAwaitableIter(PyObject *o); /*proto*/
//////////////////// GetAwaitIter ////////////////////
//@requires: Coroutine
//@requires: ObjectHandling.c::PyObjectGetAttrStr
//@requires: ObjectHandling.c::PyObjectCallNoArg
//@requires: ObjectHandling.c::PyObjectCallOneArg
static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o) {
#ifdef __Pyx_Coroutine_USED
if (__Pyx_Coroutine_CheckExact(o)) {
Py_INCREF(o);
return o;
}
#endif
return __Pyx__Coroutine_GetAwaitableIter(o);
}
// copied and adapted from genobject.c in Py3.5
static PyObject *__Pyx__Coroutine_GetAwaitableIter(PyObject *o) {
PyObject *res;
#if PY_VERSION_HEX >= 0x030500B1
unaryfunc getter = NULL;
PyTypeObject *ot;
ot = Py_TYPE(o);
if (likely(ot->tp_as_async)) {
getter = (unaryfunc) ot->tp_as_async->am_await;
}
if (unlikely(getter)) goto slot_error;
res = (*getter)(o);
#else
PyObject *method = __Pyx_PyObject_GetAttrStr(o, PYIDENT("__await__"));
if (unlikely(!method)) goto slot_error;
#if CYTHON_COMPILING_IN_CPYTHON
if (likely(PyMethod_Check(method))) {
PyObject *self = PyMethod_GET_SELF(method);
if (likely(self)) {
PyObject *function = PyMethod_GET_FUNCTION(method);
res = __Pyx_PyObject_CallOneArg(function, self);
} else
res = __Pyx_PyObject_CallNoArg(method);
} else
#endif
res = __Pyx_PyObject_CallNoArg(method);
Py_DECREF(method);
#endif
if (unlikely(!res)) goto bad;
if (!PyIter_Check(res)) {
PyErr_Format(PyExc_TypeError,
"__await__() returned non-iterator of type '%.100s'",
Py_TYPE(res)->tp_name);
Py_CLEAR(res);
} else {
int is_coroutine = 0;
#ifdef __Pyx_Coroutine_USED
is_coroutine |= __Pyx_Coroutine_CheckExact(res);
#endif
#if PY_VERSION_HEX >= 0x030500B1
is_coroutine |= PyGen_CheckCoroutineExact(res);
#endif
if (unlikely(is_coroutine)) {
/* __await__ must return an *iterator*, not
a coroutine or another awaitable (see PEP 492) */
PyErr_SetString(PyExc_TypeError,
"__await__() returned a coroutine");
Py_CLEAR(res);
}
}
return res;
slot_error:
PyErr_Format(PyExc_TypeError,
"object %.100s can't be used in 'await' expression",
Py_TYPE(o)->tp_name);
bad:
return NULL;
}
//////////////////// pep479.proto ////////////////////
static void __Pyx_Generator_Replace_StopIteration(void); /*proto*/
......@@ -107,6 +226,7 @@ static int __pyx_Generator_init(void);
#include <structmember.h>
#include <frameobject.h>
static PyObject *__Pyx_Generator_Next(PyObject *self);
static PyObject *__Pyx_Coroutine_Send(PyObject *self, PyObject *value);
static PyObject *__Pyx_Coroutine_Close(PyObject *self);
static PyObject *__Pyx_Coroutine_Throw(PyObject *gen, PyObject *args);
......@@ -403,6 +523,28 @@ static int __Pyx_Coroutine_CloseIter(__pyx_CoroutineObject *gen, PyObject *yf) {
return err;
}
static PyObject *__Pyx_Generator_Next(PyObject *self) {
__pyx_CoroutineObject *gen = (__pyx_CoroutineObject*) self;
PyObject *yf = gen->yieldfrom;
if (unlikely(__Pyx_Coroutine_CheckRunning(gen)))
return NULL;
if (yf) {
PyObject *ret;
// FIXME: does this really need an INCREF() ?
//Py_INCREF(yf);
// YieldFrom code ensures that yf is an iterator
gen->is_running = 1;
ret = Py_TYPE(yf)->tp_iternext(yf);
gen->is_running = 0;
//Py_DECREF(yf);
if (likely(ret)) {
return ret;
}
return __Pyx_Coroutine_FinishDelegation(gen);
}
return __Pyx_Coroutine_SendEx(gen, Py_None);
}
static PyObject *__Pyx_Coroutine_Close(PyObject *self) {
__pyx_CoroutineObject *gen = (__pyx_CoroutineObject *) self;
PyObject *retval, *raised_exception;
......@@ -729,6 +871,14 @@ static __pyx_CoroutineObject *__Pyx__Coroutine_New(PyTypeObject* type, __pyx_cor
//@requires: CoroutineBase
//@requires: PatchGeneratorABC
#if PY_VERSION_HEX >= 0x030500B1
static PyAsyncMethods __pyx_Coroutine_as_async {
0, /*am_await*/
0, /*am_aiter*/
0, /*am_anext*/
}
#endif
static PyTypeObject __pyx_CoroutineType_type = {
PyVarObject_HEAD_INIT(0, 0)
"coroutine", /*tp_name*/
......@@ -738,10 +888,10 @@ static PyTypeObject __pyx_CoroutineType_type = {
0, /*tp_print*/
0, /*tp_getattr*/
0, /*tp_setattr*/
#if PY_MAJOR_VERSION < 3
0, /*tp_compare*/
#if PY_VERSION_HEX >= 0x030500B1
__pyx_Coroutine_as_async, /*tp_as_async*/
#else
0, /*reserved*/
0, /*tp_reserved resp. tp_compare*/
#endif
0, /*tp_repr*/
0, /*tp_as_number*/
......@@ -759,8 +909,9 @@ static PyTypeObject __pyx_CoroutineType_type = {
0, /*tp_clear*/
0, /*tp_richcompare*/
offsetof(__pyx_CoroutineObject, gi_weakreflist), /*tp_weaklistoffset*/
// no tp_iter() as iterator is only available through __await__()
0, /*tp_iter*/
0, /*tp_iternext*/
(iternextfunc) __Pyx_Generator_Next, /*tp_iternext*/
__pyx_Coroutine_methods, /*tp_methods*/
__pyx_Coroutine_memberlist, /*tp_members*/
__pyx_Coroutine_getsets, /*tp_getset*/
......@@ -790,7 +941,7 @@ static PyTypeObject __pyx_CoroutineType_type = {
#endif
};
static int __pyx_Generator_init(void) {
static int __pyx_Coroutine_init(void) {
// on Windows, C-API functions can't be used in slots statically
__pyx_CoroutineType_type.tp_getattro = PyObject_GenericGetAttr;
......@@ -801,35 +952,10 @@ static int __pyx_Generator_init(void) {
return 0;
}
//////////////////// Generator ////////////////////
//@requires: CoroutineBase
//@requires: PatchGeneratorABC
static PyObject *__Pyx_Generator_Next(PyObject *self);
static PyObject *__Pyx_Generator_Next(PyObject *self) {
__pyx_CoroutineObject *gen = (__pyx_CoroutineObject*) self;
PyObject *yf = gen->yieldfrom;
if (unlikely(__Pyx_Coroutine_CheckRunning(gen)))
return NULL;
if (yf) {
PyObject *ret;
// FIXME: does this really need an INCREF() ?
//Py_INCREF(yf);
// YieldFrom code ensures that yf is an iterator
gen->is_running = 1;
ret = Py_TYPE(yf)->tp_iternext(yf);
gen->is_running = 0;
//Py_DECREF(yf);
if (likely(ret)) {
return ret;
}
return __Pyx_Coroutine_FinishDelegation(gen);
}
return __Pyx_Coroutine_SendEx(gen, Py_None);
}
static PyTypeObject __pyx_GeneratorType_type = {
PyVarObject_HEAD_INIT(0, 0)
"generator", /*tp_name*/
......
# cython: language_level=3, binding=True
import gc
import sys
import types
import inspect
import unittest
import warnings
import contextlib
class AsyncYieldFrom:
def __init__(self, obj):
self.obj = obj
def __await__(self):
yield from self.obj
class AsyncYield:
def __init__(self, value):
self.value = value
def __await__(self):
yield self.value
def run_async(coro):
#assert coro.__class__ is types.GeneratorType
assert coro.__class__.__name__ == 'coroutine'
buffer = []
result = None
while True:
try:
buffer.append(coro.send(None))
except StopIteration as ex:
result = ex.args[0] if ex.args else None
break
return buffer, result
@contextlib.contextmanager
def silence_coro_gc():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
yield
gc.collect()
class AsyncBadSyntaxTest(unittest.TestCase):
def test_badsyntax_1(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async1
def test_badsyntax_2(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async2
def test_badsyntax_3(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async3
def test_badsyntax_4(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async4
def test_badsyntax_5(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async5
def test_badsyntax_6(self):
with self.assertRaisesRegex(
SyntaxError, "'yield' inside async function"):
import test.badsyntax_async6
def test_badsyntax_7(self):
with self.assertRaisesRegex(
SyntaxError, "'yield from' inside async function"):
import test.badsyntax_async7
def test_badsyntax_8(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async8
def test_badsyntax_9(self):
with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
import test.badsyntax_async9
class TokenizerRegrTest(unittest.TestCase):
def test_oneline_defs(self):
buf = []
for i in range(500):
buf.append('def i{i}(): return {i}'.format(i=i))
buf = '\n'.join(buf)
# Test that 500 consequent, one-line defs is OK
ns = {}
exec(buf, ns, ns)
self.assertEqual(ns['i499'](), 499)
# Test that 500 consequent, one-line defs *and*
# one 'async def' following them is OK
buf += '\nasync def foo():\n return'
ns = {}
exec(buf, ns, ns)
self.assertEqual(ns['i499'](), 499)
self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
class CoroutineTest(unittest.TestCase):
@contextlib.contextmanager
def assertRaisesRegex(self, exc_type, regex):
# the error messages usually don't match, so we just ignore them
try:
yield
except exc_type:
self.assertTrue(True)
else:
self.assertTrue(False)
def test_gen_1(self):
def gen(): yield
self.assertFalse(hasattr(gen, '__await__'))
def test_func_1(self):
async def foo():
return 10
f = foo()
self.assertEqual(f.__class__.__name__, 'coroutine')
#self.assertIsInstance(f, types.GeneratorType)
#self.assertTrue(bool(foo.__code__.co_flags & 0x80))
#self.assertTrue(bool(foo.__code__.co_flags & 0x20))
#self.assertTrue(bool(f.gi_code.co_flags & 0x80))
#self.assertTrue(bool(f.gi_code.co_flags & 0x20))
self.assertEqual(run_async(f), ([], 10))
def bar(): pass
self.assertFalse(bool(bar.__code__.co_flags & 0x80))
def test_func_2(self):
async def foo():
raise StopIteration
with self.assertRaisesRegex(
RuntimeError, "generator raised StopIteration"):
run_async(foo())
def test_func_3(self):
async def foo():
raise StopIteration
with silence_coro_gc():
self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
def test_func_4(self):
async def foo():
raise StopIteration
check = lambda: self.assertRaisesRegex(
TypeError, "coroutine-objects do not support iteration")
with check():
list(foo())
with check():
tuple(foo())
with check():
sum(foo())
with check():
iter(foo())
with check():
next(foo())
with silence_coro_gc(), check():
for i in foo():
pass
with silence_coro_gc(), check():
[i for i in foo()]
def test_func_5(self):
@types.coroutine
def bar():
yield 1
async def foo():
await bar()
check = lambda: self.assertRaisesRegex(
TypeError, "coroutine-objects do not support iteration")
with check():
for el in foo(): pass
# the following should pass without an error
for el in bar():
self.assertEqual(el, 1)
self.assertEqual([el for el in bar()], [1])
self.assertEqual(tuple(bar()), (1,))
self.assertEqual(next(iter(bar())), 1)
def test_func_6(self):
@types.coroutine
def bar():
yield 1
yield 2
async def foo():
await bar()
f = foo()
self.assertEqual(f.send(None), 1)
self.assertEqual(f.send(None), 2)
with self.assertRaises(StopIteration):
f.send(None)
def test_func_7(self):
async def bar():
return 10
def foo():
yield from bar()
with silence_coro_gc(), self.assertRaisesRegex(
TypeError,
"cannot 'yield from' a coroutine object from a generator"):
list(foo())
def test_func_8(self):
@types.coroutine
def bar():
return (yield from foo())
async def foo():
return 'spam'
self.assertEqual(run_async(bar()), ([], 'spam') )
def test_func_9(self):
async def foo(): pass
with self.assertWarnsRegex(
RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
foo()
gc.collect()
def test_await_1(self):
async def foo():
await 1
with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
run_async(foo())
def test_await_2(self):
async def foo():
await []
with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
run_async(foo())
def test_await_3(self):
async def foo():
await AsyncYieldFrom([1, 2, 3])
self.assertEqual(run_async(foo()), ([1, 2, 3], None))
def test_await_4(self):
async def bar():
return 42
async def foo():
return await bar()
self.assertEqual(run_async(foo()), ([], 42))
def test_await_5(self):
class Awaitable:
def __await__(self):
return
async def foo():
return (await Awaitable())
with self.assertRaisesRegex(
TypeError, "__await__.*returned non-iterator of type"):
run_async(foo())
def test_await_6(self):
class Awaitable:
def __await__(self):
return iter([52])
async def foo():
return (await Awaitable())
self.assertEqual(run_async(foo()), ([52], None))
def test_await_7(self):
class Awaitable:
def __await__(self):
yield 42
return 100
async def foo():
return (await Awaitable())
self.assertEqual(run_async(foo()), ([42], 100))
def test_await_8(self):
class Awaitable:
pass
async def foo():
return (await Awaitable())
with self.assertRaisesRegex(
TypeError, "object Awaitable can't be used in 'await' expression"):
run_async(foo())
def test_await_9(self):
def wrap():
return bar
async def bar():
return 42
async def foo():
b = bar()
db = {'b': lambda: wrap}
class DB:
b = staticmethod(wrap)
return (await bar() + await wrap()() + await db['b']()()() +
await bar() * 1000 + await DB.b()())
async def foo2():
return -await bar()
self.assertEqual(run_async(foo()), ([], 42168))
self.assertEqual(run_async(foo2()), ([], -42))
def test_await_10(self):
async def baz():
return 42
async def bar():
return baz()
async def foo():
return await (await bar())
self.assertEqual(run_async(foo()), ([], 42))
def test_await_11(self):
def ident(val):
return val
async def bar():
return 'spam'
async def foo():
return ident(val=await bar())
async def foo2():
return await bar(), 'ham'
self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
def test_await_12(self):
async def coro():
return 'spam'
class Awaitable:
def __await__(self):
return coro()
async def foo():
return await Awaitable()
with self.assertRaisesRegex(
TypeError, "__await__\(\) returned a coroutine"):
run_async(foo())
def test_await_13(self):
class Awaitable:
def __await__(self):
return self
async def foo():
return await Awaitable()
with self.assertRaisesRegex(
TypeError, "__await__.*returned non-iterator of type"):
run_async(foo())
def test_with_1(self):
class Manager:
def __init__(self, name):
self.name = name
async def __aenter__(self):
await AsyncYieldFrom(['enter-1-' + self.name,
'enter-2-' + self.name])
return self
async def __aexit__(self, *args):
await AsyncYieldFrom(['exit-1-' + self.name,
'exit-2-' + self.name])
if self.name == 'B':
return True
async def foo():
async with Manager("A") as a, Manager("B") as b:
await AsyncYieldFrom([('managers', a.name, b.name)])
1/0
f = foo()
result, _ = run_async(f)
self.assertEqual(
result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
('managers', 'A', 'B'),
'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
)
async def foo():
async with Manager("A") as a, Manager("C") as c:
await AsyncYieldFrom([('managers', a.name, c.name)])
1/0
with self.assertRaises(ZeroDivisionError):
run_async(foo())
def test_with_2(self):
class CM:
def __aenter__(self):
pass
async def foo():
async with CM():
pass
with self.assertRaisesRegex(AttributeError, '__aexit__'):
run_async(foo())
def test_with_3(self):
class CM:
def __aexit__(self):
pass
async def foo():
async with CM():
pass
with self.assertRaisesRegex(AttributeError, '__aenter__'):
run_async(foo())
def test_with_4(self):
class CM:
def __enter__(self):
pass
def __exit__(self):
pass
async def foo():
async with CM():
pass
with self.assertRaisesRegex(AttributeError, '__aexit__'):
run_async(foo())
def test_with_5(self):
# While this test doesn't make a lot of sense,
# it's a regression test for an early bug with opcodes
# generation
class CM:
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
pass
async def func():
async with CM():
assert (1, ) == 1
with self.assertRaises(AssertionError):
run_async(func())
def test_with_6(self):
class CM:
def __aenter__(self):
return 123
def __aexit__(self, *e):
return 456
async def foo():
async with CM():
pass
with self.assertRaisesRegex(
TypeError, "object int can't be used in 'await' expression"):
# it's important that __aexit__ wasn't called
run_async(foo())
def test_with_7(self):
class CM:
async def __aenter__(self):
return self
def __aexit__(self, *e):
return 444
async def foo():
async with CM():
1/0
try:
run_async(foo())
except TypeError as exc:
self.assertRegex(
exc.args[0], "object int can't be used in 'await' expression")
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
else:
self.fail('invalid asynchronous context manager did not fail')
def test_with_8(self):
CNT = 0
class CM:
async def __aenter__(self):
return self
def __aexit__(self, *e):
return 456
async def foo():
nonlocal CNT
async with CM():
CNT += 1
with self.assertRaisesRegex(
TypeError, "object int can't be used in 'await' expression"):
run_async(foo())
self.assertEqual(CNT, 1)
def test_with_9(self):
CNT = 0
class CM:
async def __aenter__(self):
return self
async def __aexit__(self, *e):
1/0
async def foo():
nonlocal CNT
async with CM():
CNT += 1
with self.assertRaises(ZeroDivisionError):
run_async(foo())
self.assertEqual(CNT, 1)
def test_with_10(self):
CNT = 0
class CM:
async def __aenter__(self):
return self
async def __aexit__(self, *e):
1/0
async def foo():
nonlocal CNT
async with CM():
async with CM():
raise RuntimeError
try:
run_async(foo())
except ZeroDivisionError as exc:
self.assertTrue(exc.__context__ is not None)
self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
self.assertTrue(isinstance(exc.__context__.__context__,
RuntimeError))
else:
self.fail('exception from __aexit__ did not propagate')
def test_with_11(self):
CNT = 0
class CM:
async def __aenter__(self):
raise NotImplementedError
async def __aexit__(self, *e):
1/0
async def foo():
nonlocal CNT
async with CM():
raise RuntimeError
try:
run_async(foo())
except NotImplementedError as exc:
self.assertTrue(exc.__context__ is None)
else:
self.fail('exception from __aenter__ did not propagate')
def test_with_12(self):
CNT = 0
class CM:
async def __aenter__(self):
return self
async def __aexit__(self, *e):
return True
async def foo():
nonlocal CNT
async with CM() as cm:
self.assertIs(cm.__class__, CM)
raise RuntimeError
run_async(foo())
def test_with_13(self):
CNT = 0
class CM:
async def __aenter__(self):
1/0
async def __aexit__(self, *e):
return True
async def foo():
nonlocal CNT
CNT += 1
async with CM():
CNT += 1000
CNT += 10000
with self.assertRaises(ZeroDivisionError):
run_async(foo())
self.assertEqual(CNT, 1)
def test_for_1(self):
aiter_calls = 0
class AsyncIter:
def __init__(self):
self.i = 0
async def __aiter__(self):
nonlocal aiter_calls
aiter_calls += 1
return self
async def __anext__(self):
self.i += 1
if not (self.i % 10):
await AsyncYield(self.i * 10)
if self.i > 100:
raise StopAsyncIteration
return self.i, self.i
buffer = []
async def test1():
async for i1, i2 in AsyncIter():
buffer.append(i1 + i2)
yielded, _ = run_async(test1())
# Make sure that __aiter__ was called only once
self.assertEqual(aiter_calls, 1)
self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
self.assertEqual(buffer, [i*2 for i in range(1, 101)])
buffer = []
async def test2():
nonlocal buffer
async for i in AsyncIter():
buffer.append(i[0])
if i[0] == 20:
break
else:
buffer.append('what?')
buffer.append('end')
yielded, _ = run_async(test2())
# Make sure that __aiter__ was called only once
self.assertEqual(aiter_calls, 2)
self.assertEqual(yielded, [100, 200])
self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
buffer = []
async def test3():
nonlocal buffer
async for i in AsyncIter():
if i[0] > 20:
continue
buffer.append(i[0])
else:
buffer.append('what?')
buffer.append('end')
yielded, _ = run_async(test3())
# Make sure that __aiter__ was called only once
self.assertEqual(aiter_calls, 3)
self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
self.assertEqual(buffer, [i for i in range(1, 21)] +
['what?', 'end'])
def test_for_2(self):
tup = (1, 2, 3)
refs_before = sys.getrefcount(tup)
async def foo():
async for i in tup:
print('never going to happen')
with self.assertRaisesRegex(
TypeError, "async for' requires an object.*__aiter__.*tuple"):
run_async(foo())
self.assertEqual(sys.getrefcount(tup), refs_before)
def test_for_3(self):
class I:
def __aiter__(self):
return self
aiter = I()
refs_before = sys.getrefcount(aiter)
async def foo():
async for i in aiter:
print('never going to happen')
with self.assertRaisesRegex(
TypeError,
"async for' received an invalid object.*__aiter.*\: I"):
run_async(foo())
self.assertEqual(sys.getrefcount(aiter), refs_before)
def test_for_4(self):
class I:
async def __aiter__(self):
return self
def __anext__(self):
return ()
aiter = I()
refs_before = sys.getrefcount(aiter)
async def foo():
async for i in aiter:
print('never going to happen')
with self.assertRaisesRegex(
TypeError,
"async for' received an invalid object.*__anext__.*tuple"):
run_async(foo())
self.assertEqual(sys.getrefcount(aiter), refs_before)
def test_for_5(self):
class I:
async def __aiter__(self):
return self
def __anext__(self):
return 123
async def foo():
async for i in I():
print('never going to happen')
with self.assertRaisesRegex(
TypeError,
"async for' received an invalid object.*__anext.*int"):
run_async(foo())
def test_for_6(self):
I = 0
class Manager:
async def __aenter__(self):
nonlocal I
I += 10000
async def __aexit__(self, *args):
nonlocal I
I += 100000
class Iterable:
def __init__(self):
self.i = 0
async def __aiter__(self):
return self
async def __anext__(self):
if self.i > 10:
raise StopAsyncIteration
self.i += 1
return self.i
##############
manager = Manager()
iterable = Iterable()
mrefs_before = sys.getrefcount(manager)
irefs_before = sys.getrefcount(iterable)
async def main():
nonlocal I
async with manager:
async for i in iterable:
I += 1
I += 1000
run_async(main())
self.assertEqual(I, 111011)
self.assertEqual(sys.getrefcount(manager), mrefs_before)
self.assertEqual(sys.getrefcount(iterable), irefs_before)
##############
async def main():
nonlocal I
async with Manager():
async for i in Iterable():
I += 1
I += 1000
async with Manager():
async for i in Iterable():
I += 1
I += 1000
run_async(main())
self.assertEqual(I, 333033)
##############
async def main():
nonlocal I
async with Manager():
I += 100
async for i in Iterable():
I += 1
else:
I += 10000000
I += 1000
async with Manager():
I += 100
async for i in Iterable():
I += 1
else:
I += 10000000
I += 1000
run_async(main())
self.assertEqual(I, 20555255)
def test_for_7(self):
CNT = 0
class AI:
async def __aiter__(self):
1/0
async def foo():
nonlocal CNT
async for i in AI():
CNT += 1
CNT += 10
with self.assertRaises(ZeroDivisionError):
run_async(foo())
self.assertEqual(CNT, 0)
class CoroAsyncIOCompatTest(unittest.TestCase):
def test_asyncio_1(self):
import asyncio
class MyException(Exception):
pass
buffer = []
class CM:
async def __aenter__(self):
buffer.append(1)
await asyncio.sleep(0.01)
buffer.append(2)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await asyncio.sleep(0.01)
buffer.append(exc_type.__name__)
async def f():
async with CM() as c:
await asyncio.sleep(0.01)
raise MyException
buffer.append('unreachable')
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(f())
except MyException:
pass
finally:
loop.close()
asyncio.set_event_loop(None)
self.assertEqual(buffer, [1, 2, 'MyException'])
class SysSetCoroWrapperTest(unittest.TestCase):
def test_set_wrapper_1(self):
async def foo():
return 'spam'
wrapped = None
def wrap(gen):
nonlocal wrapped
wrapped = gen
return gen
self.assertIsNone(sys.get_coroutine_wrapper())
sys.set_coroutine_wrapper(wrap)
self.assertIs(sys.get_coroutine_wrapper(), wrap)
try:
f = foo()
self.assertTrue(wrapped)
self.assertEqual(run_async(f), ([], 'spam'))
finally:
sys.set_coroutine_wrapper(None)
self.assertIsNone(sys.get_coroutine_wrapper())
wrapped = None
with silence_coro_gc():
foo()
self.assertFalse(wrapped)
def test_set_wrapper_2(self):
self.assertIsNone(sys.get_coroutine_wrapper())
with self.assertRaisesRegex(TypeError, "callable expected, got int"):
sys.set_coroutine_wrapper(1)
self.assertIsNone(sys.get_coroutine_wrapper())
class CAPITest(unittest.TestCase):
def test_tp_await_1(self):
from _testcapi import awaitType as at
async def foo():
future = at(iter([1]))
return (await future)
self.assertEqual(foo().send(None), 1)
def test_tp_await_2(self):
# Test tp_await to __await__ mapping
from _testcapi import awaitType as at
future = at(iter([1]))
self.assertEqual(next(future.__await__()), 1)
def test_tp_await_3(self):
from _testcapi import awaitType as at
async def foo():
future = at(1)
return (await future)
with self.assertRaisesRegex(
TypeError, "__await__.*returned non-iterator of type 'int'"):
self.assertEqual(foo().send(None), 1)
# disable some tests that only apply to CPython
CAPITest = None # no CAPI module
if sys.version_info < (3, 5):
SysSetCoroWrapperTest = None
if __name__=="__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment