Commit 13ace3a9 authored by Stefan Behnel's avatar Stefan Behnel

re-establish a simple form of generator inlining for any() and all() that does...

re-establish a simple form of generator inlining for any() and all() that does not remove the generator but inlines the evaluation into the inner loop
parent 0710b467
......@@ -7315,7 +7315,39 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
self.value_expr.annotate(code)
class InlinedGeneratorExpressionNode(ScopedExprNode):
class InlinedGeneratorExpressionNode(ExprNode):
# An inlined generator expression for which the result is
# calculated inside of the loop. This will only be created by
# transforms when replacing builtin calls on generator
# expressions.
#
# gen GeneratorExpressionNode the generator, not containing any YieldExprNodes
# orig_func String the name of the builtin function this node replaces
subexprs = ["gen"]
orig_func = None
type = py_object_type
def may_be_none(self):
return self.orig_func not in ('any', 'all')
def infer_type(self, env):
return py_object_type
def analyse_types(self, env):
self.gen = self.gen.analyse_expressions(env)
self.is_temp = True
return self
def generate_result_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("GetGenexpResult", "Coroutine.c"))
code.putln("%s = __Pyx_Generator_GetGenexpResult(%s); %s" % (
self.result(), self.gen.result(),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.result())
class __InlinedGeneratorExpressionNode(ScopedExprNode):
# An inlined generator expression for which the result is
# calculated inside of the loop. This will only be created by
# transforms when replacing builtin calls on generator
......
......@@ -1455,16 +1455,16 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
visit_Node = Visitor.TreeVisitor.visitchildren
# XXX: disable inlining while it's not back supported
def __visit_YieldExprNode(self, node):
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def __visit_ExprStatNode(self, node):
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
def __visit_GeneratorExpressionNode(self, node):
def visit_GeneratorExpressionNode(self, node):
# enable when we support generic generator expressions
#
# everything below this node is out of scope
......@@ -1527,7 +1527,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
generator_body = gen_expr_node.def_node.gbody
loop_node = generator_body.body
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None:
return node
......@@ -1535,46 +1536,37 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if is_any:
condition = yield_expression
else:
condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
test_node = Nodes.IfStatNode(
yield_expression.pos, else_clause=None, if_clauses=[
Nodes.IfClauseNode(
yield_expression.pos,
else_clause = None,
if_clauses = [ Nodes.IfClauseNode(
yield_expression.pos,
condition = condition,
body = Nodes.StatListNode(
condition=condition,
body=Nodes.ReturnStatNode(
node.pos,
stats = [
Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
constant_result = is_any)),
Nodes.BreakStatNode(node.pos)
])) ]
value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any),
in_generator=True)
)]
)
loop = loop_node
while isinstance(loop.body, Nodes.LoopNode):
next_loop = loop.body
loop.body = Nodes.StatListNode(loop.body.pos, stats = [
loop.body = Nodes.StatListNode(loop.body.pos, stats=[
loop.body,
Nodes.BreakStatNode(yield_expression.pos)
])
next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
loop = next_loop
loop_node.else_clause = Nodes.SingleAssignmentNode(
loop_node.else_clause = Nodes.ReturnStatNode(
node.pos,
lhs = result_ref,
rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
constant_result = not is_any))
value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any),
in_generator=True)
Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
PySequence_List_func_type = PyrexTypes.CFuncType(
Builtin.list_type,
......@@ -1597,6 +1589,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None:
return node
......@@ -1642,7 +1636,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
result_node,
Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
def _handle_simple_function_sum(self, node, pos_args):
def __handle_simple_function_sum(self, node, pos_args):
"""Transform sum(genexpr) into an equivalent inlined aggregation loop.
"""
if len(pos_args) not in (1,2):
......@@ -1655,6 +1649,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None:
return node
else: # ComprehensionNode
......@@ -1786,6 +1782,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None:
return node
......@@ -1818,6 +1816,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None:
return node
......
......@@ -246,6 +246,28 @@ static void __Pyx_Generator_Replace_StopIteration(void) {
}
//////////////////// GetGenexpResult.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Generator_GetGenexpResult(PyObject* gen); /*proto*/
//////////////////// GetGenexpResult ////////////////////
//@requires: Generator
static CYTHON_INLINE PyObject* __Pyx_Generator_GetGenexpResult(PyObject* gen) {
PyObject *result;
result = __Pyx_Generator_Next(gen);
if (unlikely(result)) {
PyErr_Format(PyExc_RuntimeError, "Generator expression returned with non-StopIteration result '%.100s'",
result ? Py_TYPE(result)->tp_name : "NULL");
Py_XDECREF(result);
return NULL;
}
if (unlikely(__Pyx_PyGen_FetchStopIterationValue(&result) < 0))
return NULL;
return result;
}
//////////////////// CoroutineBase.proto ////////////////////
typedef PyObject *(*__pyx_coroutine_body_t)(PyObject *, PyObject *);
......
......@@ -52,6 +52,4 @@ pyregr.test_urllib2net
pyregr.test_urllibnet
# Inlined generators
all
any
inlined_generator_expressions
......@@ -53,10 +53,14 @@ def all_item(x):
"""
return all(x)
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_simple_gen(seq):
"""
>>> all_in_simple_gen([1,1,1])
......@@ -82,10 +86,14 @@ def all_in_simple_gen(seq):
"""
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_simple_gen_scope(seq):
"""
>>> all_in_simple_gen_scope([1,1,1])
......@@ -114,10 +122,14 @@ def all_in_simple_gen_scope(seq):
assert x == 'abc'
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_conditional_gen(seq):
"""
>>> all_in_conditional_gen([3,6,9])
......@@ -150,10 +162,14 @@ mixed_ustring = u'AbcDefGhIjKlmnoP'
lower_ustring = mixed_ustring.lower()
upper_ustring = mixed_ustring.upper()
@cython.test_assert_path_exists('//PythonCapiCallNode',
'//ForFromStatNode')
@cython.test_fail_if_path_exists('//SimpleCallNode',
'//ForInStatNode')
@cython.test_assert_path_exists(
'//PythonCapiCallNode',
'//ForFromStatNode'
)
@cython.test_fail_if_path_exists(
'//SimpleCallNode',
'//ForInStatNode'
)
def all_lower_case_characters(unicode ustring):
"""
>>> all_lower_case_characters(mixed_ustring)
......@@ -165,12 +181,16 @@ def all_lower_case_characters(unicode ustring):
"""
return all(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode",
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode",
"//InlinedGeneratorExpressionNode//IfStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//InlinedGeneratorExpressionNode//IfStatNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
"//IfStatNode//CoerceToBooleanNode")
# "//IfStatNode//CoerceToBooleanNode"
)
def all_in_typed_gen(seq):
"""
>>> all_in_typed_gen([1,1,1])
......@@ -197,12 +217,16 @@ def all_in_typed_gen(seq):
cdef int x
return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode",
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode",
"//InlinedGeneratorExpressionNode//IfStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//InlinedGeneratorExpressionNode//IfStatNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
"//IfStatNode//CoerceToBooleanNode")
# "//IfStatNode//CoerceToBooleanNode"
)
def all_in_double_gen(seq):
"""
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
......
......@@ -52,10 +52,14 @@ def any_item(x):
return any(x)
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_simple_gen(seq):
"""
>>> any_in_simple_gen([0,1,0])
......@@ -80,10 +84,14 @@ def any_in_simple_gen(seq):
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_simple_gen_scope(seq):
"""
>>> any_in_simple_gen_scope([0,1,0])
......@@ -111,10 +119,14 @@ def any_in_simple_gen_scope(seq):
return result
@cython.test_assert_path_exists("//ForInStatNode",
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_conditional_gen(seq):
"""
>>> any_in_conditional_gen([3,6,9])
......@@ -146,11 +158,15 @@ lower_ustring = mixed_ustring.lower()
upper_ustring = mixed_ustring.upper()
@cython.test_assert_path_exists('//PythonCapiCallNode',
@cython.test_assert_path_exists(
'//PythonCapiCallNode',
'//ForFromStatNode',
"//InlinedGeneratorExpressionNode")
@cython.test_fail_if_path_exists('//SimpleCallNode',
'//ForInStatNode')
"//InlinedGeneratorExpressionNode"
)
@cython.test_fail_if_path_exists(
'//SimpleCallNode',
'//ForInStatNode'
)
def any_lower_case_characters(unicode ustring):
"""
>>> any_lower_case_characters(upper_ustring)
......@@ -163,12 +179,16 @@ def any_lower_case_characters(unicode ustring):
return any(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode",
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode",
"//InlinedGeneratorExpressionNode//IfStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//InlinedGeneratorExpressionNode//IfStatNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
"//IfStatNode//CoerceToBooleanNode")
# "//IfStatNode//CoerceToBooleanNode"
)
def any_in_typed_gen(seq):
"""
>>> any_in_typed_gen([0,1,0])
......@@ -194,11 +214,15 @@ def any_in_typed_gen(seq):
return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode",
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode",
"//InlinedGeneratorExpressionNode//IfStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//YieldExprNode")
"//InlinedGeneratorExpressionNode//IfStatNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_gen_builtin_name(seq):
"""
>>> any_in_gen_builtin_name([0,1,0])
......@@ -223,12 +247,16 @@ def any_in_gen_builtin_name(seq):
return any(type for type in seq)
@cython.test_assert_path_exists("//ForInStatNode",
@cython.test_assert_path_exists(
"//ForInStatNode",
"//InlinedGeneratorExpressionNode",
"//InlinedGeneratorExpressionNode//IfStatNode")
@cython.test_fail_if_path_exists("//SimpleCallNode",
"//InlinedGeneratorExpressionNode//IfStatNode"
)
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
"//IfStatNode//CoerceToBooleanNode")
# "//IfStatNode//CoerceToBooleanNode"
)
def any_in_double_gen(seq):
"""
>>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment