Commit 749af0fa authored by Stefan Behnel's avatar Stefan Behnel

simplify WithTargetAssignmentStatNode and make it more robust against...

simplify WithTargetAssignmentStatNode and make it more robust against replacements of the context manager node; undo node.result() checking as it broke TempNode's disposal code
parent e3036318
......@@ -344,10 +344,10 @@ class ExprNode(Node):
def result(self):
if self.is_temp:
if not self.temp_code:
pos = (os.path.basename(self.pos[0].get_description()),) + self.pos[1:] if self.pos else '(?)'
raise RuntimeError("temp result name not set in %s at %r" % (
self.__class__.__name__, pos))
#if not self.temp_code:
# pos = (os.path.basename(self.pos[0].get_description()),) + self.pos[1:] if self.pos else '(?)'
# raise RuntimeError("temp result name not set in %s at %r" % (
# self.__class__.__name__, pos))
return self.temp_code
else:
return self.calculate_result_code()
......@@ -620,7 +620,7 @@ class ExprNode(Node):
# postponed from self.generate_evaluation_code()
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
if self.temp_code:
if self.result():
if self.type.is_pyobject:
code.put_decref_clear(self.result(), self.ctype())
elif self.type.is_memoryviewslice:
......@@ -4759,7 +4759,7 @@ class SimpleCallNode(CallNode):
exc_checks.append("PyErr_Occurred()")
if self.is_temp or exc_checks:
rhs = self.c_call_code()
if self.temp_code:
if self.result():
lhs = "%s = " % self.result()
if self.is_temp and self.type.is_pyobject:
#return_type = self.type # func_type.return_type
......
......@@ -1104,7 +1104,7 @@ class ControlFlowAnalysis(CythonTransform):
raise InternalError("Generic loops are not supported")
def visit_WithTargetAssignmentStatNode(self, node):
self.mark_assignment(node.lhs, node.rhs)
self.mark_assignment(node.lhs, node.with_node.enter_call)
return node
def visit_WithStatNode(self, node):
......
......@@ -6124,6 +6124,7 @@ class WithStatNode(StatNode):
child_attrs = ["manager", "enter_call", "target", "body"]
enter_call = None
target_temp = None
def analyse_declarations(self, env):
self.manager.analyse_declarations(env)
......@@ -6133,6 +6134,10 @@ class WithStatNode(StatNode):
def analyse_expressions(self, env):
self.manager = self.manager.analyse_types(env)
self.enter_call = self.enter_call.analyse_types(env)
if self.target:
# set up target_temp before descending into body (which uses it)
from .ExprNodes import TempNode
self.target_temp = TempNode(self.enter_call.pos, self.enter_call.type)
self.body = self.body.analyse_expressions(env)
return self
......@@ -6160,14 +6165,17 @@ class WithStatNode(StatNode):
intermediate_error_label = code.error_label
self.enter_call.generate_evaluation_code(code)
if not self.target:
if self.target:
# The temp result will be cleaned up by the WithTargetAssignmentStatNode
# after assigning its result to the target of the 'with' statement.
self.target_temp.allocate(code)
self.enter_call.make_owned_reference(code)
code.putln("%s = %s;" % (self.target_temp.result(), self.enter_call.result()))
self.enter_call.generate_post_assignment_code(code)
else:
self.enter_call.generate_disposal_code(code)
self.enter_call.free_temps(code)
else:
# Otherwise, the node will be cleaned up by the
# WithTargetAssignmentStatNode after assigning its result
# to the target of the 'with' statement.
pass
self.manager.generate_disposal_code(code)
self.manager.free_temps(code)
......@@ -6185,52 +6193,34 @@ class WithStatNode(StatNode):
code.funcstate.release_temp(self.exit_var)
code.putln('}')
class WithTargetAssignmentStatNode(AssignmentNode):
# The target assignment of the 'with' statement value (return
# value of the __enter__() call).
#
# This is a special cased assignment that steals the RHS reference
# and frees its temp.
# This is a special cased assignment that properly cleans up the RHS.
#
# lhs ExprNode the assignment target
# rhs CloneNode a (coerced) CloneNode for the orig_rhs (not owned by this node)
# orig_rhs ExprNode the original ExprNode of the rhs. this node will clean up the
# temps of the orig_rhs. basically, it takes ownership of the node
# when the WithStatNode is done with it.
# rhs ExprNode a (coerced) TempNode for the rhs (from WithStatNode)
# with_node WithStatNode the surrounding with-statement
child_attrs = ["lhs"]
child_attrs = ["rhs", "lhs"]
with_node = None
rhs = None
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
def analyse_expressions(self, env):
self.rhs = self.rhs.analyse_types(env)
self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env)
self.rhs = self.rhs.coerce_to(self.lhs.type, env)
self.rhs = self.with_node.target_temp.coerce_to(self.lhs.type, env)
return self
def generate_execution_code(self, code):
if self.orig_rhs.type.is_pyobject:
# make sure rhs gets freed on errors, see below
old_error_label = code.new_error_label()
intermediate_error_label = code.error_label
self.rhs.generate_evaluation_code(code)
self.lhs.generate_assignment_code(self.rhs, code)
if self.orig_rhs.type.is_pyobject:
self.orig_rhs.generate_disposal_code(code)
code.error_label = old_error_label
if code.label_used(intermediate_error_label):
step_over_label = code.new_label()
code.put_goto(step_over_label)
code.put_label(intermediate_error_label)
self.orig_rhs.generate_disposal_code(code)
code.put_goto(old_error_label)
code.put_label(step_over_label)
self.orig_rhs.free_temps(code)
self.with_node.target_temp.release(code)
def annotate(self, code):
self.lhs.annotate(code)
......
......@@ -1220,20 +1220,19 @@ class WithTransform(CythonTransform, SkipDeclarations):
self.visitchildren(node, 'body')
pos = node.pos
body, target, manager = node.body, node.target, node.manager
node.enter_call = ExprNodes.ProxyNode(ExprNodes.SimpleCallNode(
node.enter_call = ExprNodes.SimpleCallNode(
pos, function=ExprNodes.AttributeNode(
pos, obj=ExprNodes.CloneNode(manager),
attribute=EncodedString('__enter__'),
is_special_lookup=True),
args=[],
is_temp=True))
is_temp=True)
if target is not None:
body = Nodes.StatListNode(
pos, stats=[
Nodes.WithTargetAssignmentStatNode(
pos, lhs=target,
rhs=ResultRefNode(node.enter_call),
orig_rhs=node.enter_call),
pos, lhs=target, with_node=node),
body])
excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[
......
......@@ -68,7 +68,7 @@ class MarkParallelAssignments(EnvTransform):
pass
def visit_WithTargetAssignmentStatNode(self, node):
self.mark_assignment(node.lhs, node.rhs)
self.mark_assignment(node.lhs, node.with_node.enter_call)
self.visitchildren(node)
return node
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment