Commit eee86b80 authored by Stefan Behnel's avatar Stefan Behnel

major cleanup for comprehension code to remove redundant classes

parent 21ffedf5
......@@ -3180,63 +3180,33 @@ class ListNode(SequenceNode):
# generate_evaluation_code which will do that.
class ComprehensionNode(SequenceNode):
subexprs = []
is_sequence_constructor = 0 # not unpackable
comp_result_type = py_object_type
class ComprehensionNode(NewTempExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
def analyse_types(self, env):
self.type = self.comp_result_type
self.is_temp = 1
self.target.analyse_expressions(env)
self.type = self.target.type
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def allocate_temps(self, env, result = None):
if debug_temp_alloc:
print("%s Allocating temps" % self)
self.allocate_temp(env, result)
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
0,
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
def annotate(self, code):
self.loop.annotate(code)
class ListComprehensionNode(ComprehensionNode):
comp_result_type = list_type
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
0,
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
def calculate_result_code(self):
return self.target.result()
class SetComprehensionNode(ComprehensionNode):
comp_result_type = set_type
def generate_result_code(self, code):
self.generate_operation_code(code)
def generate_operation_code(self, code):
code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size!
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class DictComprehensionNode(ComprehensionNode):
comp_result_type = dict_type
def generate_operation_code(self, code):
code.putln("%s = PyDict_New(); %s" %
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
def annotate(self, code):
self.loop.annotate(code)
class ComprehensionAppendNode(NewTempExprNode):
......@@ -3251,18 +3221,18 @@ class ComprehensionAppendNode(NewTempExprNode):
self.type = PyrexTypes.c_int_type
self.is_temp = 1
class ListComprehensionAppendNode(ComprehensionAppendNode):
def generate_result_code(self, code):
code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" %
(self.result(),
self.target.result(),
self.expr.result(),
code.error_goto_if(self.result(), self.pos)))
if self.target.type is list_type:
function = "PyList_Append"
elif self.target.type is set_type:
function = "PySet_Add"
else:
raise InternalError(
"Invalid type for comprehension node: %s" % self.target.type)
class SetComprehensionAppendNode(ComprehensionAppendNode):
def generate_result_code(self, code):
code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
code.putln("%s = %s(%s, (PyObject*)%s); %s" %
(self.result(),
function,
self.target.result(),
self.expr.result(),
code.error_goto_if(self.result(), self.pos)))
......
......@@ -488,9 +488,11 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args,
type=Builtin.set_type, is_temp=1)
elif isinstance(iterable, ExprNodes.ListComprehensionNode):
iterable.__class__ = ExprNodes.SetComprehensionNode
iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
iterable.type is Builtin.list_type:
iterable.target = ExprNodes.SetNode(
node.pos, args=[], type=Builtin.set_type, is_temp=1)
iterable.type = Builtin.set_type
iterable.pos = node.pos
return iterable
else:
......
......@@ -699,11 +699,13 @@ def p_list_maker(s):
return ExprNodes.ListNode(pos, args = [])
expr = p_simple_expr(s)
if s.sy == 'for':
loop = p_list_for(s)
target = ExprNodes.ListNode(pos, args = [])
append = ExprNodes.ComprehensionAppendNode(
pos, expr=expr, target=ExprNodes.CloneNode(target))
loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect(']')
append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
set_inner_comp_append(loop, append)
return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
else:
exprs = [expr]
if s.sy == ',':
......@@ -712,40 +714,34 @@ def p_list_maker(s):
s.expect(']')
return ExprNodes.ListNode(pos, args = exprs)
def p_list_iter(s):
def p_list_iter(s, body):
if s.sy == 'for':
return p_list_for(s)
return p_list_for(s, body)
elif s.sy == 'if':
return p_list_if(s)
return p_list_if(s, body)
else:
return Nodes.PassStatNode(s.position())
# insert the 'append' operation into the loop
return body
def p_list_for(s):
def p_list_for(s, body):
# s.sy == 'for'
pos = s.position()
s.next()
kw = p_for_bounds(s)
kw['else_clause'] = None
kw['body'] = p_list_iter(s)
kw['body'] = p_list_iter(s, body)
return Nodes.ForStatNode(pos, **kw)
def p_list_if(s):
def p_list_if(s, body):
# s.sy == 'if'
pos = s.position()
s.next()
test = p_test(s)
return Nodes.IfStatNode(pos,
if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
if_clauses = [Nodes.IfClauseNode(pos, condition = test,
body = p_list_iter(s, body))],
else_clause = None )
def set_inner_comp_append(loop, append):
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
#dictmaker: test ':' test (',' test ':' test)* [',']
def p_dict_or_set_maker(s):
......@@ -768,11 +764,13 @@ def p_dict_or_set_maker(s):
return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for':
# set comprehension
loop = p_list_for(s)
target = ExprNodes.SetNode(pos, args=[])
append = ExprNodes.ComprehensionAppendNode(
item.pos, expr=item, target=ExprNodes.CloneNode(target))
loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect('}')
append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item)
set_inner_comp_append(loop, append)
return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
elif s.sy == ':':
# dict literal or comprehension
key = item
......@@ -780,12 +778,14 @@ def p_dict_or_set_maker(s):
value = p_simple_expr(s)
if s.sy == 'for':
# dict comprehension
loop = p_list_for(s)
s.expect('}')
target = ExprNodes.DictNode(pos, key_value_pairs = [])
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr = key, value_expr = value)
set_inner_comp_append(loop, append)
return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append)
item.pos, key_expr=key, value_expr=value,
target=ExprNodes.CloneNode(target))
loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
s.expect('}')
return ExprNodes.ComprehensionNode(
pos, loop=loop, append=append, target=target)
else:
# dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment