Commit 30b77186 authored by Vitja Makarov's avatar Vitja Makarov

Experimental support for generators

parent 951baff4
...@@ -4938,8 +4938,9 @@ class LambdaNode(InnerFunctionNode): ...@@ -4938,8 +4938,9 @@ class LambdaNode(InnerFunctionNode):
self.pymethdef_cname = self.def_node.entry.pymethdef_cname self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node) env.add_lambda_def(self.def_node)
class YieldExprNode(ExprNode):
# Yield expression node class OldYieldExprNode(ExprNode):
# XXX: remove me someday
# #
# arg ExprNode the value to return from the generator # arg ExprNode the value to return from the generator
# label_name string name of the C label used for this yield # label_name string name of the C label used for this yield
...@@ -4964,6 +4965,72 @@ class YieldExprNode(ExprNode): ...@@ -4964,6 +4965,72 @@ class YieldExprNode(ExprNode):
code.putln("/* FIXME: restore temporary variables and */") code.putln("/* FIXME: restore temporary variables and */")
code.putln("/* FIXME: extract sent value from closure */") code.putln("/* FIXME: extract sent value from closure */")
class YieldExprNode(ExprNode):
# Yield expression node
#
# arg ExprNode the value to return from the generator
# label_name string name of the C label used for this yield
subexprs = ['arg']
type = py_object_type
def analyse_types(self, env):
self.is_temp = 1
if self.arg is not None:
self.arg.analyse_types(env)
if not self.arg.type.is_pyobject:
self.arg = self.arg.coerce_to_pyobject(env)
env.use_utility_code(generator_utility_code)
def generate_evaluation_code(self, code):
saved = []
self.temp_allocator.reset()
code.putln('/* Save temporary variables */')
for cname, type, manage_ref in code.funcstate.temps_in_use():
save_cname = self.temp_allocator.allocate_temp(type)
saved.append((cname, save_cname, type))
code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname))
if type.is_pyobject:
code.put_giveref(cname)
self.label_name = code.new_label('resume_from_yield')
code.use_label(self.label_name)
self.allocate_temp_result(code)
if self.arg:
self.arg.generate_evaluation_code(code)
self.arg.make_owned_reference(code)
code.putln(
"%s = %s;" % (
Naming.retval_cname,
self.arg.result_as(py_object_type)))
self.arg.generate_post_assignment_code(code)
#self.arg.generate_disposal_code(code)
self.arg.free_temps(code)
else:
code.put_init_to_py_none(Naming.retval_cname, py_object_type)
code.put_finish_refcount_context()
code.putln("/* return from function, yielding value */")
code.putln("%s->%s.resume_label = %d;" % (Naming.cur_scope_cname, Naming.obj_base_cname, self.label_num))
code.putln("return %s;" % Naming.retval_cname);
code.put_label(self.label_name)
code.putln('/* Restore temporary variables */')
for cname, save_cname, type in saved:
code.putln('%s = %s->%s;' % (cname, Naming.cur_scope_cname, save_cname))
if type.is_pyobject:
code.putln('%s->%s = 0;' % (Naming.cur_scope_cname, save_cname))
code.put_gotref(cname)
code.putln('%s = __pyx_send_value;' % self.result())
code.put_incref(self.result(), py_object_type)
class StopIterationNode(YieldExprNode):
subexprs = []
def generate_evaluation_code(self, code):
self.allocate_temp_result(code)
self.label_name = code.new_label('resume_from_yield')
code.use_label(self.label_name)
code.put_label(self.label_name)
code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
...@@ -8230,3 +8297,53 @@ int %(binding_cfunc)s_init(void) { ...@@ -8230,3 +8297,53 @@ int %(binding_cfunc)s_init(void) {
} }
""" % Naming.__dict__) """ % Naming.__dict__)
generator_utility_code = UtilityCode(
proto="""
static PyObject *__CyGenerator_Next(PyObject *self);
static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value);
typedef PyObject *(*__cygenerator_body_t)(PyObject *, PyObject *, int);
""",
impl="""
static CYTHON_INLINE PyObject *__CyGenerator_SendEx(struct __CyGenerator *self, PyObject *value, int is_exc)
{
PyObject *retval;
if (self->is_running) {
PyErr_SetString(PyExc_ValueError,
"generator already executing");
return NULL;
}
if (self->resume_label == 0) {
if (value && value != Py_None) {
PyErr_SetString(PyExc_TypeError,
"can't send non-None value to a "
"just-started generator");
return NULL;
}
}
self->is_running = 1;
retval = self->body((PyObject *) self, value, is_exc);
self->is_running = 0;
return retval;
}
static PyObject *__CyGenerator_Next(PyObject *self)
{
struct __CyGenerator *generator = (struct __CyGenerator *) self;
PyObject *retval;
Py_INCREF(Py_None);
retval = __CyGenerator_SendEx(generator, Py_None, 0);
Py_DECREF(Py_None);
return retval;
}
static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value)
{
return __CyGenerator_SendEx((struct __CyGenerator *) self, value, 0);
}
""", proto_block='utility_code_proto_before_types')
...@@ -97,6 +97,7 @@ class Context(object): ...@@ -97,6 +97,7 @@ class Context(object):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import MarkGeneratorVisitor
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
...@@ -129,6 +130,7 @@ class Context(object): ...@@ -129,6 +130,7 @@ class Context(object):
InterpretCompilerDirectives(self, self.compiler_directives), InterpretCompilerDirectives(self, self.compiler_directives),
_align_function_definitions, _align_function_definitions,
MarkClosureVisitor(self), MarkClosureVisitor(self),
MarkGeneratorVisitor(self),
ConstantFolding(), ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(self), WithTransform(self),
......
...@@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_" ...@@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_"
enum_prefix = pyrex_prefix + "e_" enum_prefix = pyrex_prefix + "e_"
func_prefix = pyrex_prefix + "f_" func_prefix = pyrex_prefix + "f_"
pyfunc_prefix = pyrex_prefix + "pf_" pyfunc_prefix = pyrex_prefix + "pf_"
genbody_prefix = pyrex_prefix + "gb_"
gstab_prefix = pyrex_prefix + "getsets_" gstab_prefix = pyrex_prefix + "getsets_"
prop_get_prefix = pyrex_prefix + "getprop_" prop_get_prefix = pyrex_prefix + "getprop_"
const_prefix = pyrex_prefix + "k_" const_prefix = pyrex_prefix + "k_"
......
...@@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode):
assmt = None assmt = None
needs_closure = False needs_closure = False
needs_outer_scope = False needs_outer_scope = False
is_generator = False
modifiers = [] modifiers = []
def analyse_default_values(self, env): def analyse_default_values(self, env):
...@@ -1251,7 +1252,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1251,7 +1252,7 @@ class FuncDefNode(StatNode, BlockNode):
# Generate C code for header and body of function # Generate C code for header and body of function
code.enter_cfunc_scope() code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label() code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function # ----- Top-level constants used by this function
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.generate_cached_builtins_decls(lenv, code) self.generate_cached_builtins_decls(lenv, code)
...@@ -1295,7 +1296,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1295,7 +1296,8 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code(Naming.retval_cname), (self.return_type.declaration_code(Naming.retval_cname),
init)) init))
tempvardecl_code = code.insertion_point() tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code) if not self.is_generator:
self.generate_keyword_list(code)
if profile: if profile:
code.put_trace_declarations() code.put_trace_declarations()
# ----- Extern library function declarations # ----- Extern library function declarations
...@@ -1314,7 +1316,12 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1314,7 +1316,12 @@ class FuncDefNode(StatNode, BlockNode):
if is_getbuffer_slot: if is_getbuffer_slot:
self.getbuffer_init(code) self.getbuffer_init(code)
# ----- Create closure scope object # ----- Create closure scope object
if self.needs_closure: if self.is_generator:
code.putln("%s = (%s) %s;" % (
Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''),
Naming.self_cname))
elif self.needs_closure:
code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % ( code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
Naming.cur_scope_cname, Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''), lenv.scope_class.type.declaration_code(''),
...@@ -1331,7 +1338,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1331,7 +1338,7 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}") code.putln("}")
code.put_gotref(Naming.cur_scope_cname) code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point. # Note that it is unsafe to decref the scope at this point.
if self.needs_outer_scope: if self.needs_outer_scope and not self.is_generator:
code.putln("%s = (%s)%s;" % ( code.putln("%s = (%s)%s;" % (
outer_scope_cname, outer_scope_cname,
cenv.scope_class.type.declaration_code(''), cenv.scope_class.type.declaration_code(''),
...@@ -1348,7 +1355,13 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1348,7 +1355,13 @@ class FuncDefNode(StatNode, BlockNode):
# fatal error before hand, it's not really worth tracing # fatal error before hand, it's not really worth tracing
code.put_trace_call(self.entry.name, self.pos) code.put_trace_call(self.entry.name, self.pos)
# ----- Fetch arguments # ----- Fetch arguments
self.generate_argument_parsing_code(env, code) if self.is_generator:
resume_code = code.insertion_point()
first_run_label = code.new_label('first_run')
code.use_label(first_run_label)
code.put_label(first_run_label)
if not self.is_generator:
self.generate_argument_parsing_code(env, code)
# If an argument is assigned to in the body, we must # If an argument is assigned to in the body, we must
# incref it to properly keep track of refcounts. # incref it to properly keep track of refcounts.
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
...@@ -1465,7 +1478,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1465,7 +1478,7 @@ class FuncDefNode(StatNode, BlockNode):
code.put_var_giveref(entry) code.put_var_giveref(entry)
elif entry.assignments: elif entry.assignments:
code.put_var_decref(entry) code.put_var_decref(entry)
if self.needs_closure: if self.needs_closure and not self.is_generator:
code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type) code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
# ----- Return # ----- Return
...@@ -1504,15 +1517,26 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1504,15 +1517,26 @@ class FuncDefNode(StatNode, BlockNode):
if preprocessor_guard: if preprocessor_guard:
code.putln("#endif /*!(%s)*/" % preprocessor_guard) code.putln("#endif /*!(%s)*/" % preprocessor_guard)
# ----- Go back and insert temp variable declarations # ----- Go back and insert temp variable declarations
tempvardecl_code.put_temp_declarations(code.funcstate) tempvardecl_code.put_temp_declarations(code.funcstate)
# ----- Generator resume code
if self.is_generator:
resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
resume_code.putln("case 0: goto %s;" % first_run_label)
for yield_expr in self.yields:
resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name));
resume_code.putln("default: /* raise error here */");
resume_code.putln("return NULL;");
resume_code.putln("}");
# ----- Python version # ----- Python version
code.exit_cfunc_scope() code.exit_cfunc_scope()
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code) self.py_func.generate_function_definitions(env, code)
self.generate_wrapper_functions(code) self.generate_wrapper_functions(code)
if self.is_generator:
self.generator.generate_function_body(self.local_scope, code)
def declare_argument(self, env, arg): def declare_argument(self, env, arg):
if arg.type.is_void: if arg.type.is_void:
error(arg.pos, "Invalid use of 'void'") error(arg.pos, "Invalid use of 'void'")
...@@ -1863,6 +1887,57 @@ class DecoratorNode(Node): ...@@ -1863,6 +1887,57 @@ class DecoratorNode(Node):
child_attrs = ['decorator'] child_attrs = ['decorator']
class GeneratorWrapperNode(object):
# Wrapper
def __init__(self, def_node, func_cname=None, body_cname=None, header=None):
self.def_node = def_node
self.func_cname = func_cname
self.body_cname = body_cname
self.header = header
def generate_function_body(self, env, code):
cenv = env.outer_scope # XXX: correct?
while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope
lenv = self.def_node.local_scope
code.enter_cfunc_scope()
code.putln()
code.putln('%s {' % self.header)
self.def_node.generate_keyword_list(code)
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
code.put_setup_refcount_context(self.def_node.entry.name)
code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''),
lenv.scope_class.type.typeptr_cname,
lenv.scope_class.type.typeptr_cname,
Naming.empty_tuple))
code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname)
code.put_finish_refcount_context()
code.putln("return NULL;");
code.putln("}");
code.put_gotref(Naming.cur_scope_cname)
if self.def_node.needs_outer_scope:
code.putln("%s->%s = (%s)%s;" % (
Naming.cur_scope_cname,
Naming.outer_scope_cname,
cenv.scope_class.type.declaration_code(''),
Naming.self_cname))
self.def_node.generate_argument_parsing_code(env, code)
generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
code.putln('%s.resume_label = 0;' % generator_cname)
code.putln('%s.body = (void *) %s;' % (generator_cname, self.body_cname))
code.put_giveref(Naming.cur_scope_cname)
code.put_finish_refcount_context()
code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
code.putln('}\n')
code.exit_cfunc_scope()
class DefNode(FuncDefNode): class DefNode(FuncDefNode):
# A Python function definition. # A Python function definition.
# #
...@@ -2156,6 +2231,10 @@ class DefNode(FuncDefNode): ...@@ -2156,6 +2231,10 @@ class DefNode(FuncDefNode):
Naming.pyfunc_prefix + prefix + name Naming.pyfunc_prefix + prefix + name
entry.pymethdef_cname = \ entry.pymethdef_cname = \
Naming.pymethdef_prefix + prefix + name Naming.pymethdef_prefix + prefix + name
if self.is_generator:
self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name
if Options.docstrings: if Options.docstrings:
entry.doc = embed_position(self.pos, self.doc) entry.doc = embed_position(self.pos, self.doc)
entry.doc_cname = \ entry.doc_cname = \
...@@ -2303,7 +2382,15 @@ class DefNode(FuncDefNode): ...@@ -2303,7 +2382,15 @@ class DefNode(FuncDefNode):
"static PyMethodDef %s = " % "static PyMethodDef %s = " %
self.entry.pymethdef_cname) self.entry.pymethdef_cname)
code.put_pymethoddef(self.entry, ";", allow_skip=False) code.put_pymethoddef(self.entry, ";", allow_skip=False)
code.putln("%s {" % header) if self.is_generator:
code.putln("static PyObject *%s(PyObject *%s, PyObject *__pyx_send_value, int __pyx_is_exc) /* generator body */\n{" %
(self.generator_body_cname, Naming.self_cname))
self.generator = GeneratorWrapperNode(self,
func_cname=self.entry.func_cname,
body_cname=self.generator_body_cname,
header=header)
else:
code.putln("%s {" % header)
def generate_argument_declarations(self, env, code): def generate_argument_declarations(self, env, code):
for arg in self.args: for arg in self.args:
......
...@@ -1166,7 +1166,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1166,7 +1166,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self.yield_nodes = [] self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node): def visit_OldYieldExprNode(self, node):
self.yield_nodes.append(node) self.yield_nodes.append(node)
self.visitchildren(node) self.visitchildren(node)
......
...@@ -1316,7 +1316,7 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1316,7 +1316,7 @@ class MarkClosureVisitor(CythonTransform):
node.needs_closure = self.needs_closure node.needs_closure = self.needs_closure
self.needs_closure = True self.needs_closure = True
return node return node
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node) self.visit_FuncDefNode(node)
if node.needs_closure: if node.needs_closure:
...@@ -1335,6 +1335,89 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1335,6 +1335,89 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True self.needs_closure = True
return node return node
class ClosureTempAllocator(object):
def __init__(self, klass=None):
self.klass = klass
self.temps_allocated = {}
self.temps_free = {}
self.temps_count = 0
def reset(self):
for type, cnames in self.temps_allocated:
self.temps_free[type] = list(cnames)
def allocate_temp(self, type):
if not type in self.temps_allocated:
self.temps_allocated[type] = []
self.temps_free[type] = []
if self.temps_free[type]:
return self.temps_free[type].pop(0)
cname = '%s_%d' % (Naming.codewriter_temp_prefix, self.temps_count)
self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True)
self.temps_allocated[type].append(cname)
self.temps_count += 1
return cname
class YieldCollector(object):
def __init__(self, node):
self.node = node
self.yields = []
self.returns = []
class MarkGeneratorVisitor(CythonTransform):
"""XXX: merge me with MarkClosureVisitor"""
def __init__(self, context):
super(MarkGeneratorVisitor, self).__init__(context)
self.allow_yield = False
self.path = []
def visit_ModuleNode(self, node):
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
saved = self.allow_yield
self.allow_yield = False
self.visitchildren(node)
self.allow_yield = saved
return node
def visit_FuncDefNode(self, node):
saved = self.allow_yield
self.allow_yield = True
self.path.append(YieldCollector(node))
self.visitchildren(node)
self.allow_yield = saved
collector = self.path.pop()
if collector.yields and collector.returns:
error(collector.returns[0].pos, "'return' with argument inside generator")
elif collector.yields:
allocator = ClosureTempAllocator()
stop_node = ExprNodes.StopIterationNode(node.pos, arg=None)
collector.yields.append(stop_node)
for y in collector.yields: # XXX: find a better way
y.temp_allocator = allocator
node.temp_allocator = allocator
stop_node.label_num = len(collector.yields)
node.body.stats.append(Nodes.ExprStatNode(node.pos, expr=stop_node))
node.is_generator = True
node.needs_closure = True
node.yields = collector.yields
return node
def visit_YieldExprNode(self, node):
if not self.allow_yield:
error(node.pos, "'yield' outside function")
return node
collector = self.path[-1]
collector.yields.append(node)
node.label_num = len(collector.yields)
return node
def visit_ReturnStatNode(self, node):
if self.path:
self.path[-1].returns.append(node)
return node
class CreateClosureClasses(CythonTransform): class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
...@@ -1344,12 +1427,57 @@ class CreateClosureClasses(CythonTransform): ...@@ -1344,12 +1427,57 @@ class CreateClosureClasses(CythonTransform):
super(CreateClosureClasses, self).__init__(context) super(CreateClosureClasses, self).__init__(context)
self.path = [] self.path = []
self.in_lambda = False self.in_lambda = False
self.generator_class = None
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.module_scope = node.scope self.module_scope = node.scope
self.visitchildren(node) self.visitchildren(node)
return node return node
def create_abstract_generator(self, target_module_scope, pos):
if self.generator_class:
return self.generator_class
# XXX: make generator class creation cleaner
entry = target_module_scope.declare_c_class(name='__CyGenerator',
objstruct_cname='__CyGenerator',
typeobj_cname='__CyGeneratorType',
pos=pos, defining=True, implementing=True)
entry.cname = 'CyGenerator'
klass = entry.type.scope
klass.is_internal = True
klass.directives = {'final': True}
body_type = PyrexTypes.create_typedef_type('generator_body',
PyrexTypes.c_void_ptr_type,
'__cygenerator_body_t')
klass.declare_var(pos=pos, name='body', cname='body',
type=body_type, is_cdef=True)
klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
is_cdef=True)
klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
is_cdef=True)
import TypeSlots
e = klass.declare_pyfunction('send', pos)
e.func_cname = '__CyGenerator_Send'
e.signature = TypeSlots.binaryfunc
#e = klass.declare_pyfunction('close', pos)
#e.func_cname = '__CyGenerator_Close'
#e.signature = TypeSlots.unaryfunc
#e = klass.declare_pyfunction('throw', pos)
#e.func_cname = '__CyGenerator_Throw'
e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = 'PyObject_SelfIter'
e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = '__CyGenerator_Next'
self.generator_class = entry.type
return self.generator_class
def get_scope_use(self, node): def get_scope_use(self, node):
from_closure = [] from_closure = []
in_closure = [] in_closure = []
...@@ -1361,6 +1489,12 @@ class CreateClosureClasses(CythonTransform): ...@@ -1361,6 +1489,12 @@ class CreateClosureClasses(CythonTransform):
return from_closure, in_closure return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None): def create_class_from_scope(self, node, target_module_scope, inner_node=None):
# move local variables into closure
if node.is_generator:
for entry in node.local_scope.entries.values():
if not entry.from_closure:
entry.in_closure = True
from_closure, in_closure = self.get_scope_use(node) from_closure, in_closure = self.get_scope_use(node)
in_closure.sort() in_closure.sort()
...@@ -1380,8 +1514,10 @@ class CreateClosureClasses(CythonTransform): ...@@ -1380,8 +1514,10 @@ class CreateClosureClasses(CythonTransform):
inner_node = node.assmt.rhs inner_node = node.assmt.rhs
inner_node.needs_self_code = False inner_node.needs_self_code = False
node.needs_outer_scope = False node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure: if node.is_generator:
generator_class = self.create_abstract_generator(target_module_scope, node.pos)
elif not in_closure and not from_closure:
return return
elif not in_closure: elif not in_closure:
func_scope.is_passthrough = True func_scope.is_passthrough = True
...@@ -1391,13 +1527,19 @@ class CreateClosureClasses(CythonTransform): ...@@ -1391,13 +1527,19 @@ class CreateClosureClasses(CythonTransform):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname) as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name, if node.is_generator:
pos = node.pos, defining = True, implementing = True) entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True, base_type=generator_class)
else:
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
func_scope.scope_class = entry func_scope.scope_class = entry
class_scope = entry.type.scope class_scope = entry.type.scope
class_scope.is_internal = True class_scope.is_internal = True
class_scope.directives = {'final': True} class_scope.directives = {'final': True}
if node.is_generator:
node.temp_allocator.klass = class_scope
if from_closure: if from_closure:
assert cscope.is_closure_scope assert cscope.is_closure_scope
class_scope.declare_var(pos=node.pos, class_scope.declare_var(pos=node.pos,
......
...@@ -1029,7 +1029,7 @@ def p_testlist_comp(s): ...@@ -1029,7 +1029,7 @@ def p_testlist_comp(s):
def p_genexp(s, expr): def p_genexp(s, expr):
# s.sy == 'for' # s.sy == 'for'
loop = p_comp_for(s, Nodes.ExprStatNode( loop = p_comp_for(s, Nodes.ExprStatNode(
expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr))) expr.pos, expr = ExprNodes.OldYieldExprNode(expr.pos, arg=expr)))
return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE') expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')
......
def simple():
"""
>>> x = simple()
>>> list(x)
[1, 2, 3]
"""
yield 1
yield 2
yield 3
def simple_seq(seq):
"""
>>> x = simple_seq("abc")
>>> list(x)
['a', 'b', 'c']
"""
for i in seq:
yield i
def simple_send():
"""
>>> x = simple_send()
>>> next(x)
>>> x.send(1)
1
>>> x.send(2)
2
>>> x.send(3)
3
"""
i = None
while True:
i = yield i
def with_outer(*args):
"""
>>> x = with_outer(1, 2, 3)
>>> list(x())
[1, 2, 3]
"""
def generator():
for i in args:
yield i
return generator
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment