Commit 7ef3e52c authored by Stefan Behnel's avatar Stefan Behnel

merged in Vitek's generators branch

parents b2b337b2 fb3ac076
...@@ -34,6 +34,8 @@ cdef class FunctionState: ...@@ -34,6 +34,8 @@ cdef class FunctionState:
cdef public dict temps_used_type cdef public dict temps_used_type
cdef public size_t temp_counter cdef public size_t temp_counter
cdef public object closure_temps
@cython.locals(n=size_t) @cython.locals(n=size_t)
cpdef new_label(self, name=*) cpdef new_label(self, name=*)
cpdef tuple get_loop_labels(self) cpdef tuple get_loop_labels(self)
......
...@@ -117,6 +117,7 @@ class FunctionState(object): ...@@ -117,6 +117,7 @@ class FunctionState(object):
self.temps_free = {} # (type, manage_ref) -> list of free vars with same type/managed status self.temps_free = {} # (type, manage_ref) -> list of free vars with same type/managed status
self.temps_used_type = {} # name -> (type, manage_ref) self.temps_used_type = {} # name -> (type, manage_ref)
self.temp_counter = 0 self.temp_counter = 0
self.closure_temps = None
# labels # labels
...@@ -270,6 +271,9 @@ class FunctionState(object): ...@@ -270,6 +271,9 @@ class FunctionState(object):
if manage_ref if manage_ref
for cname in freelist] for cname in freelist]
def init_closure_temps(self, scope):
self.closure_temps = ClosureTempAllocator(scope)
class IntConst(object): class IntConst(object):
"""Global info about a Python integer constant held by GlobalState. """Global info about a Python integer constant held by GlobalState.
...@@ -475,6 +479,7 @@ class GlobalState(object): ...@@ -475,6 +479,7 @@ class GlobalState(object):
w.enter_cfunc_scope() w.enter_cfunc_scope()
w.putln("") w.putln("")
w.putln("static int __Pyx_InitCachedConstants(void) {") w.putln("static int __Pyx_InitCachedConstants(void) {")
w.put_declare_refcount_context()
w.put_setup_refcount_context("__Pyx_InitCachedConstants") w.put_setup_refcount_context("__Pyx_InitCachedConstants")
w = self.parts['init_globals'] w = self.parts['init_globals']
...@@ -1297,6 +1302,8 @@ class CCodeWriter(object): ...@@ -1297,6 +1302,8 @@ class CCodeWriter(object):
#if entry.type.is_extension_type: #if entry.type.is_extension_type:
# code = "((PyObject*)%s)" % code # code = "((PyObject*)%s)" % code
self.put_init_to_py_none(code, entry.type, nanny) self.put_init_to_py_none(code, entry.type, nanny)
if entry.in_closure:
self.put_giveref('Py_None')
def put_pymethoddef(self, entry, term, allow_skip=True): def put_pymethoddef(self, entry, term, allow_skip=True):
if entry.is_special or entry.name == '__getattribute__': if entry.is_special or entry.name == '__getattribute__':
...@@ -1366,6 +1373,9 @@ class CCodeWriter(object): ...@@ -1366,6 +1373,9 @@ class CCodeWriter(object):
def lookup_filename(self, filename): def lookup_filename(self, filename):
return self.globalstate.lookup_filename(filename) return self.globalstate.lookup_filename(filename)
def put_declare_refcount_context(self):
self.putln('__Pyx_RefNannyDeclareContext;')
def put_setup_refcount_context(self, name): def put_setup_refcount_context(self, name):
self.putln('__Pyx_RefNannySetupContext("%s");' % name) self.putln('__Pyx_RefNannySetupContext("%s");' % name)
...@@ -1402,3 +1412,26 @@ class PyrexCodeWriter(object): ...@@ -1402,3 +1412,26 @@ class PyrexCodeWriter(object):
def dedent(self): def dedent(self):
self.level -= 1 self.level -= 1
class ClosureTempAllocator(object):
def __init__(self, klass):
self.klass = klass
self.temps_allocated = {}
self.temps_free = {}
self.temps_count = 0
def reset(self):
for type, cnames in self.temps_allocated.items():
self.temps_free[type] = list(cnames)
def allocate_temp(self, type):
if not type in self.temps_allocated:
self.temps_allocated[type] = []
self.temps_free[type] = []
elif self.temps_free[type]:
return self.temps_free[type].pop(0)
cname = '%s%d' % (Naming.codewriter_temp_prefix, self.temps_count)
self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True)
self.temps_allocated[type].append(cname)
self.temps_count += 1
return cname
This diff is collapsed.
...@@ -964,9 +964,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -964,9 +964,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type.vtabstruct_cname, type.vtabstruct_cname,
type.vtabslot_cname)) type.vtabslot_cname))
for attr in type.scope.var_entries: for attr in type.scope.var_entries:
if attr.is_declared_generic:
attr_type = py_object_type
else:
attr_type = attr.type
code.putln( code.putln(
"%s;" % "%s;" %
attr.type.declaration_code(attr.cname)) attr_type.declaration_code(attr.cname))
code.putln(footer) code.putln(footer)
if type.objtypedef_cname is not None: if type.objtypedef_cname is not None:
# Only for exposing public typedef name. # Only for exposing public typedef name.
...@@ -1265,6 +1269,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1265,6 +1269,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in py_attrs: for entry in py_attrs:
name = "p->%s" % entry.cname name = "p->%s" % entry.cname
code.putln("tmp = ((PyObject*)%s);" % name) code.putln("tmp = ((PyObject*)%s);" % name)
if entry.is_declared_generic:
code.put_init_to_py_none(name, py_object_type, nanny=False)
else:
code.put_init_to_py_none(name, entry.type, nanny=False) code.put_init_to_py_none(name, entry.type, nanny=False)
code.putln("Py_XDECREF(tmp);") code.putln("Py_XDECREF(tmp);")
code.putln( code.putln(
...@@ -2762,8 +2769,9 @@ refnanny_utility_code = UtilityCode(proto=""" ...@@ -2762,8 +2769,9 @@ refnanny_utility_code = UtilityCode(proto="""
Py_XDECREF(m); Py_XDECREF(m);
return (__Pyx_RefNannyAPIStruct *)r; return (__Pyx_RefNannyAPIStruct *)r;
} }
#define __Pyx_RefNannyDeclareContext void *__pyx_refnanny;
#define __Pyx_RefNannySetupContext(name) \ #define __Pyx_RefNannySetupContext(name) \
void *__pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__) __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), __LINE__, __FILE__)
#define __Pyx_RefNannyFinishContext() \ #define __Pyx_RefNannyFinishContext() \
__Pyx_RefNanny->FinishContext(&__pyx_refnanny) __Pyx_RefNanny->FinishContext(&__pyx_refnanny)
#define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), __LINE__) #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
...@@ -2772,6 +2780,7 @@ refnanny_utility_code = UtilityCode(proto=""" ...@@ -2772,6 +2780,7 @@ refnanny_utility_code = UtilityCode(proto="""
#define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), __LINE__) #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), __LINE__)
#define __Pyx_XDECREF(r) do { if((r) != NULL) {__Pyx_DECREF(r);} } while(0) #define __Pyx_XDECREF(r) do { if((r) != NULL) {__Pyx_DECREF(r);} } while(0)
#else #else
#define __Pyx_RefNannyDeclareContext
#define __Pyx_RefNannySetupContext(name) #define __Pyx_RefNannySetupContext(name)
#define __Pyx_RefNannyFinishContext() #define __Pyx_RefNannyFinishContext()
#define __Pyx_INCREF(r) Py_INCREF(r) #define __Pyx_INCREF(r) Py_INCREF(r)
......
...@@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_" ...@@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_"
enum_prefix = pyrex_prefix + "e_" enum_prefix = pyrex_prefix + "e_"
func_prefix = pyrex_prefix + "f_" func_prefix = pyrex_prefix + "f_"
pyfunc_prefix = pyrex_prefix + "pf_" pyfunc_prefix = pyrex_prefix + "pf_"
genbody_prefix = pyrex_prefix + "gb_"
gstab_prefix = pyrex_prefix + "getsets_" gstab_prefix = pyrex_prefix + "getsets_"
prop_get_prefix = pyrex_prefix + "getprop_" prop_get_prefix = pyrex_prefix + "getprop_"
const_prefix = pyrex_prefix + "k_" const_prefix = pyrex_prefix + "k_"
...@@ -51,6 +52,7 @@ lambda_func_prefix = pyrex_prefix + "lambda_" ...@@ -51,6 +52,7 @@ lambda_func_prefix = pyrex_prefix + "lambda_"
module_is_main = pyrex_prefix + "module_is_main_" module_is_main = pyrex_prefix + "module_is_main_"
args_cname = pyrex_prefix + "args" args_cname = pyrex_prefix + "args"
sent_value_cname = pyrex_prefix + "sent_value"
pykwdlist_cname = pyrex_prefix + "pyargnames" pykwdlist_cname = pyrex_prefix + "pyargnames"
obj_base_cname = pyrex_prefix + "base" obj_base_cname = pyrex_prefix + "base"
builtins_cname = pyrex_prefix + "b" builtins_cname = pyrex_prefix + "b"
...@@ -107,10 +109,6 @@ exc_lineno_name = pyrex_prefix + "exc_lineno" ...@@ -107,10 +109,6 @@ exc_lineno_name = pyrex_prefix + "exc_lineno"
exc_vars = (exc_type_name, exc_value_name, exc_tb_name) exc_vars = (exc_type_name, exc_value_name, exc_tb_name)
exc_save_vars = (pyrex_prefix + 'save_exc_type',
pyrex_prefix + 'save_exc_value',
pyrex_prefix + 'save_exc_tb')
api_name = pyrex_prefix + "capi__" api_name = pyrex_prefix + "capi__"
h_guard_prefix = "__PYX_HAVE__" h_guard_prefix = "__PYX_HAVE__"
......
This diff is collapsed.
...@@ -1170,11 +1170,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1170,11 +1170,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self.yield_nodes = [] self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node): # XXX: disable inlining while it's not back supported
def __visit_YieldExprNode(self, node):
self.yield_nodes.append(node) self.yield_nodes.append(node)
self.visitchildren(node) self.visitchildren(node)
def visit_ExprStatNode(self, node): def __visit_ExprStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
if node.expr in self.yield_nodes: if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node self.yield_stat_nodes[node.expr] = node
......
...@@ -19,6 +19,7 @@ cdef class NormalizeTree(CythonTransform): ...@@ -19,6 +19,7 @@ cdef class NormalizeTree(CythonTransform):
cdef class PostParse(ScopeTrackingTransform): cdef class PostParse(ScopeTrackingTransform):
cdef dict specialattribute_handlers cdef dict specialattribute_handlers
cdef size_t lambda_counter cdef size_t lambda_counter
cdef size_t genexpr_counter
cdef _visit_assignment_node(self, node, list expr_list) cdef _visit_assignment_node(self, node, list expr_list)
...@@ -45,6 +46,11 @@ cdef class AlignFunctionDefinitions(CythonTransform): ...@@ -45,6 +46,11 @@ cdef class AlignFunctionDefinitions(CythonTransform):
cdef dict directives cdef dict directives
cdef scope cdef scope
cdef class YieldNodeCollector(TreeVisitor):
cdef public list yields
cdef public list returns
cdef public bint has_return_value
cdef class MarkClosureVisitor(CythonTransform): cdef class MarkClosureVisitor(CythonTransform):
cdef bint needs_closure cdef bint needs_closure
...@@ -52,6 +58,7 @@ cdef class CreateClosureClasses(CythonTransform): ...@@ -52,6 +58,7 @@ cdef class CreateClosureClasses(CythonTransform):
cdef list path cdef list path
cdef bint in_lambda cdef bint in_lambda
cdef module_scope cdef module_scope
cdef generator_class
cdef class GilCheck(VisitorTransform): cdef class GilCheck(VisitorTransform):
cdef list env_stack cdef list env_stack
......
...@@ -182,6 +182,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -182,6 +182,7 @@ class PostParse(ScopeTrackingTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.lambda_counter = 1 self.lambda_counter = 1
self.genexpr_counter = 1
return super(PostParse, self).visit_ModuleNode(node) return super(PostParse, self).visit_ModuleNode(node)
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
...@@ -189,14 +190,34 @@ class PostParse(ScopeTrackingTransform): ...@@ -189,14 +190,34 @@ class PostParse(ScopeTrackingTransform):
lambda_id = self.lambda_counter lambda_id = self.lambda_counter
self.lambda_counter += 1 self.lambda_counter += 1
node.lambda_name = EncodedString(u'lambda%d' % lambda_id) node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
collector = YieldNodeCollector()
collector.visitchildren(node.result_expr)
if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode):
body = ExprNodes.YieldExprNode(
node.result_expr.pos, arg=node.result_expr)
body = Nodes.ExprStatNode(node.result_expr.pos, expr=body)
else:
body = Nodes.ReturnStatNode( body = Nodes.ReturnStatNode(
node.result_expr.pos, value = node.result_expr) node.result_expr.pos, value=node.result_expr)
node.def_node = Nodes.DefNode( node.def_node = Nodes.DefNode(
node.pos, name=node.name, lambda_name=node.lambda_name, node.pos, name=node.name, lambda_name=node.lambda_name,
args=node.args, star_arg=node.star_arg, args=node.args, star_arg=node.star_arg,
starstar_arg=node.starstar_arg, starstar_arg=node.starstar_arg,
body=body) body=body, doc=None)
self.visitchildren(node)
return node
def visit_GeneratorExpressionNode(self, node):
# unpack a generator expression into the corresponding DefNode
genexpr_id = self.genexpr_counter
self.genexpr_counter += 1
node.genexpr_name = EncodedString(u'genexpr%d' % genexpr_id)
node.def_node = Nodes.DefNode(node.pos, name=node.name,
doc=None,
args=[], star_arg=None,
starstar_arg=None,
body=node.loop)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -1408,6 +1429,42 @@ class AlignFunctionDefinitions(CythonTransform): ...@@ -1408,6 +1429,42 @@ class AlignFunctionDefinitions(CythonTransform):
return node return node
class YieldNodeCollector(TreeVisitor):
def __init__(self):
super(YieldNodeCollector, self).__init__()
self.yields = []
self.returns = []
self.has_return_value = False
def visit_Node(self, node):
return self.visitchildren(node)
def visit_YieldExprNode(self, node):
if self.has_return_value:
error(node.pos, "'yield' outside function")
self.yields.append(node)
self.visitchildren(node)
def visit_ReturnStatNode(self, node):
if node.value:
self.has_return_value = True
if self.yields:
error(node.pos, "'return' with argument inside generator")
self.returns.append(node)
def visit_ClassDefNode(self, node):
pass
def visit_DefNode(self, node):
pass
def visit_LambdaNode(self, node):
pass
def visit_GeneratorExpressionNode(self, node):
pass
class MarkClosureVisitor(CythonTransform): class MarkClosureVisitor(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
...@@ -1420,6 +1477,27 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1420,6 +1477,27 @@ class MarkClosureVisitor(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
node.needs_closure = self.needs_closure node.needs_closure = self.needs_closure
self.needs_closure = True self.needs_closure = True
collector = YieldNodeCollector()
collector.visitchildren(node)
if collector.yields:
for i, yield_expr in enumerate(collector.yields):
yield_expr.label_num = i + 1
gbody = Nodes.GeneratorBodyDefNode(pos=node.pos,
name=node.name,
body=node.body)
generator = Nodes.GeneratorDefNode(pos=node.pos,
name=node.name,
args=node.args,
star_arg=node.star_arg,
starstar_arg=node.starstar_arg,
doc=node.doc,
decorators=node.decorators,
gbody=gbody,
lambda_name=node.lambda_name)
return generator
return node return node
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
...@@ -1440,7 +1518,6 @@ class MarkClosureVisitor(CythonTransform): ...@@ -1440,7 +1518,6 @@ class MarkClosureVisitor(CythonTransform):
self.needs_closure = True self.needs_closure = True
return node return node
class CreateClosureClasses(CythonTransform): class CreateClosureClasses(CythonTransform):
# Output closure classes in module scope for all functions # Output closure classes in module scope for all functions
# that really need it. # that really need it.
...@@ -1449,24 +1526,78 @@ class CreateClosureClasses(CythonTransform): ...@@ -1449,24 +1526,78 @@ class CreateClosureClasses(CythonTransform):
super(CreateClosureClasses, self).__init__(context) super(CreateClosureClasses, self).__init__(context)
self.path = [] self.path = []
self.in_lambda = False self.in_lambda = False
self.generator_class = None
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.module_scope = node.scope self.module_scope = node.scope
self.visitchildren(node) self.visitchildren(node)
return node return node
def get_scope_use(self, node): def create_generator_class(self, target_module_scope, pos):
if self.generator_class:
return self.generator_class
# XXX: make generator class creation cleaner
entry = target_module_scope.declare_c_class(name='__pyx_Generator',
objstruct_cname='__pyx_Generator_object',
typeobj_cname='__pyx_Generator_type',
pos=pos, defining=True, implementing=True)
klass = entry.type.scope
klass.is_internal = True
klass.directives = {'final': True}
body_type = PyrexTypes.create_typedef_type('generator_body',
PyrexTypes.c_void_ptr_type,
'__pyx_generator_body_t')
klass.declare_var(pos=pos, name='body', cname='body',
type=body_type, is_cdef=True)
klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
is_cdef=True)
klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
is_cdef=True)
import TypeSlots
e = klass.declare_pyfunction('send', pos)
e.func_cname = '__Pyx_Generator_Send'
e.signature = TypeSlots.binaryfunc
e = klass.declare_pyfunction('close', pos)
e.func_cname = '__Pyx_Generator_Close'
e.signature = TypeSlots.unaryfunc
e = klass.declare_pyfunction('throw', pos)
e.func_cname = '__Pyx_Generator_Throw'
e.signature = TypeSlots.pyfunction_signature
e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = 'PyObject_SelfIter'
e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
e.func_cname = '__Pyx_Generator_Next'
self.generator_class = entry.type
return self.generator_class
def find_entries_used_in_closures(self, node):
from_closure = [] from_closure = []
in_closure = [] in_closure = []
for name, entry in node.local_scope.entries.items(): for name, entry in node.local_scope.entries.items():
if entry.from_closure: if entry.from_closure:
from_closure.append((name, entry)) from_closure.append((name, entry))
elif entry.in_closure and not entry.from_closure: elif entry.in_closure:
in_closure.append((name, entry)) in_closure.append((name, entry))
return from_closure, in_closure return from_closure, in_closure
def create_class_from_scope(self, node, target_module_scope, inner_node=None): def create_class_from_scope(self, node, target_module_scope, inner_node=None):
from_closure, in_closure = self.get_scope_use(node) # skip generator body
if node.is_generator_body:
return
# move local variables into closure
if node.is_generator:
for entry in node.local_scope.entries.values():
if not entry.from_closure:
entry.in_closure = True
from_closure, in_closure = self.find_entries_used_in_closures(node)
in_closure.sort() in_closure.sort()
# Now from the begining # Now from the begining
...@@ -1485,8 +1616,11 @@ class CreateClosureClasses(CythonTransform): ...@@ -1485,8 +1616,11 @@ class CreateClosureClasses(CythonTransform):
inner_node = node.assmt.rhs inner_node = node.assmt.rhs
inner_node.needs_self_code = False inner_node.needs_self_code = False
node.needs_outer_scope = False node.needs_outer_scope = False
# Simple cases
if not in_closure and not from_closure: base_type = None
if node.is_generator:
base_type = self.create_generator_class(target_module_scope, node.pos)
elif not in_closure and not from_closure:
return return
elif not in_closure: elif not in_closure:
func_scope.is_passthrough = True func_scope.is_passthrough = True
...@@ -1496,8 +1630,10 @@ class CreateClosureClasses(CythonTransform): ...@@ -1496,8 +1630,10 @@ class CreateClosureClasses(CythonTransform):
as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname) as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
entry = target_module_scope.declare_c_class(name = as_name, entry = target_module_scope.declare_c_class(
pos = node.pos, defining = True, implementing = True) name=as_name, pos=node.pos, defining=True,
implementing=True, base_type=base_type)
func_scope.scope_class = entry func_scope.scope_class = entry
class_scope = entry.type.scope class_scope = entry.type.scope
class_scope.is_internal = True class_scope.is_internal = True
...@@ -1512,11 +1648,13 @@ class CreateClosureClasses(CythonTransform): ...@@ -1512,11 +1648,13 @@ class CreateClosureClasses(CythonTransform):
is_cdef=True) is_cdef=True)
node.needs_outer_scope = True node.needs_outer_scope = True
for name, entry in in_closure: for name, entry in in_closure:
class_scope.declare_var(pos=entry.pos, closure_entry = class_scope.declare_var(pos=entry.pos,
name=entry.name, name=entry.name,
cname=entry.cname, cname=entry.cname,
type=entry.type, type=entry.type,
is_cdef=True) is_cdef=True)
if entry.is_declared_generic:
closure_entry.is_declared_generic = 1
node.needs_closure = True node.needs_closure = True
# Do it here because other classes are already checked # Do it here because other classes are already checked
target_module_scope.check_c_class(func_scope.scope_class) target_module_scope.check_c_class(func_scope.scope_class)
......
...@@ -84,6 +84,7 @@ cdef p_genexp(PyrexScanner s, expr) ...@@ -84,6 +84,7 @@ cdef p_genexp(PyrexScanner s, expr)
#------------------------------------------------------- #-------------------------------------------------------
cdef p_global_statement(PyrexScanner s) cdef p_global_statement(PyrexScanner s)
cdef p_nonlocal_statement(PyrexScanner s)
cdef p_expression_or_assignment(PyrexScanner s) cdef p_expression_or_assignment(PyrexScanner s)
cdef p_print_statement(PyrexScanner s) cdef p_print_statement(PyrexScanner s)
cdef p_exec_statement(PyrexScanner s) cdef p_exec_statement(PyrexScanner s)
......
...@@ -1045,6 +1045,12 @@ def p_global_statement(s): ...@@ -1045,6 +1045,12 @@ def p_global_statement(s):
names = p_ident_list(s) names = p_ident_list(s)
return Nodes.GlobalNode(pos, names = names) return Nodes.GlobalNode(pos, names = names)
def p_nonlocal_statement(s):
pos = s.position()
s.next()
names = p_ident_list(s)
return Nodes.NonlocalNode(pos, names = names)
def p_expression_or_assignment(s): def p_expression_or_assignment(s):
expr_list = [p_testlist_star_expr(s)] expr_list = [p_testlist_star_expr(s)]
while s.sy == '=': while s.sy == '=':
...@@ -1598,6 +1604,8 @@ def p_simple_statement(s, first_statement = 0): ...@@ -1598,6 +1604,8 @@ def p_simple_statement(s, first_statement = 0):
#print "p_simple_statement:", s.sy, s.systring ### #print "p_simple_statement:", s.sy, s.systring ###
if s.sy == 'global': if s.sy == 'global':
node = p_global_statement(s) node = p_global_statement(s)
elif s.sy == 'nonlocal':
node = p_nonlocal_statement(s)
elif s.sy == 'print': elif s.sy == 'print':
node = p_print_statement(s) node = p_print_statement(s)
elif s.sy == 'exec': elif s.sy == 'exec':
......
...@@ -36,7 +36,7 @@ def get_lexicon(): ...@@ -36,7 +36,7 @@ def get_lexicon():
#------------------------------------------------------------------ #------------------------------------------------------------------
py_reserved_words = [ py_reserved_words = [
"global", "def", "class", "print", "del", "pass", "break", "global", "nonlocal", "def", "class", "print", "del", "pass", "break",
"continue", "return", "raise", "import", "exec", "try", "continue", "return", "raise", "import", "exec", "try",
"except", "finally", "while", "if", "elif", "else", "for", "except", "finally", "while", "if", "elif", "else", "for",
"in", "assert", "and", "or", "not", "is", "in", "lambda", "in", "assert", "and", "or", "not", "is", "in", "lambda",
......
...@@ -1309,6 +1309,16 @@ class LocalScope(Scope): ...@@ -1309,6 +1309,16 @@ class LocalScope(Scope):
entry = self.global_scope().lookup_target(name) entry = self.global_scope().lookup_target(name)
self.entries[name] = entry self.entries[name] = entry
def declare_nonlocal(self, name, pos):
# Pull entry from outer scope into local scope
orig_entry = self.lookup_here(name)
if orig_entry and orig_entry.scope is self and not orig_entry.from_closure:
error(pos, "'%s' redeclared as nonlocal" % name)
else:
entry = self.lookup(name)
if entry is None or not entry.from_closure:
error(pos, "no binding for nonlocal '%s' found" % name)
def lookup(self, name): def lookup(self, name):
# Look up name in this scope or an enclosing one. # Look up name in this scope or an enclosing one.
# Return None if not found. # Return None if not found.
...@@ -1326,6 +1336,7 @@ class LocalScope(Scope): ...@@ -1326,6 +1336,7 @@ class LocalScope(Scope):
inner_entry.is_variable = True inner_entry.is_variable = True
inner_entry.outer_entry = entry inner_entry.outer_entry = entry
inner_entry.from_closure = True inner_entry.from_closure = True
inner_entry.is_declared_generic = entry.is_declared_generic
self.entries[name] = inner_entry self.entries[name] = inner_entry
return inner_entry return inner_entry
return entry return entry
...@@ -1479,6 +1490,20 @@ class PyClassScope(ClassScope): ...@@ -1479,6 +1490,20 @@ class PyClassScope(ClassScope):
entry.is_pyclass_attr = 1 entry.is_pyclass_attr = 1
return entry return entry
def declare_nonlocal(self, name, pos):
# Pull entry from outer scope into local scope
orig_entry = self.lookup_here(name)
if orig_entry and orig_entry.scope is self and not orig_entry.from_closure:
error(pos, "'%s' redeclared as nonlocal" % name)
else:
entry = self.lookup(name)
if entry is None:
error(pos, "no binding for nonlocal '%s' found" % name)
else:
# FIXME: this works, but it's unclear if it's the
# right thing to do
self.entries[name] = entry
def add_default_value(self, type): def add_default_value(self, type):
return self.outer_scope.add_default_value(type) return self.outer_scope.add_default_value(type)
......
...@@ -221,7 +221,8 @@ class SimpleAssignmentTypeInferer(object): ...@@ -221,7 +221,8 @@ class SimpleAssignmentTypeInferer(object):
# TODO: Implement a real type inference algorithm. # TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...) # (Something more powerful than just extending this one...)
def infer_types(self, scope): def infer_types(self, scope):
enabled = not scope.is_closure_scope and scope.directives['infer_types'] closure_or_inner = scope.is_closure_scope or (scope.outer_scope and scope.outer_scope.is_closure_scope)
enabled = not closure_or_inner and scope.directives['infer_types']
verbose = scope.directives['infer_types.verbose'] verbose = scope.directives['infer_types.verbose']
if enabled == True: if enabled == True:
spanning_type = aggressive_spanning_type spanning_type = aggressive_spanning_type
......
...@@ -10,7 +10,6 @@ cfunc_call_tuple_args_T408 ...@@ -10,7 +10,6 @@ cfunc_call_tuple_args_T408
compile.cpp_operators compile.cpp_operators
cpp_templated_ctypedef cpp_templated_ctypedef
cpp_structs cpp_structs
genexpr_T491
with_statement_module_level_T536 with_statement_module_level_T536
function_as_method_T494 function_as_method_T494
closure_inside_cdef_T554 closure_inside_cdef_T554
...@@ -19,6 +18,7 @@ genexpr_iterable_lookup_T600 ...@@ -19,6 +18,7 @@ genexpr_iterable_lookup_T600
for_from_pyvar_loop_T601 for_from_pyvar_loop_T601
decorators_T593 decorators_T593
temp_sideeffects_T654 temp_sideeffects_T654
generator_type_inference
# CPython regression tests that don't current work: # CPython regression tests that don't current work:
pyregr.test_threadsignals pyregr.test_threadsignals
......
def foo():
yield
return 0
def bar(a):
return 0
yield
yield
class Foo:
yield
_ERRORS = u"""
3:4: 'return' with argument inside generator
7:4: 'yield' outside function
9:0: 'yield' not supported here
12:4: 'yield' not supported here
"""
def test_non_existant():
nonlocal no_such_name
no_such_name = 1
def redef():
x = 1
def f():
x = 2
nonlocal x
global_name = 5
def ref_to_global():
nonlocal global_name
global_name = 6
def global_in_class_scope():
class Test():
nonlocal global_name
global_name = 6
def redef_in_class_scope():
x = 1
class Test():
x = 2
nonlocal x
_ERRORS = u"""
3:4: no binding for nonlocal 'no_such_name' found
10:8: 'x' redeclared as nonlocal
15:4: no binding for nonlocal 'global_name' found
27:8: 'x' redeclared as nonlocal
"""
cdef class Test:
cdef int x
cdef class SelfInClosure(object):
cdef Test _t
cdef int x
def plain(self):
"""
>>> o = SelfInClosure()
>>> o.plain()
1
"""
self.x = 1
return self.x
def closure_method(self):
"""
>>> o = SelfInClosure()
>>> o.closure_method()() == o
True
"""
def nested():
return self
return nested
def closure_method_cdef_attr(self, Test t):
"""
>>> o = SelfInClosure()
>>> o.closure_method_cdef_attr(Test())()
(1, 2)
"""
t.x = 2
self._t = t
self.x = 1
def nested():
return self.x, t.x
return nested
# mode: run
# tag: typeinference, generators
cimport cython
def test_type_inference():
"""
>>> [ item for item in test_type_inference() ]
[(2.0, 'double'), (2.0, 'double'), (2.0, 'double')]
"""
x = 1.0
for i in range(3):
yield x * 2.0, cython.typeof(x)
try:
from builtins import next # Py3k
except ImportError:
def next(it):
return it.next()
if hasattr(__builtins__, 'GeneratorExit'):
GeneratorExit = __builtins__.GeneratorExit
else: # < 2.5
GeneratorExit = StopIteration
def very_simple():
"""
>>> x = very_simple()
>>> next(x)
1
>>> next(x)
Traceback (most recent call last):
StopIteration
>>> next(x)
Traceback (most recent call last):
StopIteration
>>> x = very_simple()
>>> x.send(1)
Traceback (most recent call last):
TypeError: can't send non-None value to a just-started generator
"""
yield 1
def simple():
"""
>>> x = simple()
>>> list(x)
[1, 2, 3]
"""
yield 1
yield 2
yield 3
def simple_seq(seq):
"""
>>> x = simple_seq("abc")
>>> list(x)
['a', 'b', 'c']
"""
for i in seq:
yield i
def simple_send():
"""
>>> x = simple_send()
>>> next(x)
>>> x.send(1)
1
>>> x.send(2)
2
>>> x.send(3)
3
"""
i = None
while True:
i = yield i
def raising():
"""
>>> x = raising()
>>> next(x)
Traceback (most recent call last):
KeyError: 'foo'
>>> next(x)
Traceback (most recent call last):
StopIteration
"""
yield {}['foo']
def with_outer(*args):
"""
>>> x = with_outer(1, 2, 3)
>>> list(x())
[1, 2, 3]
"""
def generator():
for i in args:
yield i
return generator
def with_outer_raising(*args):
"""
>>> x = with_outer_raising(1, 2, 3)
>>> list(x())
[1, 2, 3]
"""
def generator():
for i in args:
yield i
raise StopIteration
return generator
def test_close():
"""
>>> x = test_close()
>>> x.close()
>>> x = test_close()
>>> next(x)
>>> x.close()
>>> next(x)
Traceback (most recent call last):
StopIteration
"""
while True:
yield
def test_ignore_close():
"""
>>> x = test_ignore_close()
>>> x.close()
>>> x = test_ignore_close()
>>> next(x)
>>> x.close()
Traceback (most recent call last):
RuntimeError: generator ignored GeneratorExit
"""
try:
yield
except GeneratorExit:
yield
def check_throw():
"""
>>> x = check_throw()
>>> x.throw(ValueError)
Traceback (most recent call last):
ValueError
>>> next(x)
Traceback (most recent call last):
StopIteration
>>> x = check_throw()
>>> next(x)
>>> x.throw(ValueError)
>>> next(x)
>>> x.throw(IndexError, "oops")
Traceback (most recent call last):
IndexError: oops
>>> next(x)
Traceback (most recent call last):
StopIteration
"""
while True:
try:
yield
except ValueError:
pass
def test_first_assignment():
"""
>>> gen = test_first_assignment()
>>> next(gen)
5
>>> next(gen)
10
>>> next(gen)
(5, 10)
"""
cdef x = 5 # first
yield x
cdef y = 10 # first
yield y
yield (x,y)
def test_swap_assignment():
"""
>>> gen = test_swap_assignment()
>>> next(gen)
(5, 10)
>>> next(gen)
(10, 5)
"""
x,y = 5,10
yield (x,y)
x,y = y,x # no ref-counting here
yield (x,y)
class Foo(object):
"""
>>> obj = Foo()
>>> list(obj.simple(1, 2, 3))
[1, 2, 3]
"""
def simple(self, *args):
for i in args:
yield i
def generator_nonlocal():
"""
>>> g = generator_nonlocal()
>>> list(g(5))
[2, 3, 4, 5, 6]
"""
def f(x):
def g(y):
nonlocal x
for i in range(y):
x += 1
yield x
return g
return f(1)
def test_nested(a, b, c):
"""
>>> obj = test_nested(1, 2, 3)
>>> [i() for i in obj]
[1, 2, 3, 4]
"""
def one():
return a
def two():
return b
def three():
return c
def new_closure(a, b):
def sum():
return a + b
return sum
yield one
yield two
yield three
yield new_closure(a, c)
def tolist(func):
def wrapper(*args, **kwargs):
return list(func(*args, **kwargs))
return wrapper
@tolist
def test_decorated(*args):
"""
>>> test_decorated(1, 2, 3)
[1, 2, 3]
"""
for i in args:
yield i
def test_return(a):
"""
>>> d = dict()
>>> obj = test_return(d)
>>> next(obj)
1
>>> next(obj)
Traceback (most recent call last):
StopIteration
>>> d['i_was_here']
True
"""
yield 1
a['i_was_here'] = True
return
def test_copied_yield(foo):
"""
>>> class Manager(object):
... def __enter__(self):
... return self
... def __exit__(self, type, value, tb):
... pass
>>> list(test_copied_yield(Manager()))
[1]
"""
with foo:
yield 1
def test_nested_yield():
"""
>>> obj = test_nested_yield()
>>> next(obj)
1
>>> obj.send(2)
2
>>> obj.send(3)
3
>>> obj.send(4)
Traceback (most recent call last):
StopIteration
"""
yield (yield (yield 1))
def test_inside_lambda():
"""
>>> obj = test_inside_lambda()()
>>> next(obj)
1
>>> obj.send('a')
2
>>> obj.send('b')
('a', 'b')
"""
return lambda:((yield 1), (yield 2))
def test_nested_gen(int n):
"""
>>> [list(a) for a in test_nested_gen(5)]
[[], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3]]
"""
for a in range(n):
yield (b for b in range(a))
def test_lambda(n):
"""
>>> [i() for i in test_lambda(3)]
[0, 1, 2]
"""
for i in range(n):
yield lambda : i
def simple():
"""
>>> simple()
1
2
"""
x = 1
y = 2
def f():
nonlocal x
nonlocal x, y
print(x)
print(y)
f()
def assign():
"""
>>> assign()
1
"""
xx = 0
def ff():
nonlocal xx
xx += 1
print(xx)
ff()
def nested():
"""
>>> nested()
1
"""
x = 0
def fx():
def gx():
nonlocal x
x=1
print(x)
return gx
fx()()
def arg(x):
"""
>>> arg('x')
xyy
"""
def appendy():
nonlocal x
x += 'y'
x+='y'
appendy()
print x
return
def argtype(int n):
"""
>>> argtype(0)
1
"""
def inc():
nonlocal n
n += 1
inc()
print n
return
def ping_pong():
"""
>>> f = ping_pong()
>>> inc, dec = f(0)
>>> inc()
1
>>> inc()
2
>>> dec()
1
>>> inc()
2
>>> dec()
1
>>> dec()
0
"""
def f(x):
def inc():
nonlocal x
x += 1
return x
def dec():
nonlocal x
x -= 1
return x
return inc, dec
return f
def methods():
"""
>>> f = methods()
>>> c = f(0)
>>> c.inc()
1
>>> c.inc()
2
>>> c.dec()
1
>>> c.dec()
0
"""
def f(x):
class c:
def inc(self):
nonlocal x
x += 1
return x
def dec(self):
nonlocal x
x -= 1
return x
return c()
return f
def class_body(int x, y):
"""
>>> c = class_body(2,99)
>>> c.z
(3, 2)
>>> c.x #doctest: +ELLIPSIS
Traceback (most recent call last):
AttributeError: ...
>>> c.y #doctest: +ELLIPSIS
Traceback (most recent call last):
AttributeError: ...
"""
class c(object):
nonlocal x
nonlocal y
y = 2
x += 1
z = x,y
return c()
def nested_nonlocals(x):
"""
>>> g = nested_nonlocals(1)
>>> h = g()
>>> h()
3
"""
def g():
nonlocal x
x -= 2
def h():
nonlocal x
x += 4
return x
return h
return g
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment