Commit b20ed8dd authored by scoder's avatar scoder Committed by GitHub

Merge pull request #1669 from shalabhc/generator-profiling-fix

Fix profiling for generators and generator expressions
parents 3377d3e3 75d36215
...@@ -9287,7 +9287,13 @@ class YieldExprNode(ExprNode): ...@@ -9287,7 +9287,13 @@ class YieldExprNode(ExprNode):
code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname)) code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname))
code.put_xgiveref(Naming.retval_cname) code.put_xgiveref(Naming.retval_cname)
profile = code.globalstate.directives['profile']
linetrace = code.globalstate.directives['linetrace']
if profile or linetrace:
code.put_trace_return(Naming.retval_cname,
nogil=not code.funcstate.gil_owned)
code.put_finish_refcount_context() code.put_finish_refcount_context()
code.putln("/* return from generator, yielding value */") code.putln("/* return from generator, yielding value */")
code.putln("%s->resume_label = %d;" % ( code.putln("%s->resume_label = %d;" % (
Naming.generator_cname, label_num)) Naming.generator_cname, label_num))
......
...@@ -1805,9 +1805,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1805,9 +1805,11 @@ class FuncDefNode(StatNode, BlockNode):
code.declare_gilstate() code.declare_gilstate()
if profile or linetrace: if profile or linetrace:
tempvardecl_code.put_trace_declarations() if not self.is_generator:
code_object = self.code_object.calculate_result_code(code) if self.code_object else None # generators are traced when iterated, not at creation
code.put_trace_frame_init(code_object) tempvardecl_code.put_trace_declarations()
code_object = self.code_object.calculate_result_code(code) if self.code_object else None
code.put_trace_frame_init(code_object)
# ----- set up refnanny # ----- set up refnanny
if use_refnanny: if use_refnanny:
...@@ -1862,12 +1864,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1862,12 +1864,14 @@ class FuncDefNode(StatNode, BlockNode):
if profile or linetrace: if profile or linetrace:
# this looks a bit late, but if we don't get here due to a # this looks a bit late, but if we don't get here due to a
# fatal error before hand, it's not really worth tracing # fatal error before hand, it's not really worth tracing
if self.is_wrapper: if not self.is_generator:
trace_name = self.entry.name + " (wrapper)" # generators are traced when iterated, not at creation
else: if self.is_wrapper:
trace_name = self.entry.name trace_name = self.entry.name + " (wrapper)"
code.put_trace_call( else:
trace_name, self.pos, nogil=not code.funcstate.gil_owned) trace_name = self.entry.name
code.put_trace_call(
trace_name, self.pos, nogil=not code.funcstate.gil_owned)
code.funcstate.can_trace = True code.funcstate.can_trace = True
# ----- Fetch arguments # ----- Fetch arguments
self.generate_argument_parsing_code(env, code) self.generate_argument_parsing_code(env, code)
...@@ -2064,12 +2068,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2064,12 +2068,14 @@ class FuncDefNode(StatNode, BlockNode):
if profile or linetrace: if profile or linetrace:
code.funcstate.can_trace = False code.funcstate.can_trace = False
if self.return_type.is_pyobject: if not self.is_generator:
code.put_trace_return( # generators are traced when iterated, not at creation
Naming.retval_cname, nogil=not code.funcstate.gil_owned) if self.return_type.is_pyobject:
else: code.put_trace_return(
code.put_trace_return( Naming.retval_cname, nogil=not code.funcstate.gil_owned)
"Py_None", nogil=not code.funcstate.gil_owned) else:
code.put_trace_return(
"Py_None", nogil=not code.funcstate.gil_owned)
if not lenv.nogil: if not lenv.nogil:
# GIL holding function # GIL holding function
...@@ -4041,6 +4047,10 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4041,6 +4047,10 @@ class GeneratorBodyDefNode(DefNode):
tempvardecl_code = code.insertion_point() tempvardecl_code = code.insertion_point()
code.put_declare_refcount_context() code.put_declare_refcount_context()
code.put_setup_refcount_context(self.entry.name) code.put_setup_refcount_context(self.entry.name)
profile = code.globalstate.directives['profile']
linetrace = code.globalstate.directives['linetrace']
if profile or linetrace:
code.put_trace_declarations()
# ----- Resume switch point. # ----- Resume switch point.
code.funcstate.init_closure_temps(lenv.scope_class.type.scope) code.funcstate.init_closure_temps(lenv.scope_class.type.scope)
...@@ -4112,6 +4122,9 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4112,6 +4122,9 @@ class GeneratorBodyDefNode(DefNode):
code.putln('%s->resume_label = -1;' % Naming.generator_cname) code.putln('%s->resume_label = -1;' % Naming.generator_cname)
# clean up as early as possible to help breaking any reference cycles # clean up as early as possible to help breaking any reference cycles
code.putln('__Pyx_Coroutine_clear((PyObject*)%s);' % Naming.generator_cname) code.putln('__Pyx_Coroutine_clear((PyObject*)%s);' % Naming.generator_cname)
if profile or linetrace:
code.put_trace_return(Naming.retval_cname,
nogil=not code.funcstate.gil_owned)
code.put_finish_refcount_context() code.put_finish_refcount_context()
code.putln("return %s;" % Naming.retval_cname) code.putln("return %s;" % Naming.retval_cname)
code.putln("}") code.putln("}")
...@@ -4119,13 +4132,20 @@ class GeneratorBodyDefNode(DefNode): ...@@ -4119,13 +4132,20 @@ class GeneratorBodyDefNode(DefNode):
# ----- Go back and insert temp variable declarations # ----- Go back and insert temp variable declarations
tempvardecl_code.put_temp_declarations(code.funcstate) tempvardecl_code.put_temp_declarations(code.funcstate)
# ----- Generator resume code # ----- Generator resume code
if profile or linetrace:
resume_code.put_trace_call(self.entry.qualified_name, self.pos,
nogil=not code.funcstate.gil_owned)
resume_code.putln("switch (%s->resume_label) {" % ( resume_code.putln("switch (%s->resume_label) {" % (
Naming.generator_cname)) Naming.generator_cname))
resume_code.putln("case 0: goto %s;" % first_run_label) resume_code.putln("case 0: goto %s;" % first_run_label)
for i, label in code.yield_labels: for i, label in code.yield_labels:
resume_code.putln("case %d: goto %s;" % (i, label)) resume_code.putln("case %d: goto %s;" % (i, label))
resume_code.putln("default: /* CPython raises the right error here */") resume_code.putln("default: /* CPython raises the right error here */")
if profile or linetrace:
resume_code.put_trace_return("Py_None",
nogil=not code.funcstate.gil_owned)
resume_code.put_finish_refcount_context() resume_code.put_finish_refcount_context()
resume_code.putln("return NULL;") resume_code.putln("return NULL;")
resume_code.putln("}") resume_code.putln("}")
......
...@@ -65,6 +65,58 @@ __doc__ = u""" ...@@ -65,6 +65,58 @@ __doc__ = u"""
'f_raise', 'f_raise',
'm_cdef', 'm_cpdef', 'm_cpdef (wrapper)', 'm_def', 'm_cdef', 'm_cpdef', 'm_cpdef (wrapper)', 'm_def',
'withgil_prof'] 'withgil_prof']
>>> profile.runctx("test_generators()", locals(), globals(), statsfile)
>>> s = pstats.Stats(statsfile)
>>> short_stats = dict([(k[2], v[1]) for k,v in s.stats.items()])
>>> short_stats['generator']
3
>>> short_stats['generator_exception']
2
>>> short_stats['genexpr']
11
>>> sorted(callees(s, 'test_generators'))
['call_generator', 'call_generator_exception', 'generator_expr']
>>> list(callees(s, 'call_generator'))
['generator']
>>> list(callees(s, 'generator'))
[]
>>> list(callees(s, 'generator_exception'))
[]
>>> list(callees(s, 'generator_expr'))
['genexpr']
>>> list(callees(s, 'genexpr'))
[]
>>> def python_generator():
... yield 1
... yield 2
>>> def call_python_generator():
... list(python_generator())
>>> profile.runctx("call_python_generator()", locals(), globals(), statsfile)
>>> python_stats = pstats.Stats(statsfile)
>>> python_stats_dict = dict([(k[2], v[1]) for k,v in python_stats.stats.items()])
>>> profile.runctx("call_generator()", locals(), globals(), statsfile)
>>> cython_stats = pstats.Stats(statsfile)
>>> cython_stats_dict = dict([(k[2], v[1]) for k,v in cython_stats.stats.items()])
>>> python_stats_dict['python_generator'] == cython_stats_dict['generator']
True
>>> try:
... os.unlink(statsfile)
... except:
... pass
""" """
cimport cython cimport cython
...@@ -147,3 +199,29 @@ cdef class A(object): ...@@ -147,3 +199,29 @@ cdef class A(object):
return a return a
cdef m_cdef(self, long a): cdef m_cdef(self, long a):
return a return a
def test_generators():
call_generator()
call_generator_exception()
generator_expr()
def call_generator():
list(generator())
def generator():
yield 1
yield 2
def call_generator_exception():
try:
list(generator_exception())
except ValueError:
pass
def generator_exception():
yield 1
raise ValueError(2)
def generator_expr():
e = (x for x in range(10))
return sum(e)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment