Commit 72dfd2e8 authored by Stefan Behnel's avatar Stefan Behnel

replace set([...]) by a literal set {...}

parent 50b97529
...@@ -123,8 +123,8 @@ class Context: ...@@ -123,8 +123,8 @@ class Context:
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
ConstantFolding(),
FlattenBuiltinTypeCreation(), FlattenBuiltinTypeCreation(),
ConstantFolding(),
DictIterTransform(), DictIterTransform(),
SwitchTransform(), SwitchTransform(),
FinalOptimizePhase(self), FinalOptimizePhase(self),
......
...@@ -354,24 +354,64 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform): ...@@ -354,24 +354,64 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
"""Optimise some common instantiation patterns for builtin types. """Optimise some common instantiation patterns for builtin types.
""" """
def visit_GeneralCallNode(self, node): def visit_GeneralCallNode(self, node):
self.visitchildren(node)
handler = self._find_handler('general', node.function)
if handler is not None:
node = handler(node, node.positional_args, node.keyword_args)
return node
def visit_SimpleCallNode(self, node):
self.visitchildren(node)
handler = self._find_handler('simple', node.function)
if handler is not None:
node = handler(node, node.arg_tuple, None)
return node
def _find_handler(self, call_type, function):
if not function.type.is_builtin_type:
return None
handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
if handler is None:
handler = getattr(self, '_handle_any_%s' % function.name, None)
return handler
def _handle_general_dict(self, node, pos_args, kwargs):
"""Replace dict(a=b,c=d,...) by the underlying keyword dict """Replace dict(a=b,c=d,...) by the underlying keyword dict
construction which is done anyway. construction which is done anyway.
""" """
self.visitchildren(node) if not isinstance(pos_args, ExprNodes.TupleNode):
if not node.function.type.is_builtin_type:
return node return node
if node.function.name != 'dict': if len(pos_args.args) > 0:
return node return node
if not isinstance(node.positional_args, ExprNodes.TupleNode): if not isinstance(kwargs, ExprNodes.DictNode):
return node
if len(node.positional_args.args) > 0:
return node
if not isinstance(node.keyword_args, ExprNodes.DictNode):
return node return node
if node.starstar_arg: if node.starstar_arg:
# we could optimise this by updating the kw dict instead # we could optimise this by updating the kw dict instead
return node return node
return node.keyword_args return kwargs
def _handle_simple_set(self, node, pos_args, kwargs):
"""Replace set([a,b,...]) by a literal set {a,b,...}.
"""
if not isinstance(pos_args, ExprNodes.TupleNode):
return node
arg_count = len(pos_args.args)
if arg_count == 0:
return ExprNodes.SetNode(node.pos, args=[],
type=Builtin.set_type, is_temp=1)
if arg_count > 1:
return node
iterable = pos_args.args[0]
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args,
type=Builtin.set_type, is_temp=1)
elif isinstance(iterable, ExprNodes.ListComprehensionNode):
iterable.__class__ = ExprNodes.SetComprehensionNode
iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
iterable.pos = node.pos
return iterable
else:
return node
def visit_PyTypeTestNode(self, node): def visit_PyTypeTestNode(self, node):
"""Flatten redundant type checks after tree changes. """Flatten redundant type checks after tree changes.
......
...@@ -9,7 +9,12 @@ True ...@@ -9,7 +9,12 @@ True
>>> sorted(test_set_add()) >>> sorted(test_set_add())
['a', 1] ['a', 1]
>>> type(test_set_add()) is _set >>> type(test_set_list_comp()) is _set
True
>>> sorted(test_set_list_comp())
[0, 1, 2]
>>> type(test_set_clear()) is _set
True True
>>> list(test_set_clear()) >>> list(test_set_clear())
[] []
...@@ -46,6 +51,11 @@ def test_set_clear(): ...@@ -46,6 +51,11 @@ def test_set_clear():
s1.clear() s1.clear()
return s1 return s1
def test_set_list_comp():
cdef set s1
s1 = set([i%3 for i in range(5)])
return s1
def test_set_pop(): def test_set_pop():
cdef set s1 cdef set s1
s1 = set() s1 = set()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment