Commit 7445f6fc authored by Stefan Behnel's avatar Stefan Behnel

support for inlining the __enter__() method call in with statements

parent 263b0091
...@@ -5144,21 +5144,26 @@ class WithStatNode(StatNode): ...@@ -5144,21 +5144,26 @@ class WithStatNode(StatNode):
# manager The with statement manager object # manager The with statement manager object
# target ExprNode the target lhs of the __enter__() call # target ExprNode the target lhs of the __enter__() call
# body StatNode # body StatNode
# enter_call ExprNode the call to the __enter__() method
child_attrs = ["manager", "target", "body"] child_attrs = ["manager", "target", "body", "enter_call"]
enter_call = None
has_target = False has_target = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.manager.analyse_declarations(env) self.manager.analyse_declarations(env)
self.enter_call.analyse_declarations(env)
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.manager.analyse_types(env) self.manager.analyse_types(env)
self.enter_call.analyse_types(env)
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.manager.generate_function_definitions(env, code) self.manager.generate_function_definitions(env, code)
self.enter_call.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code) self.body.generate_function_definitions(env, code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
...@@ -5177,34 +5182,17 @@ class WithStatNode(StatNode): ...@@ -5177,34 +5182,17 @@ class WithStatNode(StatNode):
old_error_label = code.new_error_label() old_error_label = code.new_error_label()
intermediate_error_label = code.error_label intermediate_error_label = code.error_label
enter_func = code.funcstate.allocate_temp(py_object_type, manage_ref=True) self.enter_call.generate_evaluation_code(code)
code.putln("%s = PyObject_GetAttr(%s, %s); %s" % ( if not self.target:
enter_func, self.enter_call.generate_disposal_code(code)
self.manager.py_result(), self.enter_call.free_temps(code)
code.get_py_string_const(EncodedString('__enter__'), identifier=True),
code.error_goto_if_null(enter_func, self.pos),
))
code.put_gotref(enter_func)
self.manager.generate_disposal_code(code) self.manager.generate_disposal_code(code)
self.manager.free_temps(code) self.manager.free_temps(code)
self.target_temp.allocate(code)
code.putln('%s = PyObject_Call(%s, ((PyObject *)%s), NULL); %s' % (
self.target_temp.result(),
enter_func,
Naming.empty_tuple,
code.error_goto_if_null(self.target_temp.result(), self.pos),
))
code.put_gotref(self.target_temp.result())
code.put_decref_clear(enter_func, py_object_type)
code.funcstate.release_temp(enter_func)
if not self.has_target:
code.put_decref_clear(self.target_temp.result(), type=py_object_type)
self.target_temp.release(code)
# otherwise, WithTargetAssignmentStatNode will do it for us
code.error_label = old_error_label code.error_label = old_error_label
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if code.label_used(intermediate_error_label):
step_over_label = code.new_label() step_over_label = code.new_label()
code.put_goto(step_over_label) code.put_goto(step_over_label)
code.put_label(intermediate_error_label) code.put_label(intermediate_error_label)
...@@ -5223,7 +5211,8 @@ class WithTargetAssignmentStatNode(AssignmentNode): ...@@ -5223,7 +5211,8 @@ class WithTargetAssignmentStatNode(AssignmentNode):
# and frees its temp. # and frees its temp.
# #
# lhs ExprNode the assignment target # lhs ExprNode the assignment target
# rhs TempNode the return value of the __enter__() call # orig_rhs ExprNode the return value of the __enter__() call (not owned by this node!)
# rhs ResultRefNode a ResultRefNode for the orig_rhs (owned by this node)
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
...@@ -5234,16 +5223,13 @@ class WithTargetAssignmentStatNode(AssignmentNode): ...@@ -5234,16 +5223,13 @@ class WithTargetAssignmentStatNode(AssignmentNode):
self.rhs.analyse_types(env) self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
self.orig_rhs = self.rhs
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
self.lhs.generate_assignment_code(self.rhs, code) self.lhs.generate_assignment_code(self.rhs, code)
self.orig_rhs.release(code) self.orig_rhs.generate_disposal_code(code)
self.orig_rhs.free_temps(code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
......
...@@ -1164,16 +1164,20 @@ class WithTransform(CythonTransform, SkipDeclarations): ...@@ -1164,16 +1164,20 @@ class WithTransform(CythonTransform, SkipDeclarations):
self.visitchildren(node, 'body') self.visitchildren(node, 'body')
pos = node.pos pos = node.pos
body, target, manager = node.body, node.target, node.manager body, target, manager = node.body, node.target, node.manager
node.target_temp = ExprNodes.TempNode(pos, type=PyrexTypes.py_object_type) node.enter_call = ExprNodes.SimpleCallNode(
pos, function = ExprNodes.AttributeNode(
pos, obj = ResultRefNode(manager),
attribute = EncodedString('__enter__')),
args = [],
is_temp = True)
if target is not None: if target is not None:
node.has_target = True
body = Nodes.StatListNode( body = Nodes.StatListNode(
pos, stats = [ pos, stats = [
Nodes.WithTargetAssignmentStatNode( Nodes.WithTargetAssignmentStatNode(
pos, lhs = target, rhs = node.target_temp), pos, lhs = target,
body rhs = ResultRefNode(node.enter_call),
]) orig_rhs = node.enter_call),
node.target = None body])
excinfo_target = ResultRefNode( excinfo_target = ResultRefNode(
pos=pos, type=Builtin.tuple_type, may_hold_none=False) pos=pos, type=Builtin.tuple_type, may_hold_none=False)
......
...@@ -539,12 +539,47 @@ def with_statement(): ...@@ -539,12 +539,47 @@ def with_statement():
""" """
>>> with_statement() >>> with_statement()
Python object Python object
'Python object' Python object
""" """
x = 1.0 x = 1.0
with EmptyContextManager() as x: with EmptyContextManager() as x:
print(typeof(x)) print(typeof(x))
return typeof(x) print(typeof(x))
return x
@cython.final
cdef class TypedContextManager(object):
cpdef double __enter__(self):
return 2.0
def __exit__(self, *args):
return 0
def with_statement_typed():
"""
>>> with_statement_typed()
double
double
2.0
"""
x = 1.0
with TypedContextManager() as x:
print(typeof(x))
print(typeof(x))
return x
def with_statement_untyped():
"""
>>> with_statement_untyped()
Python object
Python object
2.0
"""
x = 1.0
cdef object t = TypedContextManager()
with t as x:
print(typeof(x))
print(typeof(x))
return x
# Regression test for trac #638. # Regression test for trac #638.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment