Commit 6e3c7cee authored by Stefan Behnel's avatar Stefan Behnel

in generators/coroutines, save away the current exception in the 'return' case...

in generators/coroutines, save away the current exception in the 'return' case of finally clauses as 'return' actually raises an (Async)StopIteration exception
parent c6236d8e
...@@ -6994,6 +6994,7 @@ class TryFinallyStatNode(StatNode): ...@@ -6994,6 +6994,7 @@ class TryFinallyStatNode(StatNode):
# body StatNode # body StatNode
# finally_clause StatNode # finally_clause StatNode
# finally_except_clause deep-copy of finally_clause for exception case # finally_except_clause deep-copy of finally_clause for exception case
# in_generator inside of generator => must store away current exception also in return case
# #
# Each of the continue, break, return and error gotos runs # Each of the continue, break, return and error gotos runs
# into its own deep-copy of the finally block code. # into its own deep-copy of the finally block code.
...@@ -7011,6 +7012,7 @@ class TryFinallyStatNode(StatNode): ...@@ -7011,6 +7012,7 @@ class TryFinallyStatNode(StatNode):
finally_except_clause = None finally_except_clause = None
is_try_finally_in_nogil = False is_try_finally_in_nogil = False
in_generator = False
@staticmethod @staticmethod
def create_analysed(pos, env, body, finally_clause): def create_analysed(pos, env, body, finally_clause):
...@@ -7129,16 +7131,25 @@ class TryFinallyStatNode(StatNode): ...@@ -7129,16 +7131,25 @@ class TryFinallyStatNode(StatNode):
code.set_all_labels(old_labels) code.set_all_labels(old_labels)
return_label = code.return_label return_label = code.return_label
exc_vars = ()
for i, (new_label, old_label) in enumerate(zip(new_labels, old_labels)): for i, (new_label, old_label) in enumerate(zip(new_labels, old_labels)):
if not code.label_used(new_label): if not code.label_used(new_label):
continue continue
if new_label == new_error_label and preserve_error: if new_label == new_error_label and preserve_error:
continue # handled above continue # handled above
code.put('%s: ' % new_label) code.putln('%s: {' % new_label)
code.putln('{')
ret_temp = None ret_temp = None
if old_label == return_label and not self.finally_clause.is_terminator: if old_label == return_label:
# return actually raises an (uncatchable) exception in generators that we must preserve
if self.in_generator:
code.putln("__Pyx_PyThreadState_declare")
exc_vars = tuple([
code.funcstate.allocate_temp(py_object_type, manage_ref=False)
for _ in range(6)])
self.put_error_catcher(code, [], exc_vars)
if not self.finally_clause.is_terminator:
# store away return value for later reuse # store away return value for later reuse
if (self.func_return_type and if (self.func_return_type and
not self.is_try_finally_in_nogil and not self.is_try_finally_in_nogil and
...@@ -7148,13 +7159,19 @@ class TryFinallyStatNode(StatNode): ...@@ -7148,13 +7159,19 @@ class TryFinallyStatNode(StatNode):
code.putln("%s = %s;" % (ret_temp, Naming.retval_cname)) code.putln("%s = %s;" % (ret_temp, Naming.retval_cname))
if self.func_return_type.is_pyobject: if self.func_return_type.is_pyobject:
code.putln("%s = 0;" % Naming.retval_cname) code.putln("%s = 0;" % Naming.retval_cname)
fresh_finally_clause().generate_execution_code(code) fresh_finally_clause().generate_execution_code(code)
if old_label == return_label:
if ret_temp: if ret_temp:
code.putln("%s = %s;" % (Naming.retval_cname, ret_temp)) code.putln("%s = %s;" % (Naming.retval_cname, ret_temp))
if self.func_return_type.is_pyobject: if self.func_return_type.is_pyobject:
code.putln("%s = 0;" % ret_temp) code.putln("%s = 0;" % ret_temp)
code.funcstate.release_temp(ret_temp) code.funcstate.release_temp(ret_temp)
ret_temp = None ret_temp = None
if self.in_generator:
self.put_error_uncatcher(code, exc_vars)
if not self.finally_clause.is_terminator: if not self.finally_clause.is_terminator:
code.put_goto(old_label) code.put_goto(old_label)
code.putln('}') code.putln('}')
...@@ -7169,7 +7186,7 @@ class TryFinallyStatNode(StatNode): ...@@ -7169,7 +7186,7 @@ class TryFinallyStatNode(StatNode):
self.finally_clause.generate_function_definitions(env, code) self.finally_clause.generate_function_definitions(env, code)
def put_error_catcher(self, code, temps_to_clean_up, exc_vars, def put_error_catcher(self, code, temps_to_clean_up, exc_vars,
exc_lineno_cnames, exc_filename_cname): exc_lineno_cnames=None, exc_filename_cname=None):
code.globalstate.use_utility_code(restore_exception_utility_code) code.globalstate.use_utility_code(restore_exception_utility_code)
code.globalstate.use_utility_code(get_exception_utility_code) code.globalstate.use_utility_code(get_exception_utility_code)
code.globalstate.use_utility_code(swap_exception_utility_code) code.globalstate.use_utility_code(swap_exception_utility_code)
...@@ -7202,12 +7219,13 @@ class TryFinallyStatNode(StatNode): ...@@ -7202,12 +7219,13 @@ class TryFinallyStatNode(StatNode):
if self.is_try_finally_in_nogil: if self.is_try_finally_in_nogil:
code.put_release_ensured_gil() code.put_release_ensured_gil()
def put_error_uncatcher(self, code, exc_vars, exc_lineno_cnames, exc_filename_cname): def put_error_uncatcher(self, code, exc_vars, exc_lineno_cnames=None, exc_filename_cname=None):
code.globalstate.use_utility_code(restore_exception_utility_code) code.globalstate.use_utility_code(restore_exception_utility_code)
code.globalstate.use_utility_code(reset_exception_utility_code) code.globalstate.use_utility_code(reset_exception_utility_code)
if self.is_try_finally_in_nogil: if self.is_try_finally_in_nogil:
code.put_ensure_gil(declare_gilstate=False) code.put_ensure_gil(declare_gilstate=False)
if self.in_generator:
code.putln("__Pyx_PyThreadState_assign") # re-assign in case a generator yielded code.putln("__Pyx_PyThreadState_assign") # re-assign in case a generator yielded
# not using preprocessor here to avoid warnings about # not using preprocessor here to avoid warnings about
...@@ -7235,6 +7253,7 @@ class TryFinallyStatNode(StatNode): ...@@ -7235,6 +7253,7 @@ class TryFinallyStatNode(StatNode):
code.globalstate.use_utility_code(reset_exception_utility_code) code.globalstate.use_utility_code(reset_exception_utility_code)
if self.is_try_finally_in_nogil: if self.is_try_finally_in_nogil:
code.put_ensure_gil(declare_gilstate=False) code.put_ensure_gil(declare_gilstate=False)
if self.in_generator:
code.putln("__Pyx_PyThreadState_assign") # re-assign in case a generator yielded code.putln("__Pyx_PyThreadState_assign") # re-assign in case a generator yielded
# not using preprocessor here to avoid warnings about # not using preprocessor here to avoid warnings about
......
...@@ -51,6 +51,7 @@ cdef class AlignFunctionDefinitions(CythonTransform): ...@@ -51,6 +51,7 @@ cdef class AlignFunctionDefinitions(CythonTransform):
cdef class YieldNodeCollector(TreeVisitor): cdef class YieldNodeCollector(TreeVisitor):
cdef public list yields cdef public list yields
cdef public list returns cdef public list returns
cdef public list finallys
cdef public bint has_return_value cdef public bint has_return_value
cdef public bint has_yield cdef public bint has_yield
cdef public bint has_await cdef public bint has_await
......
...@@ -2458,6 +2458,7 @@ class YieldNodeCollector(TreeVisitor): ...@@ -2458,6 +2458,7 @@ class YieldNodeCollector(TreeVisitor):
super(YieldNodeCollector, self).__init__() super(YieldNodeCollector, self).__init__()
self.yields = [] self.yields = []
self.returns = [] self.returns = []
self.finallys = []
self.has_return_value = False self.has_return_value = False
self.has_yield = False self.has_yield = False
self.has_await = False self.has_await = False
...@@ -2481,6 +2482,10 @@ class YieldNodeCollector(TreeVisitor): ...@@ -2481,6 +2482,10 @@ class YieldNodeCollector(TreeVisitor):
self.has_return_value = True self.has_return_value = True
self.returns.append(node) self.returns.append(node)
def visit_TryFinallyStatNode(self, node):
self.visitchildren(node)
self.finallys.append(node)
def visit_ClassDefNode(self, node): def visit_ClassDefNode(self, node):
pass pass
...@@ -2531,7 +2536,7 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2531,7 +2536,7 @@ class MarkClosureVisitor(CythonTransform):
for i, yield_expr in enumerate(collector.yields, 1): for i, yield_expr in enumerate(collector.yields, 1):
yield_expr.label_num = i yield_expr.label_num = i
for retnode in collector.returns: for retnode in collector.returns + collector.finallys:
retnode.in_generator = True retnode.in_generator = True
gbody = Nodes.GeneratorBodyDefNode( gbody = Nodes.GeneratorBodyDefNode(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment