Commit 30b77186 authored by Vitja Makarov's avatar Vitja Makarov

Experimental support for generators

parent 951baff4
......@@ -4938,8 +4938,9 @@ class LambdaNode(InnerFunctionNode):
self.pymethdef_cname = self.def_node.entry.pymethdef_cname
env.add_lambda_def(self.def_node)
class YieldExprNode(ExprNode):
# Yield expression node
class OldYieldExprNode(ExprNode):
# XXX: remove me someday
#
# arg ExprNode the value to return from the generator
# label_name string name of the C label used for this yield
......@@ -4964,6 +4965,72 @@ class YieldExprNode(ExprNode):
code.putln("/* FIXME: restore temporary variables and */")
code.putln("/* FIXME: extract sent value from closure */")
class YieldExprNode(ExprNode):
# Yield expression node
#
# arg ExprNode the value to return from the generator
# label_name string name of the C label used for this yield
subexprs = ['arg']
type = py_object_type
def analyse_types(self, env):
self.is_temp = 1
if self.arg is not None:
self.arg.analyse_types(env)
if not self.arg.type.is_pyobject:
self.arg = self.arg.coerce_to_pyobject(env)
env.use_utility_code(generator_utility_code)
def generate_evaluation_code(self, code):
saved = []
self.temp_allocator.reset()
code.putln('/* Save temporary variables */')
for cname, type, manage_ref in code.funcstate.temps_in_use():
save_cname = self.temp_allocator.allocate_temp(type)
saved.append((cname, save_cname, type))
code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname))
if type.is_pyobject:
code.put_giveref(cname)
self.label_name = code.new_label('resume_from_yield')
code.use_label(self.label_name)
self.allocate_temp_result(code)
if self.arg:
self.arg.generate_evaluation_code(code)
self.arg.make_owned_reference(code)
code.putln(
"%s = %s;" % (
Naming.retval_cname,
self.arg.result_as(py_object_type)))
self.arg.generate_post_assignment_code(code)
#self.arg.generate_disposal_code(code)
self.arg.free_temps(code)
else:
code.put_init_to_py_none(Naming.retval_cname, py_object_type)
code.put_finish_refcount_context()
code.putln("/* return from function, yielding value */")
code.putln("%s->%s.resume_label = %d;" % (Naming.cur_scope_cname, Naming.obj_base_cname, self.label_num))
code.putln("return %s;" % Naming.retval_cname);
code.put_label(self.label_name)
code.putln('/* Restore temporary variables */')
for cname, save_cname, type in saved:
code.putln('%s = %s->%s;' % (cname, Naming.cur_scope_cname, save_cname))
if type.is_pyobject:
code.putln('%s->%s = 0;' % (Naming.cur_scope_cname, save_cname))
code.put_gotref(cname)
code.putln('%s = __pyx_send_value;' % self.result())
code.put_incref(self.result(), py_object_type)
class StopIterationNode(YieldExprNode):
subexprs = []
def generate_evaluation_code(self, code):
self.allocate_temp_result(code)
self.label_name = code.new_label('resume_from_yield')
code.use_label(self.label_name)
code.put_label(self.label_name)
code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
#-------------------------------------------------------------------
#
......@@ -8230,3 +8297,53 @@ int %(binding_cfunc)s_init(void) {
}
""" % Naming.__dict__)
generator_utility_code = UtilityCode(
proto="""
static PyObject *__CyGenerator_Next(PyObject *self);
static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value);
typedef PyObject *(*__cygenerator_body_t)(PyObject *, PyObject *, int);
""",
impl="""
static CYTHON_INLINE PyObject *__CyGenerator_SendEx(struct __CyGenerator *self, PyObject *value, int is_exc)
{
PyObject *retval;
if (self->is_running) {
PyErr_SetString(PyExc_ValueError,
"generator already executing");
return NULL;
}
if (self->resume_label == 0) {
if (value && value != Py_None) {
PyErr_SetString(PyExc_TypeError,
"can't send non-None value to a "
"just-started generator");
return NULL;
}
}
self->is_running = 1;
retval = self->body((PyObject *) self, value, is_exc);
self->is_running = 0;
return retval;
}
static PyObject *__CyGenerator_Next(PyObject *self)
{
struct __CyGenerator *generator = (struct __CyGenerator *) self;
PyObject *retval;
Py_INCREF(Py_None);
retval = __CyGenerator_SendEx(generator, Py_None, 0);
Py_DECREF(Py_None);
return retval;
}
static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value)
{
return __CyGenerator_SendEx((struct __CyGenerator *) self, value, 0);
}
""", proto_block='utility_code_proto_before_types')
......@@ -97,6 +97,7 @@ class Context(object):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import MarkGeneratorVisitor
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic
......@@ -129,6 +130,7 @@ class Context(object):
InterpretCompilerDirectives(self, self.compiler_directives),
_align_function_definitions,
MarkClosureVisitor(self),
MarkGeneratorVisitor(self),
ConstantFolding(),
FlattenInListTransform(),
WithTransform(self),
......
......@@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_"
enum_prefix = pyrex_prefix + "e_"
func_prefix = pyrex_prefix + "f_"
pyfunc_prefix = pyrex_prefix + "pf_"
genbody_prefix = pyrex_prefix + "gb_"
gstab_prefix = pyrex_prefix + "getsets_"
prop_get_prefix = pyrex_prefix + "getprop_"
const_prefix = pyrex_prefix + "k_"
......
......@@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode):
assmt = None
needs_closure = False
needs_outer_scope = False
is_generator = False
modifiers = []
def analyse_default_values(self, env):
......@@ -1251,7 +1252,7 @@ class FuncDefNode(StatNode, BlockNode):
# Generate C code for header and body of function
code.enter_cfunc_scope()
code.return_from_error_cleanup_label = code.new_label()
# ----- Top-level constants used by this function
code.mark_pos(self.pos)
self.generate_cached_builtins_decls(lenv, code)
......@@ -1295,7 +1296,8 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code(Naming.retval_cname),
init))
tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code)
if not self.is_generator:
self.generate_keyword_list(code)
if profile:
code.put_trace_declarations()
# ----- Extern library function declarations
......@@ -1314,7 +1316,12 @@ class FuncDefNode(StatNode, BlockNode):
if is_getbuffer_slot:
self.getbuffer_init(code)
# ----- Create closure scope object
if self.needs_closure:
if self.is_generator:
code.putln("%s = (%s) %s;" % (
Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''),
Naming.self_cname))
elif self.needs_closure:
code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''),
......@@ -1331,7 +1338,7 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}")
code.put_gotref(Naming.cur_scope_cname)
# Note that it is unsafe to decref the scope at this point.
if self.needs_outer_scope:
if self.needs_outer_scope and not self.is_generator:
code.putln("%s = (%s)%s;" % (
outer_scope_cname,
cenv.scope_class.type.declaration_code(''),
......@@ -1348,7 +1355,13 @@ class FuncDefNode(StatNode, BlockNode):
# fatal error before hand, it's not really worth tracing
code.put_trace_call(self.entry.name, self.pos)
# ----- Fetch arguments
self.generate_argument_parsing_code(env, code)
if self.is_generator:
resume_code = code.insertion_point()
first_run_label = code.new_label('first_run')
code.use_label(first_run_label)
code.put_label(first_run_label)
if not self.is_generator:
self.generate_argument_parsing_code(env, code)
# If an argument is assigned to in the body, we must
# incref it to properly keep track of refcounts.
for entry in lenv.arg_entries:
......@@ -1465,7 +1478,7 @@ class FuncDefNode(StatNode, BlockNode):
code.put_var_giveref(entry)
elif entry.assignments:
code.put_var_decref(entry)
if self.needs_closure:
if self.needs_closure and not self.is_generator:
code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
# ----- Return
......@@ -1504,15 +1517,26 @@ class FuncDefNode(StatNode, BlockNode):
if preprocessor_guard:
code.putln("#endif /*!(%s)*/" % preprocessor_guard)
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_temp_declarations(code.funcstate)
# ----- Generator resume code
if self.is_generator:
resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
resume_code.putln("case 0: goto %s;" % first_run_label)
for yield_expr in self.yields:
resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name));
resume_code.putln("default: /* raise error here */");
resume_code.putln("return NULL;");
resume_code.putln("}");
# ----- Python version
code.exit_cfunc_scope()
if self.py_func:
self.py_func.generate_function_definitions(env, code)
self.generate_wrapper_functions(code)
if self.is_generator:
self.generator.generate_function_body(self.local_scope, code)
def declare_argument(self, env, arg):
if arg.type.is_void:
error(arg.pos, "Invalid use of 'void'")
......@@ -1863,6 +1887,57 @@ class DecoratorNode(Node):
child_attrs = ['decorator']
class GeneratorWrapperNode(object):
# Wrapper
def __init__(self, def_node, func_cname=None, body_cname=None, header=None):
self.def_node = def_node
self.func_cname = func_cname
self.body_cname = body_cname
self.header = header
def generate_function_body(self, env, code):
cenv = env.outer_scope # XXX: correct?
while cenv.is_py_class_scope or cenv.is_c_class_scope:
cenv = cenv.outer_scope
lenv = self.def_node.local_scope
code.enter_cfunc_scope()
code.putln()
code.putln('%s {' % self.header)
self.def_node.generate_keyword_list(code)
code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
code.putln(";")
code.put_setup_refcount_context(self.def_node.entry.name)
code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
Naming.cur_scope_cname,
lenv.scope_class.type.declaration_code(''),
lenv.scope_class.type.typeptr_cname,
lenv.scope_class.type.typeptr_cname,
Naming.empty_tuple))
code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname)
code.put_finish_refcount_context()
code.putln("return NULL;");
code.putln("}");
code.put_gotref(Naming.cur_scope_cname)
if self.def_node.needs_outer_scope:
code.putln("%s->%s = (%s)%s;" % (
Naming.cur_scope_cname,
Naming.outer_scope_cname,
cenv.scope_class.type.declaration_code(''),
Naming.self_cname))
self.def_node.generate_argument_parsing_code(env, code)
generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
code.putln('%s.resume_label = 0;' % generator_cname)
code.putln('%s.body = (void *) %s;' % (generator_cname, self.body_cname))
code.put_giveref(Naming.cur_scope_cname)
code.put_finish_refcount_context()
code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
code.putln('}\n')
code.exit_cfunc_scope()
class DefNode(FuncDefNode):
# A Python function definition.
#
......@@ -2156,6 +2231,10 @@ class DefNode(FuncDefNode):
Naming.pyfunc_prefix + prefix + name
entry.pymethdef_cname = \
Naming.pymethdef_prefix + prefix + name
if self.is_generator:
self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name
if Options.docstrings:
entry.doc = embed_position(self.pos, self.doc)
entry.doc_cname = \
......@@ -2303,7 +2382,15 @@ class DefNode(FuncDefNode):
"static PyMethodDef %s = " %
self.entry.pymethdef_cname)
code.put_pymethoddef(self.entry, ";", allow_skip=False)
code.putln("%s {" % header)
if self.is_generator:
code.putln("static PyObject *%s(PyObject *%s, PyObject *__pyx_send_value, int __pyx_is_exc) /* generator body */\n{" %
(self.generator_body_cname, Naming.self_cname))
self.generator = GeneratorWrapperNode(self,
func_cname=self.entry.func_cname,
body_cname=self.generator_body_cname,
header=header)
else:
code.putln("%s {" % header)
def generate_argument_declarations(self, env, code):
for arg in self.args:
......
......@@ -1166,7 +1166,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node):
def visit_OldYieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
......
......@@ -1316,7 +1316,7 @@ class MarkClosureVisitor(CythonTransform):
node.needs_closure = self.needs_closure
self.needs_closure = True
return node
def visit_CFuncDefNode(self, node):
self.visit_FuncDefNode(node)
if node.needs_closure:
......@@ -1335,6 +1335,89 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True
return node
class ClosureTempAllocator(object):
def __init__(self, klass=None):
self.klass = klass
self.temps_allocated = {}
self.temps_free = {}
self.temps_count = 0
def reset(self):
for type, cnames in self.temps_allocated:
self.temps_free[type] = list(cnames)
def allocate_temp(self, type):
if not type in self.temps_allocated:
self.temps_allocated[type] = []
self.temps_free[type] = []
if self.temps_free[type]:
return self.temps_free[type].pop(0)
cname = '%s_%d' % (Naming.codewriter_temp_prefix, self.temps_count)
self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True)
self.temps_allocated[type].append(cname)
self.temps_count += 1
return cname
class YieldCollector(object):
def __init__(self, node):
self.node = node
self.yields = []
self.returns = []
class MarkGeneratorVisitor(CythonTransform):
"""XXX: merge me with MarkClosureVisitor"""
def __init__(self, context):
super(MarkGeneratorVisitor, self).__init__(context)
self.allow_yield = False
self.path = []
def visit_ModuleNode(self, node):
self.visitchildren(node)
return node
def visit_ClassDefNode(self, node):
saved = self.allow_yield
self.allow_yield = False
self.visitchildren(node)
self.allow_yield = saved
return node
def visit_FuncDefNode(self, node):
saved = self.allow_yield
self.allow_yield = True
self.path.append(YieldCollector(node))
self.visitchildren(node)
self.allow_yield = saved
collector = self.path.pop()
if collector.yields and collector.returns:
error(collector.returns[0].pos, "'return' with argument inside generator")
elif collector.yields:
allocator = ClosureTempAllocator()
stop_node = ExprNodes.StopIterationNode(node.pos, arg=None)
collector.yields.append(stop_node)
for y in collector.yields: # XXX: find a better way
y.temp_allocator = allocator
node.temp_allocator = allocator
stop_node.label_num = len(collector.yields)
node.body.stats.append(Nodes.ExprStatNode(node.pos, expr=stop_node))
node.is_generator = True
node.needs_closure = True
node.yields = collector.yields
return node
def visit_YieldExprNode(self, node):
if not self.allow_yield:
error(node.pos, "'yield' outside function")
return node
collector = self.path[-1]
collector.yields.append(node)
node.label_num = len(collector.yields)
return node
def visit_ReturnStatNode(self, node):
if self.path:
self.path[-1].returns.append(node)
return node
class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions
......@@ -1344,12 +1427,57 @@ class CreateClosureClasses(CythonTransform):
super(CreateClosureClasses, self).__init__(context)
self.path = []
self.in_lambda = False
self.generator_class = None
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.visitchildren(node)
return node
def create_abstract_generator(self, target_module_scope, pos):
if self.generator_class:
return self.generator_class
# XXX: make generator class creation cleaner
entry = target_module_scope.declare_c_class(name='__CyGenerator',
objstruct_cname='__CyGenerator',
typeobj_cname='__CyGeneratorType',
pos=pos, defining=True, implementing=True)
entry.cname = 'CyGenerator'
klass = entry.type.scope
klass.is_internal = True
klass.directives = {'final': True}
body_type = PyrexTypes.create_typedef_type('generator_body',
PyrexTypes.c_void_ptr_type,
'__cygenerator_body_t')
klass.declare_var(pos=pos, name='body', cname='body',
type=body_type, is_cdef=True)
klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
is_cdef=True)
klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
is_cdef=True)
import TypeSlots
e = klass.declare_pyfunction('send', pos)
e.func_cname = '__CyGenerator_Send'
e.signature = TypeSlots.binaryfunc
#e = klass.declare_pyfunction('close', pos)
#e.func_cname = '__CyGenerator_Close'
#e.signature = TypeSlots.unaryfunc
#e = klass.declare_pyfunction('throw', pos)
#e.func_cname = '__CyGenerator_Throw'
e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = 'PyObject_SelfIter'
e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = '__CyGenerator_Next'
self.generator_class = entry.type
return self.generator_class
def get_scope_use(self, node):
from_closure = []
in_closure = []
......@@ -1361,6 +1489,12 @@ class CreateClosureClasses(CythonTransform):
return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None):
# move local variables into closure
if node.is_generator:
for entry in node.local_scope.entries.values():
if not entry.from_closure:
entry.in_closure = True
from_closure, in_closure = self.get_scope_use(node)
in_closure.sort()
......@@ -1380,8 +1514,10 @@ class CreateClosureClasses(CythonTransform):
inner_node = node.assmt.rhs
inner_node.needs_self_code = False
node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure:
if node.is_generator:
generator_class = self.create_abstract_generator(target_module_scope, node.pos)
elif not in_closure and not from_closure:
return
elif not in_closure:
func_scope.is_passthrough = True
......@@ -1391,13 +1527,19 @@ class CreateClosureClasses(CythonTransform):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
if node.is_generator:
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True, base_type=generator_class)
else:
entry = target_module_scope.declare_c_class(name = as_name,
pos = node.pos, defining = True, implementing = True)
func_scope.scope_class = entry
class_scope = entry.type.scope
class_scope.is_internal = True
class_scope.directives = {'final': True}
if node.is_generator:
node.temp_allocator.klass = class_scope
if from_closure:
assert cscope.is_closure_scope
class_scope.declare_var(pos=node.pos,
......
......@@ -1029,7 +1029,7 @@ def p_testlist_comp(s):
def p_genexp(s, expr):
# s.sy == 'for'
loop = p_comp_for(s, Nodes.ExprStatNode(
expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr)))
expr.pos, expr = ExprNodes.OldYieldExprNode(expr.pos, arg=expr)))
return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')
......
def simple():
"""
>>> x = simple()
>>> list(x)
[1, 2, 3]
"""
yield 1
yield 2
yield 3
def simple_seq(seq):
"""
>>> x = simple_seq("abc")
>>> list(x)
['a', 'b', 'c']
"""
for i in seq:
yield i
def simple_send():
"""
>>> x = simple_send()
>>> next(x)
>>> x.send(1)
1
>>> x.send(2)
2
>>> x.send(3)
3
"""
i = None
while True:
i = yield i
def with_outer(*args):
"""
>>> x = with_outer(1, 2, 3)
>>> list(x())
[1, 2, 3]
"""
def generator():
for i in args:
yield i
return generator
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment