Commit 749af0fa authored by Stefan Behnel's avatar Stefan Behnel

simplify WithTargetAssignmentStatNode and make it more robust against...

simplify WithTargetAssignmentStatNode and make it more robust against replacements of the context manager node; undo node.result() checking as it broke TempNode's disposal code
parent e3036318
...@@ -344,10 +344,10 @@ class ExprNode(Node): ...@@ -344,10 +344,10 @@ class ExprNode(Node):
def result(self): def result(self):
if self.is_temp: if self.is_temp:
if not self.temp_code: #if not self.temp_code:
pos = (os.path.basename(self.pos[0].get_description()),) + self.pos[1:] if self.pos else '(?)' # pos = (os.path.basename(self.pos[0].get_description()),) + self.pos[1:] if self.pos else '(?)'
raise RuntimeError("temp result name not set in %s at %r" % ( # raise RuntimeError("temp result name not set in %s at %r" % (
self.__class__.__name__, pos)) # self.__class__.__name__, pos))
return self.temp_code return self.temp_code
else: else:
return self.calculate_result_code() return self.calculate_result_code()
...@@ -620,7 +620,7 @@ class ExprNode(Node): ...@@ -620,7 +620,7 @@ class ExprNode(Node):
# postponed from self.generate_evaluation_code() # postponed from self.generate_evaluation_code()
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code) self.free_subexpr_temps(code)
if self.temp_code: if self.result():
if self.type.is_pyobject: if self.type.is_pyobject:
code.put_decref_clear(self.result(), self.ctype()) code.put_decref_clear(self.result(), self.ctype())
elif self.type.is_memoryviewslice: elif self.type.is_memoryviewslice:
...@@ -4759,7 +4759,7 @@ class SimpleCallNode(CallNode): ...@@ -4759,7 +4759,7 @@ class SimpleCallNode(CallNode):
exc_checks.append("PyErr_Occurred()") exc_checks.append("PyErr_Occurred()")
if self.is_temp or exc_checks: if self.is_temp or exc_checks:
rhs = self.c_call_code() rhs = self.c_call_code()
if self.temp_code: if self.result():
lhs = "%s = " % self.result() lhs = "%s = " % self.result()
if self.is_temp and self.type.is_pyobject: if self.is_temp and self.type.is_pyobject:
#return_type = self.type # func_type.return_type #return_type = self.type # func_type.return_type
......
...@@ -1104,7 +1104,7 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -1104,7 +1104,7 @@ class ControlFlowAnalysis(CythonTransform):
raise InternalError("Generic loops are not supported") raise InternalError("Generic loops are not supported")
def visit_WithTargetAssignmentStatNode(self, node): def visit_WithTargetAssignmentStatNode(self, node):
self.mark_assignment(node.lhs, node.rhs) self.mark_assignment(node.lhs, node.with_node.enter_call)
return node return node
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
......
...@@ -6124,6 +6124,7 @@ class WithStatNode(StatNode): ...@@ -6124,6 +6124,7 @@ class WithStatNode(StatNode):
child_attrs = ["manager", "enter_call", "target", "body"] child_attrs = ["manager", "enter_call", "target", "body"]
enter_call = None enter_call = None
target_temp = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.manager.analyse_declarations(env) self.manager.analyse_declarations(env)
...@@ -6133,6 +6134,10 @@ class WithStatNode(StatNode): ...@@ -6133,6 +6134,10 @@ class WithStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.manager = self.manager.analyse_types(env) self.manager = self.manager.analyse_types(env)
self.enter_call = self.enter_call.analyse_types(env) self.enter_call = self.enter_call.analyse_types(env)
if self.target:
# set up target_temp before descending into body (which uses it)
from .ExprNodes import TempNode
self.target_temp = TempNode(self.enter_call.pos, self.enter_call.type)
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self return self
...@@ -6160,14 +6165,17 @@ class WithStatNode(StatNode): ...@@ -6160,14 +6165,17 @@ class WithStatNode(StatNode):
intermediate_error_label = code.error_label intermediate_error_label = code.error_label
self.enter_call.generate_evaluation_code(code) self.enter_call.generate_evaluation_code(code)
if not self.target: if self.target:
self.enter_call.generate_disposal_code(code) # The temp result will be cleaned up by the WithTargetAssignmentStatNode
self.enter_call.free_temps(code) # after assigning its result to the target of the 'with' statement.
self.target_temp.allocate(code)
self.enter_call.make_owned_reference(code)
code.putln("%s = %s;" % (self.target_temp.result(), self.enter_call.result()))
self.enter_call.generate_post_assignment_code(code)
else: else:
# Otherwise, the node will be cleaned up by the self.enter_call.generate_disposal_code(code)
# WithTargetAssignmentStatNode after assigning its result self.enter_call.free_temps(code)
# to the target of the 'with' statement.
pass
self.manager.generate_disposal_code(code) self.manager.generate_disposal_code(code)
self.manager.free_temps(code) self.manager.free_temps(code)
...@@ -6185,52 +6193,34 @@ class WithStatNode(StatNode): ...@@ -6185,52 +6193,34 @@ class WithStatNode(StatNode):
code.funcstate.release_temp(self.exit_var) code.funcstate.release_temp(self.exit_var)
code.putln('}') code.putln('}')
class WithTargetAssignmentStatNode(AssignmentNode): class WithTargetAssignmentStatNode(AssignmentNode):
# The target assignment of the 'with' statement value (return # The target assignment of the 'with' statement value (return
# value of the __enter__() call). # value of the __enter__() call).
# #
# This is a special cased assignment that steals the RHS reference # This is a special cased assignment that properly cleans up the RHS.
# and frees its temp.
# #
# lhs ExprNode the assignment target # lhs ExprNode the assignment target
# rhs CloneNode a (coerced) CloneNode for the orig_rhs (not owned by this node) # rhs ExprNode a (coerced) TempNode for the rhs (from WithStatNode)
# orig_rhs ExprNode the original ExprNode of the rhs. this node will clean up the # with_node WithStatNode the surrounding with-statement
# temps of the orig_rhs. basically, it takes ownership of the node
# when the WithStatNode is done with it.
child_attrs = ["lhs"] child_attrs = ["rhs", "lhs"]
with_node = None
rhs = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.rhs = self.rhs.analyse_types(env)
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.with_node.target_temp.coerce_to(self.lhs.type, env)
return self return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
if self.orig_rhs.type.is_pyobject:
# make sure rhs gets freed on errors, see below
old_error_label = code.new_error_label()
intermediate_error_label = code.error_label
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
self.lhs.generate_assignment_code(self.rhs, code) self.lhs.generate_assignment_code(self.rhs, code)
self.with_node.target_temp.release(code)
if self.orig_rhs.type.is_pyobject:
self.orig_rhs.generate_disposal_code(code)
code.error_label = old_error_label
if code.label_used(intermediate_error_label):
step_over_label = code.new_label()
code.put_goto(step_over_label)
code.put_label(intermediate_error_label)
self.orig_rhs.generate_disposal_code(code)
code.put_goto(old_error_label)
code.put_label(step_over_label)
self.orig_rhs.free_temps(code)
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
......
...@@ -1220,20 +1220,19 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1220,20 +1220,19 @@ class WithTransform(CythonTransform, SkipDeclarations):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
pos = node.pos pos = node.pos
body, target, manager = node.body, node.target, node.manager body, target, manager = node.body, node.target, node.manager
node.enter_call = ExprNodes.ProxyNode(ExprNodes.SimpleCallNode( node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode( pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager), pos, obj=ExprNodes.CloneNode(manager),
attribute=EncodedString('__enter__'), attribute=EncodedString('__enter__'),
is_special_lookup=True), is_special_lookup=True),
args=[], args=[],
is_temp=True)) is_temp=True)
if target is not None: if target is not None:
body = Nodes.StatListNode( body = Nodes.StatListNode(
pos, stats=[ pos, stats=[
Nodes.WithTargetAssignmentStatNode( Nodes.WithTargetAssignmentStatNode(
pos, lhs=target, pos, lhs=target, with_node=node),
rhs=ResultRefNode(node.enter_call),
orig_rhs=node.enter_call),
body]) body])
excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[ excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[
......
...@@ -68,7 +68,7 @@ class MarkParallelAssignments(EnvTransform): ...@@ -68,7 +68,7 @@ class MarkParallelAssignments(EnvTransform):
pass pass
def visit_WithTargetAssignmentStatNode(self, node): def visit_WithTargetAssignmentStatNode(self, node):
self.mark_assignment(node.lhs, node.rhs) self.mark_assignment(node.lhs, node.with_node.enter_call)
self.visitchildren(node) self.visitchildren(node)
return node return node
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment