Commit 6430e5b4 authored by Stefan Behnel's avatar Stefan Behnel

optimise dict([ (x,y) for x,y in ... ]) into dict comprehension

parent 6cbb99b3
......@@ -3271,6 +3271,8 @@ class ListNode(SequenceNode):
# obj_conversion_errors [PyrexError] used internally
# orignial_args [ExprNode] used internally
obj_conversion_errors = []
gil_message = "Constructing Python list"
def analyse_expressions(self, env):
......@@ -3404,11 +3406,12 @@ class ComprehensionAppendNode(ExprNode):
# target must not be in child_attrs/subexprs
subexprs = ['expr']
type = PyrexTypes.c_int_type
def analyse_types(self, env):
self.expr.analyse_types(env)
if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
def generate_result_code(self, code):
......@@ -3437,7 +3440,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
self.value_expr.analyse_types(env)
if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1
def generate_result_code(self, code):
......@@ -3502,6 +3504,9 @@ class DictNode(ExprNode):
subexprs = ['key_value_pairs']
type = dict_type
obj_conversion_errors = []
def calculate_constant_result(self):
self.constant_result = dict([
item.constant_result for item in self.key_value_pairs])
......
......@@ -34,7 +34,6 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False
class IterationTransform(Visitor.VisitorTransform):
"""Transform some common for-in loop patterns into efficient C loops:
......@@ -613,24 +612,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
])
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict(some_dict) by PyDict_Copy(some_dict).
"""Replace dict(some_dict) by PyDict_Copy(some_dict) and
dict([ (a,b) for ... ]) by a literal { a:b for ... }.
"""
if len(pos_args.args) != 1:
return node
dict_arg = pos_args.args[0]
if dict_arg.type is not Builtin.dict_type:
return node
dict_arg = ExprNodes.NoneCheckNode(
dict_arg, "PyExc_TypeError", "'NoneType' is not iterable")
arg = pos_args.args[0]
if arg.type is Builtin.dict_type:
arg = ExprNodes.NoneCheckNode(
arg, "PyExc_TypeError", "'NoneType' is not iterable")
return ExprNodes.PythonCapiCallNode(
node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
args = [dict_arg],
is_temp = node.is_temp
)
elif isinstance(arg, ExprNodes.ComprehensionNode) and \
arg.type is Builtin.list_type:
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[], is_temp=1)
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node,
is_temp=1)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...}.
"""Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
"""
arg_count = len(pos_args.args)
if arg_count == 0:
......
__doc__ = u"""
>>> type(smoketest()) is dict
>>> type(smoketest_dict()) is dict
True
>>> type(smoketest_list()) is dict
True
>>> sorted(smoketest().items())
>>> sorted(smoketest_dict().items())
[(2, 0), (4, 4), (6, 8)]
>>> sorted(smoketest_list().items())
[(2, 0), (4, 4), (6, 8)]
>>> list(typed().items())
[(A, 1), (A, 1), (A, 1)]
>>> sorted(iterdict().items())
[(1, 'a'), (2, 'b'), (3, 'c')]
"""
def smoketest():
return {x+2:x*2 for x in range(5) if x % 2 == 0}
def smoketest_dict():
return { x+2:x*2
for x in range(5)
if x % 2 == 0 }
def smoketest_list():
return dict([ (x+2,x*2)
for x in range(5)
if x % 2 == 0 ])
cdef class A:
def __repr__(self): return u"A"
......
__doc__ = u"""
>>> type(smoketest()) is not list
>>> type(smoketest_set()) is not list
True
>>> type(smoketest()) is _set
>>> type(smoketest_set()) is _set
True
>>> type(smoketest_list()) is _set
True
>>> sorted(smoketest())
>>> sorted(smoketest_set())
[0, 4, 8]
>>> sorted(smoketest_list())
[0, 4, 8]
>>> list(typed())
[A, A, A]
>>> sorted(iterdict())
......@@ -15,8 +20,15 @@ True
# Py2.3 doesn't have the set type, but Cython does :)
_set = set
def smoketest():
return {x*2 for x in range(5) if x % 2 == 0}
def smoketest_set():
return { x*2
for x in range(5)
if x % 2 == 0 }
def smoketest_list():
return set([ x*2
for x in range(5)
if x % 2 == 0 ])
cdef class A:
def __repr__(self): return u"A"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment