Commit e8d25746 authored by Robert Bradshaw's avatar Robert Bradshaw

List comprehension

parent 061558c7
...@@ -917,9 +917,22 @@ class IteratorNode(ExprNode): ...@@ -917,9 +917,22 @@ class IteratorNode(ExprNode):
self.type = py_object_type self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
self.counter = TempNode(self.pos, PyrexTypes.c_py_ssize_t_type, env)
self.counter.allocate_temp(env)
def release_temp(self, env):
env.release_temp(self.result_code)
self.counter.release_temp(env)
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln( code.putln(
"%s = PyObject_GetIter(%s); if (!%s) %s" % ( "if (PyList_CheckExact(%s)) { %s = 0; %s = %s; Py_INCREF(%s); }" % (
self.sequence.py_result(),
self.counter.result_code,
self.result_code,
self.sequence.py_result(),
self.result_code))
code.putln("else { %s = PyObject_GetIter(%s); if (!%s) %s }" % (
self.result_code, self.result_code,
self.sequence.py_result(), self.sequence.py_result(),
self.result_code, self.result_code,
...@@ -941,6 +954,16 @@ class NextNode(AtomicExprNode): ...@@ -941,6 +954,16 @@ class NextNode(AtomicExprNode):
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
code.putln(
"if (PyList_CheckExact(%s)) { if (%s >= PyList_GET_SIZE(%s)) break; %s = PyList_GET_ITEM(%s, %s++); Py_INCREF(%s); }" % (
self.iterator.py_result(),
self.iterator.counter.result_code,
self.iterator.py_result(),
self.result_code,
self.iterator.py_result(),
self.iterator.counter.result_code,
self.result_code))
code.putln("else {")
code.putln( code.putln(
"%s = PyIter_Next(%s);" % ( "%s = PyIter_Next(%s);" % (
self.result_code, self.result_code,
...@@ -951,10 +974,9 @@ class NextNode(AtomicExprNode): ...@@ -951,10 +974,9 @@ class NextNode(AtomicExprNode):
code.putln( code.putln(
"if (PyErr_Occurred()) %s" % "if (PyErr_Occurred()) %s" %
code.error_goto(self.pos)) code.error_goto(self.pos))
code.putln( code.putln("break;")
"break;") code.putln("}")
code.putln( code.putln("}")
"}")
class ExcValueNode(AtomicExprNode): class ExcValueNode(AtomicExprNode):
...@@ -1833,6 +1855,51 @@ class ListNode(SequenceNode): ...@@ -1833,6 +1855,51 @@ class ListNode(SequenceNode):
arg.generate_post_assignment_code(code) arg.generate_post_assignment_code(code)
class ListComprehensionNode(SequenceNode):
subexprs = []
def analyse_types(self, env):
self.type = py_object_type
self.is_temp = 1
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
def allocate_temps(self, env, result = None):
if debug_temp_alloc:
print self, "Allocating temps"
self.allocate_temp(env, result)
self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env)
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); if (!%s) %s" %
(self.result_code,
0,
self.result_code,
code.error_goto(self.pos)))
self.loop.generate_execution_code(code)
class ListComprehensionAppendNode(ExprNode):
subexprs = ['expr']
def analyse_types(self, env):
self.expr.analyse_types(env)
if self.expr.type != py_object_type:
self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
def generate_result_code(self, code):
code.putln("%s = PyList_Append(%s, %s); if (%s) %s" %
(self.result_code,
self.target.result_code,
self.expr.result_code,
self.result_code,
code.error_goto(self.pos)))
class DictNode(ExprNode): class DictNode(ExprNode):
# Dictionary constructor. # Dictionary constructor.
# #
...@@ -2980,6 +3047,11 @@ class CloneNode(CoercionNode): ...@@ -2980,6 +3047,11 @@ class CloneNode(CoercionNode):
def calculate_result_code(self): def calculate_result_code(self):
return self.arg.result_code return self.arg.result_code
def analyse_types(self, env):
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
self.is_temp = 1
#def result_as_extension_type(self): #def result_as_extension_type(self):
# return self.arg.result_as_extension_type() # return self.arg.result_as_extension_type()
......
...@@ -3065,6 +3065,12 @@ class WhileStatNode(StatNode): ...@@ -3065,6 +3065,12 @@ class WhileStatNode(StatNode):
code.put_label(break_label) code.put_label(break_label)
def ForStatNode(pos, **kw):
if kw.has_key('iterator'):
return ForInStatNode(pos, **kw)
else:
return ForFromStatNode(pos, **kw)
class ForInStatNode(StatNode): class ForInStatNode(StatNode):
# for statement # for statement
# #
......
...@@ -573,14 +573,65 @@ def unquote(s): ...@@ -573,14 +573,65 @@ def unquote(s):
s = "".join(l2) s = "".join(l2)
return s return s
# list_display ::= "[" [listmaker] "]"
# listmaker ::= expression ( list_for | ( "," expression )* [","] )
# list_iter ::= list_for | list_if
# list_for ::= "for" expression_list "in" testlist [list_iter]
# list_if ::= "if" test [list_iter]
def p_list_maker(s): def p_list_maker(s):
# s.sy == '[' # s.sy == '['
pos = s.position() pos = s.position()
s.next() s.next()
exprs = p_simple_expr_list(s) if s.sy == ']':
s.expect(']')
return ExprNodes.ListNode(pos, args = [])
expr = p_simple_expr(s)
if s.sy == 'for':
loop = p_list_for(s)
s.expect(']')
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
else:
exprs = [expr]
if s.sy == ',':
s.next()
exprs += p_simple_expr_list(s)
s.expect(']') s.expect(']')
return ExprNodes.ListNode(pos, args = exprs) return ExprNodes.ListNode(pos, args = exprs)
def p_list_iter(s):
if s.sy == 'for':
return p_list_for(s)
elif s.sy == 'if':
return p_list_if(s)
else:
return Nodes.PassStatNode(s.position())
def p_list_for(s):
# s.sy == 'for'
pos = s.position()
s.next()
kw = p_for_bounds(s)
kw['else_clause'] = None
kw['body'] = p_list_iter(s)
return Nodes.ForStatNode(pos, **kw)
def p_list_if(s):
# s.sy == 'if'
pos = s.position()
s.next()
test = p_simple_expr(s)
return Nodes.IfStatNode(pos,
if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
else_clause = None )
#dictmaker: test ':' test (',' test ':' test)* [','] #dictmaker: test ':' test (',' test ':' test)* [',']
def p_dict_maker(s): def p_dict_maker(s):
...@@ -931,17 +982,17 @@ def p_for_statement(s): ...@@ -931,17 +982,17 @@ def p_for_statement(s):
# s.sy == 'for' # s.sy == 'for'
pos = s.position() pos = s.position()
s.next() s.next()
kw = p_for_bounds(s)
kw['body'] = p_suite(s)
kw['else_clause'] = p_else_clause(s)
return Nodes.ForStatNode(pos, **kw)
def p_for_bounds(s):
target = p_for_target(s) target = p_for_target(s)
if s.sy == 'in': if s.sy == 'in':
s.next() s.next()
iterator = p_for_iterator(s) iterator = p_for_iterator(s)
body = p_suite(s) return { 'target': target, 'iterator': iterator }
else_clause = p_else_clause(s)
return Nodes.ForInStatNode(pos,
target = target,
iterator = iterator,
body = body,
else_clause = else_clause)
elif s.sy == 'from': elif s.sy == 'from':
s.next() s.next()
bound1 = p_bit_expr(s) bound1 = p_bit_expr(s)
...@@ -960,16 +1011,11 @@ def p_for_statement(s): ...@@ -960,16 +1011,11 @@ def p_for_statement(s):
if rel1[0] <> rel2[0]: if rel1[0] <> rel2[0]:
error(rel2_pos, error(rel2_pos,
"Relation directions in for-from do not match") "Relation directions in for-from do not match")
body = p_suite(s) return {'target': target,
else_clause = p_else_clause(s) 'bound1': bound1,
return Nodes.ForFromStatNode(pos, 'relation1': rel1,
target = target, 'relation2': rel2,
bound1 = bound1, 'bound2': bound2 }
relation1 = rel1,
relation2 = rel2,
bound2 = bound2,
body = body,
else_clause = else_clause)
def p_for_from_relation(s): def p_for_from_relation(s):
if s.sy in inequality_relations: if s.sy in inequality_relations:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment