Commit e8d25746 authored by Robert Bradshaw's avatar Robert Bradshaw

List comprehension

parent 061558c7
......@@ -916,10 +916,23 @@ class IteratorNode(ExprNode):
self.sequence = self.sequence.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
self.counter = TempNode(self.pos, PyrexTypes.c_py_ssize_t_type, env)
self.counter.allocate_temp(env)
def release_temp(self, env):
env.release_temp(self.result_code)
self.counter.release_temp(env)
def generate_result_code(self, code):
code.putln(
"%s = PyObject_GetIter(%s); if (!%s) %s" % (
"if (PyList_CheckExact(%s)) { %s = 0; %s = %s; Py_INCREF(%s); }" % (
self.sequence.py_result(),
self.counter.result_code,
self.result_code,
self.sequence.py_result(),
self.result_code))
code.putln("else { %s = PyObject_GetIter(%s); if (!%s) %s }" % (
self.result_code,
self.sequence.py_result(),
self.result_code,
......@@ -941,6 +954,16 @@ class NextNode(AtomicExprNode):
self.is_temp = 1
def generate_result_code(self, code):
code.putln(
"if (PyList_CheckExact(%s)) { if (%s >= PyList_GET_SIZE(%s)) break; %s = PyList_GET_ITEM(%s, %s++); Py_INCREF(%s); }" % (
self.iterator.py_result(),
self.iterator.counter.result_code,
self.iterator.py_result(),
self.result_code,
self.iterator.py_result(),
self.iterator.counter.result_code,
self.result_code))
code.putln("else {")
code.putln(
"%s = PyIter_Next(%s);" % (
self.result_code,
......@@ -951,10 +974,9 @@ class NextNode(AtomicExprNode):
code.putln(
"if (PyErr_Occurred()) %s" %
code.error_goto(self.pos))
code.putln(
"break;")
code.putln(
"}")
code.putln("break;")
code.putln("}")
code.putln("}")
class ExcValueNode(AtomicExprNode):
......@@ -1832,6 +1854,51 @@ class ListNode(SequenceNode):
for arg in self.args:
arg.generate_post_assignment_code(code)
class ListComprehensionNode(SequenceNode):
subexprs = []
def analyse_types(self, env):
self.type = py_object_type
self.is_temp = 1
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
def allocate_temps(self, env, result = None):
if debug_temp_alloc:
print self, "Allocating temps"
self.allocate_temp(env, result)
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); if (!%s) %s" %
(self.result_code,
0,
self.result_code,
code.error_goto(self.pos)))
self.loop.generate_execution_code(code)
class ListComprehensionAppendNode(ExprNode):
subexprs = ['expr']
def analyse_types(self, env):
self.expr.analyse_types(env)
if self.expr.type != py_object_type:
self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
def generate_result_code(self, code):
code.putln("%s = PyList_Append(%s, %s); if (%s) %s" %
(self.result_code,
self.target.result_code,
self.expr.result_code,
self.result_code,
code.error_goto(self.pos)))
class DictNode(ExprNode):
# Dictionary constructor.
......@@ -2979,6 +3046,11 @@ class CloneNode(CoercionNode):
def calculate_result_code(self):
return self.arg.result_code
def analyse_types(self, env):
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
self.is_temp = 1
#def result_as_extension_type(self):
# return self.arg.result_as_extension_type()
......
......@@ -3065,6 +3065,12 @@ class WhileStatNode(StatNode):
code.put_label(break_label)
def ForStatNode(pos, **kw):
if kw.has_key('iterator'):
return ForInStatNode(pos, **kw)
else:
return ForFromStatNode(pos, **kw)
class ForInStatNode(StatNode):
# for statement
#
......
......@@ -573,14 +573,65 @@ def unquote(s):
s = "".join(l2)
return s
# list_display ::= "[" [listmaker] "]"
# listmaker ::= expression ( list_for | ( "," expression )* [","] )
# list_iter ::= list_for | list_if
# list_for ::= "for" expression_list "in" testlist [list_iter]
# list_if ::= "if" test [list_iter]
def p_list_maker(s):
# s.sy == '['
pos = s.position()
s.next()
exprs = p_simple_expr_list(s)
s.expect(']')
return ExprNodes.ListNode(pos, args = exprs)
if s.sy == ']':
s.expect(']')
return ExprNodes.ListNode(pos, args = [])
expr = p_simple_expr(s)
if s.sy == 'for':
loop = p_list_for(s)
s.expect(']')
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
else:
exprs = [expr]
if s.sy == ',':
s.next()
exprs += p_simple_expr_list(s)
s.expect(']')
return ExprNodes.ListNode(pos, args = exprs)
def p_list_iter(s):
if s.sy == 'for':
return p_list_for(s)
elif s.sy == 'if':
return p_list_if(s)
else:
return Nodes.PassStatNode(s.position())
def p_list_for(s):
# s.sy == 'for'
pos = s.position()
s.next()
kw = p_for_bounds(s)
kw['else_clause'] = None
kw['body'] = p_list_iter(s)
return Nodes.ForStatNode(pos, **kw)
def p_list_if(s):
# s.sy == 'if'
pos = s.position()
s.next()
test = p_simple_expr(s)
return Nodes.IfStatNode(pos,
if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
else_clause = None )
#dictmaker: test ':' test (',' test ':' test)* [',']
def p_dict_maker(s):
......@@ -931,17 +982,17 @@ def p_for_statement(s):
# s.sy == 'for'
pos = s.position()
s.next()
kw = p_for_bounds(s)
kw['body'] = p_suite(s)
kw['else_clause'] = p_else_clause(s)
return Nodes.ForStatNode(pos, **kw)
def p_for_bounds(s):
target = p_for_target(s)
if s.sy == 'in':
s.next()
iterator = p_for_iterator(s)
body = p_suite(s)
else_clause = p_else_clause(s)
return Nodes.ForInStatNode(pos,
target = target,
iterator = iterator,
body = body,
else_clause = else_clause)
return { 'target': target, 'iterator': iterator }
elif s.sy == 'from':
s.next()
bound1 = p_bit_expr(s)
......@@ -960,16 +1011,11 @@ def p_for_statement(s):
if rel1[0] <> rel2[0]:
error(rel2_pos,
"Relation directions in for-from do not match")
body = p_suite(s)
else_clause = p_else_clause(s)
return Nodes.ForFromStatNode(pos,
target = target,
bound1 = bound1,
relation1 = rel1,
relation2 = rel2,
bound2 = bound2,
body = body,
else_clause = else_clause)
return {'target': target,
'bound1': bound1,
'relation1': rel1,
'relation2': rel2,
'bound2': bound2 }
def p_for_from_relation(s):
if s.sy in inequality_relations:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment