Commit b2cb180f authored by Stefan Behnel's avatar Stefan Behnel

refactor comprehensions by removing separate target node (to simplify a future...

refactor comprehensions by removing separate target node (to simplify a future length-hint optimisation)

--HG--
extra : rebase_source : 476b22eeaeaea1ff69ee8069328fb47ffe18ea20
parent d2fb1655
...@@ -6103,11 +6103,14 @@ class ScopedExprNode(ExprNode): ...@@ -6103,11 +6103,14 @@ class ScopedExprNode(ExprNode):
class ComprehensionNode(ScopedExprNode): class ComprehensionNode(ScopedExprNode):
subexprs = ["target"] # A list/set/dict comprehension
child_attrs = ["loop"] child_attrs = ["loop"]
is_temp = True
def infer_type(self, env): def infer_type(self, env):
return self.target.infer_type(env) return self.type
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop self.append.target = self # this is used in the PyList_Append of the inner loop
...@@ -6117,8 +6120,6 @@ class ComprehensionNode(ScopedExprNode): ...@@ -6117,8 +6120,6 @@ class ComprehensionNode(ScopedExprNode):
self.loop.analyse_declarations(env) self.loop.analyse_declarations(env)
def analyse_types(self, env): def analyse_types(self, env):
self.target = self.target.analyse_expressions(env)
self.type = self.target.type
if not self.has_local_scope: if not self.has_local_scope:
self.loop = self.loop.analyse_expressions(env) self.loop = self.loop.analyse_expressions(env)
return self return self
...@@ -6131,13 +6132,23 @@ class ComprehensionNode(ScopedExprNode): ...@@ -6131,13 +6132,23 @@ class ComprehensionNode(ScopedExprNode):
def may_be_none(self): def may_be_none(self):
return False return False
def calculate_result_code(self):
return self.target.result()
def generate_result_code(self, code): def generate_result_code(self, code):
self.generate_operation_code(code) self.generate_operation_code(code)
def generate_operation_code(self, code): def generate_operation_code(self, code):
if self.type is Builtin.list_type:
create_code = 'PyList_New(0)'
elif self.type is Builtin.set_type:
create_code = 'PySet_New(NULL)'
elif self.type is Builtin.dict_type:
create_code = 'PyDict_New()'
else:
raise InternalError("illegal type for comprehension: %s" % self.type)
code.putln('%s = %s; %s' % (
self.result(), create_code,
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.result())
self.loop.generate_execution_code(code) self.loop.generate_execution_code(code)
def annotate(self, code): def annotate(self, code):
...@@ -6149,6 +6160,7 @@ class ComprehensionAppendNode(Node): ...@@ -6149,6 +6160,7 @@ class ComprehensionAppendNode(Node):
# target must not be in child_attrs/subexprs # target must not be in child_attrs/subexprs
child_attrs = ['expr'] child_attrs = ['expr']
target = None
type = PyrexTypes.c_int_type type = PyrexTypes.c_int_type
......
...@@ -1246,7 +1246,6 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -1246,7 +1246,6 @@ class ControlFlowAnalysis(CythonTransform):
self.env_stack.append(self.env) self.env_stack.append(self.env)
self.env = node.expr_scope self.env = node.expr_scope
# Skip append node here # Skip append node here
self._visit(node.target)
self._visit(node.loop) self._visit(node.loop)
if node.expr_scope: if node.expr_scope:
self.env = self.env_stack.pop() self.env = self.env_stack.pop()
......
...@@ -1404,7 +1404,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1404,7 +1404,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if len(pos_args) != 1: if len(pos_args) != 1:
return node return node
if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \ if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
and pos_args[0].target.type is Builtin.list_type: and pos_args[0].type is Builtin.list_type:
listcomp_node = pos_args[0] listcomp_node = pos_args[0]
loop_node = listcomp_node.loop loop_node = listcomp_node.loop
elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
...@@ -1414,18 +1414,17 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1414,18 +1414,17 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if yield_expression is None: if yield_expression is None:
return node return node
target = ExprNodes.ListNode(node.pos, args = [])
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression, yield_expression.pos, expr = yield_expression)
target = ExprNodes.CloneNode(target))
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
listcomp_node = ExprNodes.ComprehensionNode( listcomp_node = ExprNodes.ComprehensionNode(
gen_expr_node.pos, loop = loop_node, target = target, gen_expr_node.pos, loop = loop_node,
append = append_node, type = Builtin.list_type, append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope, expr_scope = gen_expr_node.expr_scope,
has_local_scope = True) has_local_scope = True)
append_node.target = listcomp_node
else: else:
return node return node
...@@ -1550,7 +1549,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1550,7 +1549,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# the items into a list and then copy them into a tuple of the # the items into a list and then copy them into a tuple of the
# final size. This takes up to twice as much memory, but will # final size. This takes up to twice as much memory, but will
# have to do until we have real support for genexps. # have to do until we have real support for genexps.
result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode) result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
if result is not node: if result is not node:
return ExprNodes.AsTupleNode(node.pos, arg=result) return ExprNodes.AsTupleNode(node.pos, arg=result)
return node return node
...@@ -1558,14 +1557,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1558,14 +1557,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
def _handle_simple_function_list(self, node, pos_args): def _handle_simple_function_list(self, node, pos_args):
if not pos_args: if not pos_args:
return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode) return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
def _handle_simple_function_set(self, node, pos_args): def _handle_simple_function_set(self, node, pos_args):
if not pos_args: if not pos_args:
return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode) return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
def _transform_list_set_genexpr(self, node, pos_args, container_node_class): def _transform_list_set_genexpr(self, node, pos_args, target_type):
"""Replace set(genexpr) and list(genexpr) by a literal comprehension. """Replace set(genexpr) and list(genexpr) by a literal comprehension.
""" """
if len(pos_args) > 1: if len(pos_args) > 1:
...@@ -1579,23 +1578,21 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1579,23 +1578,21 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if yield_expression is None: if yield_expression is None:
return node return node
target_node = container_node_class(node.pos, args=[])
append_node = ExprNodes.ComprehensionAppendNode( append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
expr = yield_expression, expr = yield_expression)
target = ExprNodes.CloneNode(target_node))
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
setcomp = ExprNodes.ComprehensionNode( comp = ExprNodes.ComprehensionNode(
node.pos, node.pos,
has_local_scope = True, has_local_scope = True,
expr_scope = gen_expr_node.expr_scope, expr_scope = gen_expr_node.expr_scope,
loop = loop_node, loop = loop_node,
append = append_node, append = append_node,
target = target_node) type = target_type)
append_node.target = setcomp append_node.target = comp
return setcomp return comp
def _handle_simple_function_dict(self, node, pos_args): def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict( (a,b) for ... ) by a literal { a:b for ... }. """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
...@@ -1618,12 +1615,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1618,12 +1615,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if len(yield_expression.args) != 2: if len(yield_expression.args) != 2:
return node return node
target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
append_node = ExprNodes.DictComprehensionAppendNode( append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos, yield_expression.pos,
key_expr = yield_expression.args[0], key_expr = yield_expression.args[0],
value_expr = yield_expression.args[1], value_expr = yield_expression.args[1])
target = ExprNodes.CloneNode(target_node))
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
...@@ -1633,7 +1628,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1633,7 +1628,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
expr_scope = gen_expr_node.expr_scope, expr_scope = gen_expr_node.expr_scope,
loop = loop_node, loop = loop_node,
append = append_node, append = append_node,
target = target_node) type = Builtin.dict_type)
append_node.target = dictcomp append_node.target = dictcomp
return dictcomp return dictcomp
...@@ -3245,7 +3240,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3245,7 +3240,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
self.visitchildren(node) self.visitchildren(node)
if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
# loop was pruned already => transform into literal # loop was pruned already => transform into literal
return node.target if node.type is Builtin.list_type:
return ExprNodes.ListNode(node.pos, args=[])
elif node.type is Builtin.set_type:
return ExprNodes.SetNode(node.pos, args=[])
elif node.type is Builtin.dict_type:
return ExprNodes.DictNode(node.pos, key_value_pairs=[])
return node return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
import cython import cython
cython.declare(Nodes=object, ExprNodes=object, EncodedString=object, cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
StringEncoding=object, lookup_unicodechar=object, re=object, StringEncoding=object, lookup_unicodechar=object, re=object,
Future=object, Options=object, error=object, warning=object) Future=object, Options=object, error=object, warning=object,
Builtin=object)
import re import re
from unicodedata import lookup as lookup_unicodechar from unicodedata import lookup as lookup_unicodechar
...@@ -15,6 +16,7 @@ from unicodedata import lookup as lookup_unicodechar ...@@ -15,6 +16,7 @@ from unicodedata import lookup as lookup_unicodechar
from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor
import Nodes import Nodes
import ExprNodes import ExprNodes
import Builtin
import StringEncoding import StringEncoding
from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes
from ModuleNode import ModuleNode from ModuleNode import ModuleNode
...@@ -897,13 +899,11 @@ def p_list_maker(s): ...@@ -897,13 +899,11 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = []) return ExprNodes.ListNode(pos, args = [])
expr = p_test(s) expr = p_test(s)
if s.sy == 'for': if s.sy == 'for':
target = ExprNodes.ListNode(pos, args = []) append = ExprNodes.ComprehensionAppendNode(pos, expr=expr)
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
loop = p_comp_for(s, append) loop = p_comp_for(s, append)
s.expect(']') s.expect(']')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target, pos, loop=loop, append=append, type = Builtin.list_type,
# list comprehensions leak their loop variable in Py2 # list comprehensions leak their loop variable in Py2
has_local_scope = s.context.language_level >= 3) has_local_scope = s.context.language_level >= 3)
else: else:
...@@ -964,13 +964,12 @@ def p_dict_or_set_maker(s): ...@@ -964,13 +964,12 @@ def p_dict_or_set_maker(s):
return ExprNodes.SetNode(pos, args=values) return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for': elif s.sy == 'for':
# set comprehension # set comprehension
target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode( append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target)) item.pos, expr=item)
loop = p_comp_for(s, append) loop = p_comp_for(s, append)
s.expect('}') s.expect('}')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target) pos, loop=loop, append=append, type=Builtin.set_type)
elif s.sy == ':': elif s.sy == ':':
# dict literal or comprehension # dict literal or comprehension
key = item key = item
...@@ -978,14 +977,12 @@ def p_dict_or_set_maker(s): ...@@ -978,14 +977,12 @@ def p_dict_or_set_maker(s):
value = p_test(s) value = p_test(s)
if s.sy == 'for': if s.sy == 'for':
# dict comprehension # dict comprehension
target = ExprNodes.DictNode(pos, key_value_pairs = [])
append = ExprNodes.DictComprehensionAppendNode( append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value, item.pos, key_expr=key, value_expr=value)
target=ExprNodes.CloneNode(target))
loop = p_comp_for(s, append) loop = p_comp_for(s, append)
s.expect('}') s.expect('}')
return ExprNodes.ComprehensionNode( return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target) pos, loop=loop, append=append, type=Builtin.dict_type)
else: else:
# dict literal # dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)] items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment