Commit 245434dc authored by Stefan Behnel's avatar Stefan Behnel

drop sum(genexpr) into plain C code when the result is C typed

parent 514f2542
......@@ -4037,10 +4037,13 @@ class GeneratorExpressionNode(ScopedExprNode):
class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
# An inlined generator expression for which the result is
# calculated inside of the loop.
# calculated inside of the loop. This will only be created by
# transforms when replacing builtin calls on generator
# expressions.
#
# loop ForStatNode the for-loop, not containing any YieldExprNodes
# result_node ResultRefNode the reference to the result value temp
# orig_func String the name of the builtin function this node replaces
child_attrs = ["loop"]
......@@ -4048,6 +4051,13 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
self.type = self.result_node.type
self.is_temp = True
def coerce_to(self, dst_type, env):
if self.orig_func == 'sum' and dst_type.is_numeric:
# we can optimise by dropping the aggregation variable into C
self.result_node.type = self.type = dst_type
return self
return GeneratorExpressionNode.coerce_to(self, dst_type, env)
def generate_result_code(self, code):
self.result_node.result_code = self.result()
self.loop.generate_execution_code(code)
......
......@@ -1186,7 +1186,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
def _handle_simple_function_sum(self, node, pos_args):
"""Transform sum(genexpr) into an equivalent inlined aggregation loop.
......@@ -1230,7 +1230,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
# specific handlers for general call nodes
......
......@@ -15,6 +15,20 @@ def range_sum(int N):
result = sum(i for i in range(N))
return result
@cython.test_assert_path_exists('//ForFromStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode',
'//ForInStatNode')
def range_sum_typed(int N):
"""
>>> sum(range(10))
45
>>> range_sum_typed(10)
45
"""
cdef int result = sum(i for i in range(N))
return result
@cython.test_assert_path_exists('//ForFromStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment