Commit b2cb180f authored by Stefan Behnel's avatar Stefan Behnel

refactor comprehensions by removing separate target node (to simplify a future...

refactor comprehensions by removing separate target node (to simplify a future length-hint optimisation)

--HG--
extra : rebase_source : 476b22eeaeaea1ff69ee8069328fb47ffe18ea20
parent d2fb1655
......@@ -6103,11 +6103,14 @@ class ScopedExprNode(ExprNode):
class ComprehensionNode(ScopedExprNode):
subexprs = ["target"]
# A list/set/dict comprehension
child_attrs = ["loop"]
is_temp = True
def infer_type(self, env):
return self.target.infer_type(env)
return self.type
def analyse_declarations(self, env):
self.append.target = self # this is used in the PyList_Append of the inner loop
......@@ -6117,8 +6120,6 @@ class ComprehensionNode(ScopedExprNode):
self.loop.analyse_declarations(env)
def analyse_types(self, env):
self.target = self.target.analyse_expressions(env)
self.type = self.target.type
if not self.has_local_scope:
self.loop = self.loop.analyse_expressions(env)
return self
......@@ -6131,13 +6132,23 @@ class ComprehensionNode(ScopedExprNode):
def may_be_none(self):
return False
def calculate_result_code(self):
return self.target.result()
def generate_result_code(self, code):
self.generate_operation_code(code)
def generate_operation_code(self, code):
if self.type is Builtin.list_type:
create_code = 'PyList_New(0)'
elif self.type is Builtin.set_type:
create_code = 'PySet_New(NULL)'
elif self.type is Builtin.dict_type:
create_code = 'PyDict_New()'
else:
raise InternalError("illegal type for comprehension: %s" % self.type)
code.putln('%s = %s; %s' % (
self.result(), create_code,
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.result())
self.loop.generate_execution_code(code)
def annotate(self, code):
......@@ -6149,6 +6160,7 @@ class ComprehensionAppendNode(Node):
# target must not be in child_attrs/subexprs
child_attrs = ['expr']
target = None
type = PyrexTypes.c_int_type
......
......@@ -1246,7 +1246,6 @@ class ControlFlowAnalysis(CythonTransform):
self.env_stack.append(self.env)
self.env = node.expr_scope
# Skip append node here
self._visit(node.target)
self._visit(node.loop)
if node.expr_scope:
self.env = self.env_stack.pop()
......
......@@ -1404,7 +1404,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if len(pos_args) != 1:
return node
if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
and pos_args[0].target.type is Builtin.list_type:
and pos_args[0].type is Builtin.list_type:
listcomp_node = pos_args[0]
loop_node = listcomp_node.loop
elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
......@@ -1414,18 +1414,17 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if yield_expression is None:
return node
target = ExprNodes.ListNode(node.pos, args = [])
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos, expr = yield_expression,
target = ExprNodes.CloneNode(target))
yield_expression.pos, expr = yield_expression)
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
listcomp_node = ExprNodes.ComprehensionNode(
gen_expr_node.pos, loop = loop_node, target = target,
gen_expr_node.pos, loop = loop_node,
append = append_node, type = Builtin.list_type,
expr_scope = gen_expr_node.expr_scope,
has_local_scope = True)
append_node.target = listcomp_node
else:
return node
......@@ -1550,7 +1549,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# the items into a list and then copy them into a tuple of the
# final size. This takes up to twice as much memory, but will
# have to do until we have real support for genexps.
result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
if result is not node:
return ExprNodes.AsTupleNode(node.pos, arg=result)
return node
......@@ -1558,14 +1557,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
def _handle_simple_function_list(self, node, pos_args):
if not pos_args:
return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
def _handle_simple_function_set(self, node, pos_args):
if not pos_args:
return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
def _transform_list_set_genexpr(self, node, pos_args, target_type):
"""Replace set(genexpr) and list(genexpr) by a literal comprehension.
"""
if len(pos_args) > 1:
......@@ -1579,23 +1578,21 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if yield_expression is None:
return node
target_node = container_node_class(node.pos, args=[])
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr = yield_expression,
target = ExprNodes.CloneNode(target_node))
expr = yield_expression)
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
setcomp = ExprNodes.ComprehensionNode(
comp = ExprNodes.ComprehensionNode(
node.pos,
has_local_scope = True,
expr_scope = gen_expr_node.expr_scope,
loop = loop_node,
append = append_node,
target = target_node)
append_node.target = setcomp
return setcomp
type = target_type)
append_node.target = comp
return comp
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
......@@ -1618,12 +1615,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if len(yield_expression.args) != 2:
return node
target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
append_node = ExprNodes.DictComprehensionAppendNode(
yield_expression.pos,
key_expr = yield_expression.args[0],
value_expr = yield_expression.args[1],
target = ExprNodes.CloneNode(target_node))
value_expr = yield_expression.args[1])
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
......@@ -1633,7 +1628,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
expr_scope = gen_expr_node.expr_scope,
loop = loop_node,
append = append_node,
target = target_node)
type = Builtin.dict_type)
append_node.target = dictcomp
return dictcomp
......@@ -3245,7 +3240,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
self.visitchildren(node)
if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
# loop was pruned already => transform into literal
return node.target
if node.type is Builtin.list_type:
return ExprNodes.ListNode(node.pos, args=[])
elif node.type is Builtin.set_type:
return ExprNodes.SetNode(node.pos, args=[])
elif node.type is Builtin.dict_type:
return ExprNodes.DictNode(node.pos, key_value_pairs=[])
return node
def visit_ForInStatNode(self, node):
......
......@@ -7,7 +7,8 @@
import cython
cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
StringEncoding=object, lookup_unicodechar=object, re=object,
Future=object, Options=object, error=object, warning=object)
Future=object, Options=object, error=object, warning=object,
Builtin=object)
import re
from unicodedata import lookup as lookup_unicodechar
......@@ -15,6 +16,7 @@ from unicodedata import lookup as lookup_unicodechar
from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor
import Nodes
import ExprNodes
import Builtin
import StringEncoding
from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes
from ModuleNode import ModuleNode
......@@ -897,13 +899,11 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = [])
expr = p_test(s)
if s.sy == 'for':
target = ExprNodes.ListNode(pos, args = [])
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
append = ExprNodes.ComprehensionAppendNode(pos, expr=expr)
loop = p_comp_for(s, append)
s.expect(']')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target,
pos, loop=loop, append=append, type = Builtin.list_type,
# list comprehensions leak their loop variable in Py2
has_local_scope = s.context.language_level >= 3)
else:
......@@ -964,13 +964,12 @@ def p_dict_or_set_maker(s):
return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for':
# set comprehension
target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target))
item.pos, expr=item)
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
pos, loop=loop, append=append, type=Builtin.set_type)
elif s.sy == ':':
# dict literal or comprehension
key = item
......@@ -978,14 +977,12 @@ def p_dict_or_set_maker(s):
value = p_test(s)
if s.sy == 'for':
# dict comprehension
target = ExprNodes.DictNode(pos, key_value_pairs = [])
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr=key, value_expr=value,
target=ExprNodes.CloneNode(target))
item.pos, key_expr=key, value_expr=value)
loop = p_comp_for(s, append)
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
pos, loop=loop, append=append, type=Builtin.dict_type)
else:
# dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment