Commit 0ca78ef8 authored by Stefan Behnel's avatar Stefan Behnel

implement sum(genexp) as inlined genexp loop

parent 84ee8e9c
...@@ -1085,6 +1085,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1085,6 +1085,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self.yield_nodes.append(node) self.yield_nodes.append(node)
self.visitchildren(node) self.visitchildren(node)
def _find_single_yield_node(self, node):
collector = self.YieldNodeCollector()
collector.visitchildren(node)
if len(collector.yield_nodes) != 1:
return None
return collector.yield_nodes[0]
def _handle_simple_function_all(self, node, pos_args): def _handle_simple_function_all(self, node, pos_args):
"""Transform """Transform
...@@ -1132,14 +1139,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1132,14 +1139,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node)
collector = self.YieldNodeCollector() if yield_node is None:
collector.visitchildren(loop_node)
if len(collector.yield_nodes) != 1:
return node return node
yield_node = collector.yield_nodes[0]
yield_expression = yield_node.arg yield_expression = yield_node.arg
del collector
if is_any: if is_any:
condition = yield_expression condition = yield_expression
...@@ -1185,6 +1188,48 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1185,6 +1188,48 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node.pos, loop = loop_node, result_node = result_ref, gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope) expr_scope = gen_expr_node.expr_scope)
def _handle_simple_function_sum(self, node, pos_args):
if len(pos_args) not in (1,2):
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node)
if yield_node is None:
return node
yield_expression = yield_node.arg
if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
else:
start = pos_args[1]
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
add_node = Nodes.SingleAssignmentNode(
yield_node.pos,
lhs = result_ref,
rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
)
Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
exec_code = Nodes.StatListNode(
node.pos,
stats = [
Nodes.SingleAssignmentNode(
start.pos,
lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
rhs = start,
first = True),
loop_node
])
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
# specific handlers for general call nodes # specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs): def _handle_general_function_dict(self, node, pos_args, kwargs):
......
...@@ -144,6 +144,11 @@ class ResultRefNode(AtomicExprNode): ...@@ -144,6 +144,11 @@ class ResultRefNode(AtomicExprNode):
return True return True
def result(self): def result(self):
try:
return self.result_code
except AttributeError:
if self.expression is not None:
self.result_code = self.expression.result()
return self.result_code return self.result_code
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
......
def range_sum(int N):
"""
>>> sum(range(10))
45
>>> range_sum(10)
45
"""
result = sum(i for i in range(N))
return result
def return_range_sum(int N):
"""
>>> sum(range(10))
45
>>> return_range_sum(10)
45
"""
return sum(i for i in range(N))
def return_range_sum_squares(int N):
"""
>>> sum([i*i for i in range(10)])
285
>>> return_range_sum_squares(10)
285
>>> sum([i*i for i in range(10000)])
333283335000
>>> return_range_sum_squares(10000)
333283335000
"""
return sum(i*i for i in range(N))
def return_sum_squares(seq):
"""
>>> sum([i*i for i in range(10)])
285
>>> return_sum_squares(range(10))
285
>>> sum([i*i for i in range(10000)])
333283335000
>>> return_sum_squares(range(10000))
333283335000
"""
return sum(i*i for i in seq)
def return_sum_squares_start(seq, int start):
"""
>>> sum([i*i for i in range(10)], -1)
284
>>> return_sum_squares_start(range(10), -1)
284
>>> sum([i*i for i in range(10000)], 9)
333283335009
>>> return_sum_squares_start(range(10000), 9)
333283335009
"""
return sum((i*i for i in seq), start)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment