Commit 6430e5b4 authored by Stefan Behnel's avatar Stefan Behnel

optimise dict([ (x,y) for x,y in ... ]) into dict comprehension

parent 6cbb99b3
...@@ -3271,6 +3271,8 @@ class ListNode(SequenceNode): ...@@ -3271,6 +3271,8 @@ class ListNode(SequenceNode):
# obj_conversion_errors [PyrexError] used internally # obj_conversion_errors [PyrexError] used internally
# orignial_args [ExprNode] used internally # orignial_args [ExprNode] used internally
obj_conversion_errors = []
gil_message = "Constructing Python list" gil_message = "Constructing Python list"
def analyse_expressions(self, env): def analyse_expressions(self, env):
...@@ -3403,12 +3405,13 @@ class ComprehensionAppendNode(ExprNode): ...@@ -3403,12 +3405,13 @@ class ComprehensionAppendNode(ExprNode):
# Need to be careful to avoid infinite recursion: # Need to be careful to avoid infinite recursion:
# target must not be in child_attrs/subexprs # target must not be in child_attrs/subexprs
subexprs = ['expr'] subexprs = ['expr']
type = PyrexTypes.c_int_type
def analyse_types(self, env): def analyse_types(self, env):
self.expr.analyse_types(env) self.expr.analyse_types(env)
if not self.expr.type.is_pyobject: if not self.expr.type.is_pyobject:
self.expr = self.expr.coerce_to_pyobject(env) self.expr = self.expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3429,7 +3432,7 @@ class ComprehensionAppendNode(ExprNode): ...@@ -3429,7 +3432,7 @@ class ComprehensionAppendNode(ExprNode):
class DictComprehensionAppendNode(ComprehensionAppendNode): class DictComprehensionAppendNode(ComprehensionAppendNode):
subexprs = ['key_expr', 'value_expr'] subexprs = ['key_expr', 'value_expr']
def analyse_types(self, env): def analyse_types(self, env):
self.key_expr.analyse_types(env) self.key_expr.analyse_types(env)
if not self.key_expr.type.is_pyobject: if not self.key_expr.type.is_pyobject:
...@@ -3437,7 +3440,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -3437,7 +3440,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
self.value_expr.analyse_types(env) self.value_expr.analyse_types(env)
if not self.value_expr.type.is_pyobject: if not self.value_expr.type.is_pyobject:
self.value_expr = self.value_expr.coerce_to_pyobject(env) self.value_expr = self.value_expr.coerce_to_pyobject(env)
self.type = PyrexTypes.c_int_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3502,6 +3504,9 @@ class DictNode(ExprNode): ...@@ -3502,6 +3504,9 @@ class DictNode(ExprNode):
subexprs = ['key_value_pairs'] subexprs = ['key_value_pairs']
type = dict_type
obj_conversion_errors = []
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = dict([ self.constant_result = dict([
item.constant_result for item in self.key_value_pairs]) item.constant_result for item in self.key_value_pairs])
......
...@@ -34,7 +34,6 @@ def is_common_value(a, b): ...@@ -34,7 +34,6 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False return False
class IterationTransform(Visitor.VisitorTransform): class IterationTransform(Visitor.VisitorTransform):
"""Transform some common for-in loop patterns into efficient C loops: """Transform some common for-in loop patterns into efficient C loops:
...@@ -613,24 +612,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -613,24 +612,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
]) ])
def _handle_simple_function_dict(self, node, pos_args): def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict(some_dict) by PyDict_Copy(some_dict). """Replace dict(some_dict) by PyDict_Copy(some_dict) and
dict([ (a,b) for ... ]) by a literal { a:b for ... }.
""" """
if len(pos_args.args) != 1: if len(pos_args.args) != 1:
return node return node
dict_arg = pos_args.args[0] arg = pos_args.args[0]
if dict_arg.type is not Builtin.dict_type: if arg.type is Builtin.dict_type:
return node arg = ExprNodes.NoneCheckNode(
arg, "PyExc_TypeError", "'NoneType' is not iterable")
dict_arg = ExprNodes.NoneCheckNode( return ExprNodes.PythonCapiCallNode(
dict_arg, "PyExc_TypeError", "'NoneType' is not iterable") node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
return ExprNodes.PythonCapiCallNode( args = [dict_arg],
node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, is_temp = node.is_temp
args = [dict_arg], )
is_temp = node.is_temp elif isinstance(arg, ExprNodes.ComprehensionNode) and \
) arg.type is Builtin.list_type:
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[], is_temp=1)
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node,
is_temp=1)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
def _handle_simple_function_set(self, node, pos_args): def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...}. """Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
""" """
arg_count = len(pos_args.args) arg_count = len(pos_args.args)
if arg_count == 0: if arg_count == 0:
......
__doc__ = u""" __doc__ = u"""
>>> type(smoketest()) is dict >>> type(smoketest_dict()) is dict
True
>>> type(smoketest_list()) is dict
True True
>>> sorted(smoketest().items()) >>> sorted(smoketest_dict().items())
[(2, 0), (4, 4), (6, 8)]
>>> sorted(smoketest_list().items())
[(2, 0), (4, 4), (6, 8)] [(2, 0), (4, 4), (6, 8)]
>>> list(typed().items()) >>> list(typed().items())
[(A, 1), (A, 1), (A, 1)] [(A, 1), (A, 1), (A, 1)]
>>> sorted(iterdict().items()) >>> sorted(iterdict().items())
[(1, 'a'), (2, 'b'), (3, 'c')] [(1, 'a'), (2, 'b'), (3, 'c')]
""" """
def smoketest(): def smoketest_dict():
return {x+2:x*2 for x in range(5) if x % 2 == 0} return { x+2:x*2
for x in range(5)
if x % 2 == 0 }
def smoketest_list():
return dict([ (x+2,x*2)
for x in range(5)
if x % 2 == 0 ])
cdef class A: cdef class A:
def __repr__(self): return u"A" def __repr__(self): return u"A"
......
__doc__ = u""" __doc__ = u"""
>>> type(smoketest()) is not list >>> type(smoketest_set()) is not list
True True
>>> type(smoketest()) is _set >>> type(smoketest_set()) is _set
True
>>> type(smoketest_list()) is _set
True True
>>> sorted(smoketest()) >>> sorted(smoketest_set())
[0, 4, 8]
>>> sorted(smoketest_list())
[0, 4, 8] [0, 4, 8]
>>> list(typed()) >>> list(typed())
[A, A, A] [A, A, A]
>>> sorted(iterdict()) >>> sorted(iterdict())
...@@ -15,8 +20,15 @@ True ...@@ -15,8 +20,15 @@ True
# Py2.3 doesn't have the set type, but Cython does :) # Py2.3 doesn't have the set type, but Cython does :)
_set = set _set = set
def smoketest(): def smoketest_set():
return {x*2 for x in range(5) if x % 2 == 0} return { x*2
for x in range(5)
if x % 2 == 0 }
def smoketest_list():
return set([ x*2
for x in range(5)
if x % 2 == 0 ])
cdef class A: cdef class A:
def __repr__(self): return u"A" def __repr__(self): return u"A"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment