Commit 7eca82ec authored by Craig Citro's avatar Craig Citro

Add changes to allow def statements anywhere they're legal.

parent ad9a2205
......@@ -516,6 +516,9 @@ class ExprNode(Node):
for sub in self.subexpr_nodes():
sub.free_temps(code)
def generate_function_definitions(self, env, code):
pass
# ---------------- Annotation ---------------------
def annotate(self, code):
......
......@@ -1606,7 +1606,9 @@ class PyArgDeclNode(Node):
# name string
# entry Symtab.Entry
child_attrs = []
def generate_function_definitions(self, env, code):
self.entry.generate_function_definitions(env, code)
class DecoratorNode(Node):
# A decorator
......@@ -2918,6 +2920,9 @@ class ExprStatNode(StatNode):
self.expr.generate_disposal_code(code)
self.expr.free_temps(code)
def generate_function_definitions(self, env, code):
self.expr.generate_function_definitions(env, code)
def annotate(self, code):
self.expr.annotate(code)
......@@ -3036,6 +3041,9 @@ class SingleAssignmentNode(AssignmentNode):
def generate_assignment_code(self, code):
self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code):
self.lhs.annotate(code)
self.rhs.annotate(code)
......@@ -3088,6 +3096,9 @@ class CascadedAssignmentNode(AssignmentNode):
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code):
for i in range(len(self.lhs_list)):
lhs = self.lhs_list[i].annotate(code)
......@@ -3131,13 +3142,17 @@ class ParallelAssignmentNode(AssignmentNode):
for stat in self.stats:
stat.generate_assignment_code(code)
def generate_function_definitions(self, env, code):
for stat in self.stats:
stat.generate_function_definitions(env, code)
def annotate(self, code):
for stat in self.stats:
stat.annotate(code)
class InPlaceAssignmentNode(AssignmentNode):
# An in place arithmatic operand:
# An in place arithmetic operand:
#
# a += b
# a -= b
......@@ -3327,6 +3342,10 @@ class PrintStatNode(StatNode):
self.arg_tuple.generate_disposal_code(code)
self.arg_tuple.free_temps(code)
def generate_function_definitions(self, env, code):
for item in self.arg_tuple:
item.generate_function_definitions(env, code)
def annotate(self, code):
self.arg_tuple.annotate(code)
......@@ -3510,6 +3529,10 @@ class ReturnStatNode(StatNode):
for cname, type in code.funcstate.temps_holding_reference():
code.put_decref_clear(cname, type)
code.put_goto(code.return_label)
def generate_function_definitions(self, env, code):
if self.value is not None:
self.value.generate_function_definitions(env, code)
def annotate(self, code):
if self.value:
......@@ -3568,6 +3591,14 @@ class RaiseStatNode(StatNode):
code.putln(
code.error_goto(self.pos))
def generate_function_definitions(self, env, code):
if self.exc_type is not None:
self.exc_type.generate_function_definitions(env, code)
if self.exc_value is not None:
self.exc_value.generate_function_definitions(env, code)
if self.exc_tb is not None:
self.exc_tb.generate_function_definitions(env, code)
def annotate(self, code):
if self.exc_type:
self.exc_type.annotate(code)
......@@ -3642,6 +3673,11 @@ class AssertStatNode(StatNode):
self.cond.free_temps(code)
code.putln("#endif")
def generate_function_definitions(self, env, code):
self.cond.generate_function_definitions(env, code)
if self.value is not None:
self.value.generate_function_definitions(env, code)
def annotate(self, code):
self.cond.annotate(code)
if self.value:
......@@ -3687,6 +3723,12 @@ class IfStatNode(StatNode):
self.else_clause.generate_execution_code(code)
code.putln("}")
code.put_label(end_label)
def generate_function_definitions(self, env, code):
for clause in self.if_clauses:
clause.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
for if_clause in self.if_clauses:
......@@ -3729,6 +3771,10 @@ class IfClauseNode(Node):
code.put_goto(end_label)
code.putln("}")
def generate_function_definitions(self, env, code):
self.condition.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code):
self.condition.annotate(code)
self.body.annotate(code)
......@@ -3749,6 +3795,11 @@ class SwitchCaseNode(StatNode):
code.putln("case %s:" % cond.result())
self.body.generate_execution_code(code)
code.putln("break;")
def generate_function_definitions(self, env, code):
for cond in self.conditions:
cond.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code):
for cond in self.conditions:
......@@ -3774,6 +3825,13 @@ class SwitchStatNode(StatNode):
code.putln("break;")
code.putln("}")
def generate_function_definitions(self, env, code):
self.test.generate_function_definitions(env, code)
for case in self.cases:
case.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.test.annotate(code)
for case in self.cases:
......@@ -3834,6 +3892,12 @@ class WhileStatNode(LoopNode, StatNode):
code.putln("}")
code.put_label(break_label)
def generate_function_definitions(self, env, code):
self.condition.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.condition.annotate(code)
self.body.annotate(code)
......@@ -3898,6 +3962,13 @@ class ForInStatNode(LoopNode, StatNode):
self.iterator.generate_disposal_code(code)
self.iterator.free_temps(code)
def generate_function_definitions(self, env, code):
self.target.generate_function_definitions(env, code)
self.iterator.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.target.annotate(code)
self.iterator.annotate(code)
......@@ -4087,13 +4158,23 @@ class ForFromStatNode(LoopNode, StatNode):
'>=': ("", "--"),
'>' : ("-1", "--")
}
def generate_function_definitions(self, env, code):
self.target.generate_function_definitions(env, code)
self.bound1.generate_function_definitions(env, code)
self.bound2.generate_function_definitions(env, code)
if self.step is not None:
self.step.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.target.annotate(code)
self.bound1.annotate(code)
self.bound2.annotate(code)
if self.step:
self.bound2.annotate(code)
self.step.annotate(code)
self.body.annotate(code)
if self.else_clause:
self.else_clause.annotate(code)
......@@ -4248,6 +4329,13 @@ class TryExceptStatNode(StatNode):
code.continue_label = old_continue_label
code.error_label = old_error_label
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
for except_clause in self.except_clauses:
except_clause.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.body.annotate(code)
for except_node in self.except_clauses:
......@@ -4386,6 +4474,11 @@ class ExceptClauseNode(Node):
code.putln(
"}")
def generate_function_definitions(self, env, code):
if self.target is not None:
self.target.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
def annotate(self, code):
if self.pattern:
self.pattern.annotate(code)
......@@ -4533,6 +4626,10 @@ class TryFinallyStatNode(StatNode):
code.putln(
"}")
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
self.finally_clause.generate_function_definitions(env, code)
def put_error_catcher(self, code, error_label, i, catch_label, temps_to_clean_up):
code.globalstate.use_utility_code(restore_exception_utility_code)
code.putln(
......
......@@ -915,7 +915,6 @@ class CreateClosureClasses(CythonTransform):
return node
def create_class_from_scope(self, node, target_module_scope):
as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
func_scope = node.local_scope
......@@ -931,7 +930,8 @@ class CreateClosureClasses(CythonTransform):
type=node.entry.scope.scope_class.type,
is_cdef=True)
for entry in func_scope.entries.values():
# This is wasteful--we should do this later when we know which vars are actually being used inside...
# This is wasteful--we should do this later when we know
# which vars are actually being used inside...
cname = entry.cname
class_scope.declare_var(pos=entry.pos,
name=entry.name,
......
......@@ -1658,8 +1658,8 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api:
error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'):
s.error('def statement not allowed here')
#if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'):
# s.error('def statement not allowed here')
s.level = ctx.level
return p_def_statement(s, decorators)
elif s.sy == 'class':
......
......@@ -249,7 +249,7 @@ class VisitorTransform(TreeVisitor):
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
Certain common conventions and utilities for Cython transforms.
- Sets up the context of the pipeline in self.context
- Tracks directives in effect in self.current_directives
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment