Commit 25987676 authored by Stefan Behnel's avatar Stefan Behnel

implement any(genexpr) and all(genexpr) as special cased optimisations without requiring generators

parent bbf293be
...@@ -981,8 +981,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -981,8 +981,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if not function.is_name: if not function.is_name:
return False return False
entry = self.current_env().lookup(function.name) entry = self.current_env().lookup(function.name)
if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope: if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
return False return False
# if entry is None, it's at least an undeclared name, so likely builtin
return True return True
def _dispatch_to_handler(self, node, function, args, kwargs=None): def _dispatch_to_handler(self, node, function, args, kwargs=None):
...@@ -1074,6 +1075,121 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1074,6 +1075,121 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self._error_wrong_arg_count('float', node, pos_args, 1) self._error_wrong_arg_count('float', node, pos_args, 1)
return node return node
class YieldNodeCollector(Visitor.TreeVisitor):
def __init__(self):
Visitor.TreeVisitor.__init__(self)
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def _handle_simple_function_all(self, node, pos_args):
"""Transform
_result = all(x for L in LL for x in L)
into
for L in LL:
for x in L:
if not x:
_result = False
break
else:
continue
break
else:
_result = True
"""
return self._transform_any_all(node, pos_args, False)
def _handle_simple_function_any(self, node, pos_args):
"""Transform
_result = any(x for L in LL for x in L)
into
for L in LL:
for x in L:
if x:
_result = True
break
else:
continue
break
else:
_result = False
"""
return self._transform_any_all(node, pos_args, True)
def _transform_any_all(self, node, pos_args, is_any):
if len(pos_args) != 1:
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
loop_node = pos_args[0].loop
collector = self.YieldNodeCollector()
collector.visitchildren(loop_node)
if len(collector.yield_nodes) != 1:
return node
yield_node = collector.yield_nodes[0]
yield_expression = yield_node.arg
del collector
result_ref = UtilNodes.ResultRefNode(pos=node.pos)
result_ref.type = PyrexTypes.c_bint_type
if is_any:
condition = yield_expression
else:
condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
# Transform generator expression into plain for-loop, replace
# yield node in body by assignment of True to the node result,
# set the 'else' branch to a False assignment. Propagate the
# break after the inner assignment by injecting breaks after
# the inner loops, and putting a default 'continue' into their
# 'else' clauses.
test_node = Nodes.IfStatNode(
yield_node.pos,
else_clause = None,
if_clauses = [ Nodes.IfClauseNode(
yield_node.pos,
condition = condition,
body = Nodes.StatListNode(
node.pos,
stats = [
Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(yield_node.pos, value = is_any,
constant_result = is_any)),
Nodes.BreakStatNode(node.pos)
])) ]
)
loop = loop_node
while isinstance(loop.body, Nodes.LoopNode):
next_loop = loop.body
loop.body = Nodes.StatListNode(loop.body.pos, stats = [
loop.body,
Nodes.BreakStatNode(yield_node.pos)
])
next_loop.else_clause = Nodes.ContinueStatNode(yield_node.pos)
loop = next_loop
loop_node.else_clause = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
constant_result = not is_any))
Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
return UtilNodes.TempResultFromStatNode(result_ref, loop_node)
# specific handlers for general call nodes # specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs): def _handle_general_function_dict(self, node, pos_args, kwargs):
......
cdef class VerboseGetItem(object):
cdef object sequence
def __init__(self, seq):
self.sequence = seq
def __getitem__(self, i):
print i
return self.sequence[i] # may raise IndexError
cimport cython
@cython.test_assert_path_exists("//SimpleCallNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def all_item(x):
"""
>>> all_item([1,1,1,1,1])
True
>>> all_item([1,1,1,1,0])
False
>>> all_item([0,1,1,1,0])
False
>>> all(VerboseGetItem([1,1,1,0,0]))
0
1
2
3
False
>>> all_item(VerboseGetItem([1,1,1,0,0]))
0
1
2
3
False
>>> all(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
>>> all_item(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
"""
return all(x)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def all_in_simple_gen(seq):
"""
>>> all_in_simple_gen([1,1,1])
True
>>> all_in_simple_gen([1,1,0])
False
>>> all_in_simple_gen([1,0,1])
False
>>> all_in_simple_gen(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
>>> all_in_simple_gen(VerboseGetItem([1,1,0,1,1]))
0
1
2
False
"""
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def all_in_typed_gen(seq):
"""
>>> all_in_typed_gen([1,1,1])
True
>>> all_in_typed_gen([1,0,0])
False
>>> all_in_typed_gen(VerboseGetItem([1,1,1,1,1]))
0
1
2
3
4
5
True
>>> all_in_typed_gen(VerboseGetItem([1,1,1,1,0]))
0
1
2
3
4
False
"""
# FIXME: this isn't really supposed to work, but it currently does
# due to incorrect scoping - this should be fixed!!
cdef int x
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def all_in_nested_gen(seq):
"""
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
True
>>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,1]])
True
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,0]] for x in L)
False
>>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,0]])
False
>>> all(x for L in [[1,1,1],[0,1,1],[1,1,1]] for x in L)
False
>>> all_in_nested_gen([[1,1,1],[0,1,1],[1,1,1]])
False
>>> all_in_nested_gen([VerboseGetItem([1,1,1]), VerboseGetItem([1,1,1,1,1])])
0
1
2
3
0
1
2
3
4
5
True
>>> all_in_nested_gen([VerboseGetItem([1,1,1]),VerboseGetItem([1,1]),VerboseGetItem([1,1,0])])
0
1
2
3
0
1
2
0
1
2
False
"""
# FIXME: this isn't really supposed to work, but it currently does
# due to incorrect scoping - this should be fixed!!
cdef int x
return all(x for L in seq for x in L)
cdef class VerboseGetItem(object):
cdef object sequence
def __init__(self, seq):
self.sequence = seq
def __getitem__(self, i):
print i
return self.sequence[i] # may raise IndexError
cimport cython
@cython.test_assert_path_exists("//SimpleCallNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
def any_item(x):
"""
>>> any_item([0,0,1,0,0])
True
>>> any_item([0,0,0,0,1])
True
>>> any_item([0,0,0,0,0])
False
>>> any(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any_item(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
>>> any_item(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
"""
return any(x)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def any_in_simple_gen(seq):
"""
>>> any_in_simple_gen([0,1,0])
True
>>> any_in_simple_gen([0,0,0])
False
>>> any_in_simple_gen(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any_in_simple_gen(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
"""
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def any_in_typed_gen(seq):
"""
>>> any_in_typed_gen([0,1,0])
True
>>> any_in_typed_gen([0,0,0])
False
>>> any_in_typed_gen(VerboseGetItem([0,0,1,0,0]))
0
1
2
True
>>> any_in_typed_gen(VerboseGetItem([0,0,0,0,0]))
0
1
2
3
4
5
False
"""
# FIXME: this isn't really supposed to work, but it currently does
# due to incorrect scoping - this should be fixed!!
cdef int x
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode",
"//GeneratorExpressionNode")
def any_in_nested_gen(seq):
"""
>>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
True
>>> any_in_nested_gen([[0,0,0],[0,0,1],[0,0,0]])
True
>>> any(x for L in [[0,0,0],[0,0,0],[0,0,0]] for x in L)
False
>>> any_in_nested_gen([[0,0,0],[0,0,0],[0,0,0]])
False
>>> any_in_nested_gen([VerboseGetItem([0,0,0]), VerboseGetItem([0,0,1,0,0])])
0
1
2
3
0
1
2
True
>>> any_in_nested_gen([VerboseGetItem([0,0,0]),VerboseGetItem([0,0]),VerboseGetItem([0,0,0])])
0
1
2
3
0
1
2
0
1
2
3
False
"""
# FIXME: this isn't really supposed to work, but it currently does
# due to incorrect scoping - this should be fixed!!
cdef int x
return any(x for L in seq for x in L)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment