Commit ee2d4178 authored by Stefan Behnel's avatar Stefan Behnel

clean up comprehensions to bring them closer to generator expressions, make...

clean up comprehensions to bring them closer to generator expressions, make their scoping behaviour configurable
remove optimisations for set([...]) and dict([...]) as they do not take side-effects into account: unhashable items lead to pre-mature exit from the loop
instead, transform set(genexp), list(genexp) and dict(genexp) into inlined comprehensions that do not leak loop variables
parent ac7b0df6
......@@ -3919,27 +3919,45 @@ class ScopedExprNode(ExprNode):
# this is called with the expr_scope as env
pass
def init_scope(self, outer_scope, expr_scope=None):
self.expr_scope = expr_scope
class ComprehensionNode(ExprNode): # (ScopedExprNode)
class ComprehensionNode(ScopedExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
# different behaviour in Py2 and Py3: leak loop variables or not?
has_local_scope = False # Py2 behaviour as default
def infer_type(self, env):
return self.target.infer_type(env)
def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop
self.init_scope(env)
if self.expr_scope is not None:
self.loop.analyse_declarations(self.expr_scope)
else:
self.loop.analyse_declarations(env)
# self.expr_scope = Symtab.GeneratorExpressionScope(env)
# self.loop.analyse_declarations(self.expr_scope)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
elif self.has_local_scope:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
else:
self.expr_scope = None
def analyse_types(self, env):
self.target.analyse_expressions(env)
self.type = self.target.type
if not self.has_local_scope:
self.loop.analyse_expressions(env)
# def analyse_scoped_expressions(self, env):
# self.loop.analyse_expressions(env)
def analyse_scoped_expressions(self, env):
if self.has_local_scope:
self.loop.analyse_expressions(env)
def may_be_none(self):
return False
......@@ -3957,20 +3975,20 @@ class ComprehensionNode(ExprNode): # (ScopedExprNode)
self.loop.annotate(code)
class ComprehensionAppendNode(ExprNode):
class ComprehensionAppendNode(Node):
# Need to be careful to avoid infinite recursion:
# target must not be in child_attrs/subexprs
subexprs = ['expr']
child_attrs = ['expr']
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.expr.analyse_types(env)
def analyse_expressions(self, env):
self.expr.analyse_expressions(env)
if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env)
self.is_temp = 1
def generate_result_code(self, code):
def generate_execution_code(self, code):
if self.target.type is list_type:
function = "PyList_Append"
elif self.target.type is set_type:
......@@ -3979,32 +3997,52 @@ class ComprehensionAppendNode(ExprNode):
raise InternalError(
"Invalid type for comprehension node: %s" % self.target.type)
code.putln("%s = %s(%s, (PyObject*)%s); %s" %
(self.result(),
self.expr.generate_evaluation_code(code)
code.putln(code.error_goto_if("%s(%s, (PyObject*)%s)" % (
function,
self.target.result(),
self.expr.result(),
code.error_goto_if(self.result(), self.pos)))
self.expr.result()
), self.pos))
self.expr.generate_disposal_code(code)
self.expr.free_temps(code)
def generate_function_definitions(self, env, code):
self.expr.generate_function_definitions(env, code)
def annotate(self, code):
self.expr.annotate(code)
class DictComprehensionAppendNode(ComprehensionAppendNode):
subexprs = ['key_expr', 'value_expr']
child_attrs = ['key_expr', 'value_expr']
def analyse_types(self, env):
self.key_expr.analyse_types(env)
def analyse_expressions(self, env):
self.key_expr.analyse_expressions(env)
if not self.key_expr.type.is_pyobject:
self.key_expr = self.key_expr.coerce_to_pyobject(env)
self.value_expr.analyse_types(env)
self.value_expr.analyse_expressions(env)
if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env)
self.is_temp = 1
def generate_result_code(self, code):
code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
(self.result(),
def generate_execution_code(self, code):
self.key_expr.generate_evaluation_code(code)
self.value_expr.generate_evaluation_code(code)
code.putln(code.error_goto_if("PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s)" % (
self.target.result(),
self.key_expr.result(),
self.value_expr.result(),
code.error_goto_if(self.result(), self.pos)))
self.value_expr.result()
), self.pos))
self.key_expr.generate_disposal_code(code)
self.key_expr.free_temps(code)
self.value_expr.generate_disposal_code(code)
self.value_expr.free_temps(code)
def generate_function_definitions(self, env, code):
self.key_expr.generate_function_definitions(env, code)
self.value_expr.generate_function_definitions(env, code)
def annotate(self, code):
self.key_expr.annotate(code)
self.value_expr.annotate(code)
class GeneratorExpressionNode(ScopedExprNode):
......@@ -4019,9 +4057,15 @@ class GeneratorExpressionNode(ScopedExprNode):
type = py_object_type
def analyse_declarations(self, env):
self.expr_scope = Symtab.GeneratorExpressionScope(env)
self.init_scope(env)
self.loop.analyse_declarations(self.expr_scope)
def init_scope(self, outer_scope, expr_scope=None):
if expr_scope is not None:
self.expr_scope = expr_scope
else:
self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
def analyse_types(self, env):
self.is_temp = True
......
......@@ -1022,52 +1022,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# specific handlers for simple call nodes
def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
"""
arg_count = len(pos_args)
if arg_count == 0:
return ExprNodes.SetNode(node.pos, args=[],
type=Builtin.set_type)
if arg_count > 1:
return node
iterable = pos_args[0]
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args)
elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
isinstance(iterable.target, (ExprNodes.ListNode,
ExprNodes.SetNode)):
iterable.target = ExprNodes.SetNode(node.pos, args=[])
iterable.pos = node.pos
return iterable
else:
return node
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
"""
if len(pos_args) != 1:
return node
arg = pos_args[0]
if isinstance(arg, ExprNodes.ComprehensionNode) and \
isinstance(arg.target, (ExprNodes.ListNode,
ExprNodes.SetNode)):
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[])
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
def _handle_simple_function_float(self, node, pos_args):
if len(pos_args) == 0:
return ExprNodes.FloatNode(node.pos, value='0.0')
......@@ -1182,7 +1136,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
constant_result = not is_any))
Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
Visitor.recursively_replace_node(loop_node, yield_node, test_node)
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
......@@ -1215,7 +1169,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
)
Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
Visitor.recursively_replace_node(loop_node, yield_node, add_node)
exec_code = Nodes.StatListNode(
node.pos,
......@@ -1232,6 +1186,113 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
def _handle_simple_function_list(self, node, pos_args):
if len(pos_args) == 0:
return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
def _handle_simple_function_set(self, node, pos_args):
if len(pos_args) == 0:
return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
"""Replace set(genexpr) and list(genexpr) by a literal comprehension.
"""
if len(pos_args) > 1:
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node)
if yield_node is None:
return node
yield_expression = yield_node.arg
target_node = container_node_class(node.pos, args=[])
append_node = ExprNodes.ComprehensionAppendNode(
yield_node.pos,
expr = yield_expression,
target = ExprNodes.CloneNode(target_node),
is_temp = 1) # FIXME: why is this an ExprNode?
Visitor.recursively_replace_node(loop_node, yield_node, append_node)
setcomp = ExprNodes.ComprehensionNode(
node.pos,
has_local_scope = True,
expr_scope = gen_expr_node.expr_scope,
loop = loop_node,
append = append_node,
target = target_node)
append_node.target = setcomp
return setcomp
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
"""
if len(pos_args) == 0:
return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
if len(pos_args) > 1:
return node
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node)
if yield_node is None:
return node
yield_expression = yield_node.arg
if not isinstance(yield_expression, ExprNodes.TupleNode):
return node
if len(yield_expression.args) != 2:
return node
target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
append_node = ExprNodes.DictComprehensionAppendNode(
yield_node.pos,
key_expr = yield_expression.args[0],
value_expr = yield_expression.args[1],
target = ExprNodes.CloneNode(target_node),
is_temp = 1) # FIXME: why is this an ExprNode?
Visitor.recursively_replace_node(loop_node, yield_node, append_node)
dictcomp = ExprNodes.ComprehensionNode(
node.pos,
has_local_scope = True,
expr_scope = gen_expr_node.expr_scope,
loop = loop_node,
append = append_node,
target = target_node)
append_node.target = dictcomp
return dictcomp
arg = pos_args[0]
if isinstance(arg, ExprNodes.ComprehensionNode) and \
isinstance(arg.target, (ExprNodes.ListNode,
ExprNodes.SetNode)):
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[])
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
# specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs):
......
......@@ -1142,6 +1142,7 @@ class AnalyseExpressionsTransform(CythonTransform):
return node
def visit_ScopedExprNode(self, node):
if node.expr_scope is not None:
node.expr_scope.infer_types()
node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
......
......@@ -777,7 +777,7 @@ def p_list_maker(s):
target = ExprNodes.ListNode(pos, args = [])
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
loop = p_comp_for(s, append)
s.expect(']')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
......@@ -843,7 +843,7 @@ def p_dict_or_set_maker(s):
target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target))
loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
......@@ -858,7 +858,7 @@ def p_dict_or_set_maker(s):
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value,
target=ExprNodes.CloneNode(target))
loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
......
......@@ -352,7 +352,9 @@ class RecursiveNodeReplacer(VisitorTransform):
else:
return node
def recursively_replace_node(tree, old_node, new_node):
replace_in = RecursiveNodeReplacer(old_node, new_node)
replace_in(tree)
# Utils
......
__doc__ = u"""
>>> type(smoketest_dict()) is dict
True
>>> type(smoketest_list()) is dict
True
>>> sorted(smoketest_dict().items())
[(2, 0), (4, 4), (6, 8)]
>>> sorted(smoketest_list().items())
[(2, 0), (4, 4), (6, 8)]
>>> list(typed().items())
[(A, 1), (A, 1), (A, 1)]
>>> sorted(iterdict().items())
[(1, 'a'), (2, 'b'), (3, 'c')]
"""
cimport cython
def smoketest_dict():
return { x+2:x*2
def dictcomp():
"""
>>> sorted(dictcomp().items())
[(2, 0), (4, 4), (6, 8)]
>>> sorted(dictcomp().items())
[(2, 0), (4, 4), (6, 8)]
"""
x = 'abc'
result = { x+2:x*2
for x in range(5)
if x % 2 == 0 }
assert x != 'abc'
return result
@cython.test_fail_if_path_exists(
"//ComprehensionNode//ComprehensionAppendNode",
"//SimpleCallNode//ComprehensionNode")
"//GeneratorExpressionNode",
"//SimpleCallNode")
@cython.test_assert_path_exists(
"//ComprehensionNode",
"//ComprehensionNode//DictComprehensionAppendNode")
def smoketest_list():
return dict([ (x+2,x*2)
def genexpr():
"""
>>> type(genexpr()) is dict
True
>>> type(genexpr()) is dict
True
"""
x = 'abc'
result = dict( (x+2,x*2)
for x in range(5)
if x % 2 == 0 ])
if x % 2 == 0 )
assert x == 'abc'
return result
cdef class A:
def __repr__(self): return u"A"
def __richcmp__(one, other, op): return one is other
def __hash__(self): return id(self) % 65536
def typed():
def typed_dictcomp():
"""
>>> list(typed_dictcomp().items())
[(A, 1), (A, 1), (A, 1)]
"""
cdef A obj
return {obj:1 for obj in [A(), A(), A()]}
def iterdict():
def iterdict_dictcomp():
"""
>>> sorted(iterdict_dictcomp().items())
[(1, 'a'), (2, 'b'), (3, 'c')]
"""
cdef dict d = dict(a=1,b=2,c=3)
return {d[key]:key for key in d}
......
......@@ -3,7 +3,20 @@ def smoketest():
>>> smoketest()
[0, 4, 8]
"""
print [x*2 for x in range(5) if x % 2 == 0]
x = 'abc'
result = [x*2 for x in range(5) if x % 2 == 0]
assert x != 'abc'
return result
def list_genexp():
"""
>>> list_genexp()
[0, 4, 8]
"""
x = 'abc'
result = list(x*2 for x in range(5) if x % 2 == 0)
assert x == 'abc'
return result
def int_runvar():
"""
......
__doc__ = u"""
>>> type(smoketest_set()) is not list
True
>>> type(smoketest_set()) is _set
True
>>> type(smoketest_list()) is _set
True
>>> sorted(smoketest_set())
[0, 4, 8]
>>> sorted(smoketest_list())
[0, 4, 8]
>>> list(typed())
[A, A, A]
>>> sorted(iterdict())
[1, 2, 3]
"""
cimport cython
# Py2.3 doesn't have the set type, but Cython does :)
_set = set
def smoketest_set():
def setcomp():
"""
>>> type(setcomp()) is not list
True
>>> type(setcomp()) is _set
True
>>> sorted(setcomp())
[0, 4, 8]
"""
return { x*2
for x in range(5)
if x % 2 == 0 }
@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
@cython.test_assert_path_exists("//ComprehensionNode",
@cython.test_fail_if_path_exists(
"//GeneratorExpressionNode",
"//SimpleCallNode")
@cython.test_assert_path_exists(
"//ComprehensionNode",
"//ComprehensionNode//ComprehensionAppendNode")
def smoketest_list():
return set([ x*2
def genexp_set():
"""
>>> type(genexp_set()) is _set
True
>>> sorted(genexp_set())
[0, 4, 8]
"""
return set( x*2
for x in range(5)
if x % 2 == 0 ])
if x % 2 == 0 )
cdef class A:
def __repr__(self): return u"A"
......@@ -41,10 +40,18 @@ cdef class A:
def __hash__(self): return id(self) % 65536
def typed():
"""
>>> list(typed())
[A, A, A]
"""
cdef A obj
return {obj for obj in {A(), A(), A()}}
def iterdict():
"""
>>> sorted(iterdict())
[1, 2, 3]
"""
cdef dict d = dict(a=1,b=2,c=3)
return {d[key] for key in d}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment