Commit eac2e5ff authored by Stefan Behnel's avatar Stefan Behnel

implement set/dict comprehensions and set literals

parent 4a8f753f
......@@ -12,7 +12,7 @@ import Naming
from Nodes import Node
import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type
from Builtin import list_type, tuple_type, dict_type, unicode_type
from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type
import Symtab
import Options
from Annotate import AnnotationItem
......@@ -3007,7 +3007,7 @@ class ListNode(SequenceNode):
gil_message = "Constructing Python list"
def analyse_expressions(self, env):
ExprNode.analyse_expressions(self, env)
SequenceNode.analyse_expressions(self, env)
self.coerce_to_pyobject(env)
def analyse_types(self, env):
......@@ -3102,15 +3102,15 @@ class ListNode(SequenceNode):
# generate_evaluation_code which will do that.
class ListComprehensionNode(SequenceNode):
class ComprehensionNode(SequenceNode):
subexprs = []
is_sequence_constructor = 0 # not unpackable
comp_result_type = py_object_type
child_attrs = ["loop", "append"]
def analyse_types(self, env):
self.type = list_type
self.type = self.comp_result_type
self.is_temp = 1
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
......@@ -3132,19 +3132,48 @@ class ListComprehensionNode(SequenceNode):
self.loop.annotate(code)
class ListComprehensionAppendNode(ExprNode):
class ListComprehensionNode(ComprehensionNode):
comp_result_type = list_type
def generate_operation_code(self, code):
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
0,
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class SetComprehensionNode(ComprehensionNode):
comp_result_type = set_type
def generate_operation_code(self, code):
code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size!
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class DictComprehensionNode(ComprehensionNode):
comp_result_type = dict_type
def generate_operation_code(self, code):
code.putln("%s = PyDict_New(); %s" %
(self.result(),
code.error_goto_if_null(self.result(), self.pos)))
self.loop.generate_execution_code(code)
class ComprehensionAppendNode(NewTempExprNode):
# Need to be careful to avoid infinite recursion:
# target must not be in child_attrs/subexprs
subexprs = ['expr']
def analyse_types(self, env):
self.expr.analyse_types(env)
if self.expr.type != py_object_type:
if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
class ListComprehensionAppendNode(ComprehensionAppendNode):
def generate_result_code(self, code):
code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" %
(self.result(),
......@@ -3152,6 +3181,78 @@ class ListComprehensionAppendNode(ExprNode):
self.expr.result(),
code.error_goto_if(self.result(), self.pos)))
class SetComprehensionAppendNode(ComprehensionAppendNode):
def generate_result_code(self, code):
code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
(self.result(),
self.target.result(),
self.expr.result(),
code.error_goto_if(self.result(), self.pos)))
class DictComprehensionAppendNode(ComprehensionAppendNode):
subexprs = ['key_expr', 'value_expr']
def analyse_types(self, env):
self.key_expr.analyse_types(env)
if not self.key_expr.type.is_pyobject:
self.key_expr = self.key_expr.coerce_to_pyobject(env)
self.value_expr.analyse_types(env)
if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
def generate_result_code(self, code):
code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
(self.result(),
self.target.result(),
self.key_expr.result(),
self.value_expr.result(),
code.error_goto_if(self.result(), self.pos)))
class SetNode(NewTempExprNode):
# Set constructor.
subexprs = ['args']
gil_message = "Constructing Python set"
def analyse_types(self, env):
for i in range(len(self.args)):
arg = self.args[i]
arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env)
self.type = set_type
self.gil_check(env)
self.is_temp = 1
def compile_time_value(self, denv):
values = [arg.compile_time_value(denv) for arg in self.args]
try:
set
except NameError:
from sets import Set as set
try:
return set(values)
except Exception, e:
self.compile_time_value_error(e)
def generate_evaluation_code(self, code):
self.allocate_temp_result(code)
code.putln(
"%s = PySet_New(0); %s" % (
self.result(),
code.error_goto_if_null(self.result(), self.pos)))
for arg in self.args:
arg.generate_evaluation_code(code)
code.putln(
code.error_goto_if_neg(
"PySet_Add(%s, %s)" % (self.result(), arg.py_result()),
self.pos))
arg.generate_disposal_code(code)
arg.free_temps(code)
class DictNode(ExprNode):
# Dictionary constructor.
......
......@@ -473,7 +473,7 @@ def make_slice_node(pos, start, stop = None, step = None):
return ExprNodes.SliceNode(pos,
start = start, stop = stop, step = step)
#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dictmaker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dict_or_set_maker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
def p_atom(s):
pos = s.position()
......@@ -491,7 +491,7 @@ def p_atom(s):
elif sy == '[':
return p_list_maker(s)
elif sy == '{':
return p_dict_maker(s)
return p_dict_or_set_maker(s)
elif sy == '`':
return p_backquote_expr(s)
elif sy == 'INT':
......@@ -701,13 +701,8 @@ def p_list_maker(s):
if s.sy == 'for':
loop = p_list_for(s)
s.expect(']')
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
set_inner_comp_append(loop, append)
return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
else:
exprs = [expr]
......@@ -743,26 +738,68 @@ def p_list_if(s):
if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
else_clause = None )
def set_inner_comp_append(loop, append):
inner_loop = loop
while not isinstance(inner_loop.body, Nodes.PassStatNode):
inner_loop = inner_loop.body
if isinstance(inner_loop, Nodes.IfStatNode):
inner_loop = inner_loop.if_clauses[0]
inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
#dictmaker: test ':' test (',' test ':' test)* [',']
def p_dict_maker(s):
def p_dict_or_set_maker(s):
# s.sy == '{'
pos = s.position()
s.next()
items = []
while s.sy != '}':
items.append(p_dict_item(s))
if s.sy != ',':
break
if s.sy == '}':
s.next()
return ExprNodes.DictNode(pos, key_value_pairs = [])
item = p_simple_expr(s)
if s.sy == ',' or s.sy == '}':
# set literal
values = [item]
while s.sy == ',':
s.next()
values.append( p_simple_expr(s) )
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs = items)
def p_dict_item(s):
return ExprNodes.SetNode(pos, args=values)
elif s.sy == 'for':
# set comprehension
loop = p_list_for(s)
s.expect('}')
append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item)
set_inner_comp_append(loop, append)
return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
elif s.sy == ':':
# dict literal or comprehension
key = item
s.next()
value = p_simple_expr(s)
if s.sy == 'for':
# dict comprehension
loop = p_list_for(s)
s.expect('}')
append = ExprNodes.DictComprehensionAppendNode(
item.pos, key_expr = key, value_expr = value)
set_inner_comp_append(loop, append)
return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append)
else:
# dict literal
items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
while s.sy == ',':
s.next()
key = p_simple_expr(s)
s.expect(':')
value = p_simple_expr(s)
return ExprNodes.DictItemNode(key.pos, key=key, value=value)
items.append(
ExprNodes.DictItemNode(key.pos, key=key, value=value))
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs=items)
else:
# raise an error
s.expect('}')
return ExprNodes.DictNode(pos, key_value_pairs = [])
def p_backquote_expr(s):
# s.sy == '`'
......
u"""
>>> type(smoketest()) is dict
True
>>> sorted(smoketest().items())
[(2, 0), (4, 4), (6, 8)]
>>> list(typed().items())
[(A, 1), (A, 1), (A, 1)]
>>> sorted(iterdict().items())
[(1, 'a'), (2, 'b'), (3, 'c')]
"""
def smoketest():
return {x+2:x*2 for x in range(5) if x % 2 == 0}
cdef class A:
def __repr__(self): return u"A"
def __richcmp__(one, other, op): return one is other
def __hash__(self): return id(self) % 65536
def typed():
cdef A obj
return {obj:1 for obj in [A(), A(), A()]}
def iterdict():
cdef dict d = dict(a=1,b=2,c=3)
return {d[key]:key for key in d}
def sorted(it):
l = list(it)
l.sort()
return l
__doc__ = u"""
>>> test_set_add()
set(['a', 1])
>>> test_set_clear()
set([])
>>> test_set_pop()
set([])
>>> test_set_discard()
set([233, '12'])
u"""
>>> type(test_set_literal()) is _set
True
>>> sorted(test_set_literal())
['a', 'b', 1]
>>> type(test_set_add()) is _set
True
>>> sorted(test_set_add())
['a', 1]
>>> type(test_set_add()) is _set
True
>>> list(test_set_clear())
[]
>>> type(test_set_pop()) is _set
True
>>> list(test_set_pop())
[]
>>> type(test_set_discard()) is _set
True
>>> sorted(test_set_discard())
['12', 233]
"""
# Py2.3 doesn't have the 'set' builtin type, but Cython does :)
_set = set
def test_set_literal():
cdef set s1 = {1,'a',1,'b','a'}
return s1
def test_set_add():
cdef set s1
s1 = set([1])
......@@ -40,3 +63,15 @@ def test_set_discard():
s1.discard(3)
return s1
def sorted(it):
# Py3 can't compare strings to ints
chars = []
nums = []
for item in it:
if type(item) is int:
nums.append(item)
else:
chars.append(item)
nums.sort()
chars.sort()
return chars+nums
u"""
>>> type(smoketest()) is not list
True
>>> type(smoketest()) is _set
True
>>> sorted(smoketest())
[0, 4, 8]
>>> list(typed())
[A, A, A]
>>> sorted(iterdict())
[1, 2, 3]
"""
# Py2.3 doesn't have the set type, but Cython does :)
_set = set
def smoketest():
return {x*2 for x in range(5) if x % 2 == 0}
cdef class A:
def __repr__(self): return u"A"
def __richcmp__(one, other, op): return one is other
def __hash__(self): return id(self) % 65536
def typed():
cdef A obj
return {obj for obj in {A(), A(), A()}}
def iterdict():
cdef dict d = dict(a=1,b=2,c=3)
return {d[key] for key in d}
def sorted(it):
l = list(it)
l.sort()
return l
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment