Commit eee86b80 authored by Stefan Behnel's avatar Stefan Behnel

major cleanup for comprehension code to remove redundant classes

parent 21ffedf5
...@@ -3180,63 +3180,33 @@ class ListNode(SequenceNode): ...@@ -3180,63 +3180,33 @@ class ListNode(SequenceNode):
# generate_evaluation_code which will do that. # generate_evaluation_code which will do that.
class ComprehensionNode(SequenceNode): class ComprehensionNode(NewTempExprNode):
subexprs = [] subexprs = ["target"]
is_sequence_constructor = 0 # not unpackable
comp_result_type = py_object_type
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.comp_result_type self.target.analyse_expressions(env)
self.is_temp = 1 self.type = self.target.type
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def allocate_temps(self, env, result = None): def allocate_temps(self, env, result = None):
if debug_temp_alloc: if debug_temp_alloc:
print("%s Allocating temps" % self) print("%s Allocating temps" % self)
self.allocate_temp(env, result) self.allocate_temp(env, result)
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
0,
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
def annotate(self, code):
self.loop.annotate(code)
class ListComprehensionNode(ComprehensionNode): def calculate_result_code(self):
comp_result_type = list_type return self.target.result()
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
0,
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class SetComprehensionNode(ComprehensionNode): def generate_result_code(self, code):
comp_result_type = set_type self.generate_operation_code(code)
def generate_operation_code(self, code): def generate_operation_code(self, code):
code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size!
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code) self.loop.generate_execution_code(code)
class DictComprehensionNode(ComprehensionNode): def annotate(self, code):
comp_result_type = dict_type self.loop.annotate(code)
def generate_operation_code(self, code):
code.putln("%s = PyDict_New(); %s" %
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class ComprehensionAppendNode(NewTempExprNode): class ComprehensionAppendNode(NewTempExprNode):
...@@ -3251,18 +3221,18 @@ class ComprehensionAppendNode(NewTempExprNode): ...@@ -3251,18 +3221,18 @@ class ComprehensionAppendNode(NewTempExprNode):
self.type = PyrexTypes.c_int_type self.type = PyrexTypes.c_int_type
self.is_temp = 1 self.is_temp = 1
class ListComprehensionAppendNode(ComprehensionAppendNode):
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" % if self.target.type is list_type:
(self.result(), function = "PyList_Append"
self.target.result(), elif self.target.type is set_type:
self.expr.result(), function = "PySet_Add"
code.error_goto_if(self.result(), self.pos))) else:
raise InternalError(
"Invalid type for comprehension node: %s" % self.target.type)
class SetComprehensionAppendNode(ComprehensionAppendNode): code.putln("%s = %s(%s, (PyObject*)%s); %s" %
def generate_result_code(self, code):
code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
(self.result(), (self.result(),
function,
self.target.result(), self.target.result(),
self.expr.result(), self.expr.result(),
code.error_goto_if(self.result(), self.pos))) code.error_goto_if(self.result(), self.pos)))
......
...@@ -488,9 +488,11 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform): ...@@ -488,9 +488,11 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)): if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args, return ExprNodes.SetNode(node.pos, args=iterable.args,
type=Builtin.set_type, is_temp=1) type=Builtin.set_type, is_temp=1)
elif isinstance(iterable, ExprNodes.ListComprehensionNode): elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
iterable.__class__ = ExprNodes.SetComprehensionNode iterable.type is Builtin.list_type:
iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode iterable.target = ExprNodes.SetNode(
node.pos, args=[], type=Builtin.set_type, is_temp=1)
iterable.type = Builtin.set_type
iterable.pos = node.pos iterable.pos = node.pos
return iterable return iterable
else: else:
......
...@@ -699,11 +699,13 @@ def p_list_maker(s): ...@@ -699,11 +699,13 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = []) return ExprNodes.ListNode(pos, args = [])
expr = p_simple_expr(s) expr = p_simple_expr(s)
if s.sy == 'for': if s.sy == 'for':
loop = p_list_for(s) target = ExprNodes.ListNode(pos, args = [])
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect(']') s.expect(']')
append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr ) return ExprNodes.ComprehensionNode(
set_inner_comp_append(loop, append) pos, loop=loop, append=append, target=target)
return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
else: else:
exprs = [expr] exprs = [expr]
if s.sy == ',': if s.sy == ',':
...@@ -712,40 +714,34 @@ def p_list_maker(s): ...@@ -712,40 +714,34 @@ def p_list_maker(s):
s.expect(']') s.expect(']')
return ExprNodes.ListNode(pos, args = exprs) return ExprNodes.ListNode(pos, args = exprs)
def p_list_iter(s): def p_list_iter(s, body):
if s.sy == 'for': if s.sy == 'for':
return p_list_for(s) return p_list_for(s, body)
elif s.sy == 'if': elif s.sy == 'if':
return p_list_if(s) return p_list_if(s, body)
else: else:
return Nodes.PassStatNode(s.position()) # insert the 'append' operation into the loop
return body
def p_list_for(s): def p_list_for(s, body):
# s.sy == 'for' # s.sy == 'for'
pos = s.position() pos = s.position()
s.next() s.next()
kw = p_for_bounds(s) kw = p_for_bounds(s)
kw['else_clause'] = None kw['else_clause'] = None
kw['body'] = p_list_iter(s) kw['body'] = p_list_iter(s, body)
return Nodes.ForStatNode(pos, **kw) return Nodes.ForStatNode(pos, **kw)
def p_list_if(s): def p_list_if(s, body):
# s.sy == 'if' # s.sy == 'if'
pos = s.position() pos = s.position()
s.next() s.next()
test = p_test(s) test = p_test(s)
return Nodes.IfStatNode(pos, return Nodes.IfStatNode(pos,
if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))], if_clauses = [Nodes.IfClauseNode(pos, condition = test,
body = p_list_iter(s, body))],
else_clause = None ) else_clause = None )
def set_inner_comp_append(loop, append):
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
#dictmaker: test ':' test (',' test ':' test)* [','] #dictmaker: test ':' test (',' test ':' test)* [',']
def p_dict_or_set_maker(s): def p_dict_or_set_maker(s):
...@@ -768,11 +764,13 @@ def p_dict_or_set_maker(s): ...@@ -768,11 +764,13 @@ def p_dict_or_set_maker(s):
return ExprNodes.SetNode(pos, args=values) return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for': elif s.sy == 'for':
# set comprehension # set comprehension
loop = p_list_for(s) target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target))
loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect('}') s.expect('}')
append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item) return ExprNodes.ComprehensionNode(
set_inner_comp_append(loop, append) pos, loop=loop, append=append, target=target)
return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
elif s.sy == ':': elif s.sy == ':':
# dict literal or comprehension # dict literal or comprehension
key = item key = item
...@@ -780,12 +778,14 @@ def p_dict_or_set_maker(s): ...@@ -780,12 +778,14 @@ def p_dict_or_set_maker(s):
value = p_simple_expr(s) value = p_simple_expr(s)
if s.sy == 'for': if s.sy == 'for':
# dict comprehension # dict comprehension
loop = p_list_for(s) target = ExprNodes.DictNode(pos, key_value_pairs = [])
s.expect('}')
append = ExprNodes.DictComprehensionAppendNode( append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr = key, value_expr = value) item.pos, key_expr=key, value_expr=value,
set_inner_comp_append(loop, append) target=ExprNodes.CloneNode(target))
return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append) loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
else: else:
# dict literal # dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)] items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment