Commit c1d71d70 authored by Stefan Behnel's avatar Stefan Behnel

use inlined generator expression for unicode.join(genexpr)

parent 5bcbcf84
...@@ -41,10 +41,10 @@ Features added ...@@ -41,10 +41,10 @@ Features added
* Binary and/or/xor/rshift operations with small constant Python integers * Binary and/or/xor/rshift operations with small constant Python integers
are faster. are faster.
* When called on generator expressions, the builtin functions ``all()``, * When called on generator expressions, the builtins ``all()``, ``any()``,
``any()``, ``dict()``, ``list()``, ``set()`` and ``sorted()`` are ``dict()``, ``list()``, ``set()``, ``sorted()`` and ``unicode.join()``
(partially) inlined into the for-loops to avoid the generator iteration avoid the generator iteration overhead by inlining a part of their
overhead. functionality into the for-loop.
* Keyword argument dicts are no longer copied on function entry when they * Keyword argument dicts are no longer copied on function entry when they
are not being used or only passed through to other function calls (e.g. are not being used or only passed through to other function calls (e.g.
......
...@@ -33,19 +33,23 @@ try: ...@@ -33,19 +33,23 @@ try:
except ImportError: except ImportError:
basestring = str # Python 3 basestring = str # Python 3
def load_c_utility(name): def load_c_utility(name):
return UtilityCode.load_cached(name, "Optimize.c") return UtilityCode.load_cached(name, "Optimize.c")
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
if isinstance(node, coercion_nodes): if isinstance(node, coercion_nodes):
return node.arg return node.arg
return node return node
def unwrap_node(node): def unwrap_node(node):
while isinstance(node, UtilNodes.ResultRefNode): while isinstance(node, UtilNodes.ResultRefNode):
node = node.expression node = node.expression
return node return node
def is_common_value(a, b): def is_common_value(a, b):
a = unwrap_node(a) a = unwrap_node(a)
b = unwrap_node(b) b = unwrap_node(b)
...@@ -55,11 +59,54 @@ def is_common_value(a, b): ...@@ -55,11 +59,54 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False return False
def filter_none_node(node): def filter_none_node(node):
if node is not None and node.constant_result is None: if node is not None and node.constant_result is None:
return None return None
return node return node
class _YieldNodeCollector(Visitor.TreeVisitor):
"""
YieldExprNode finder for generator expressions.
"""
def __init__(self):
Visitor.TreeVisitor.__init__(self)
self.yield_stat_nodes = {}
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
# everything below these nodes is out of scope:
def visit_GeneratorExpressionNode(self, node):
pass
def visit_LambdaNode(self, node):
pass
def _find_single_yield_expression(node):
collector = _YieldNodeCollector()
collector.visitchildren(node)
if len(collector.yield_nodes) != 1:
return None, None
yield_node = collector.yield_nodes[0]
try:
return yield_node.arg, collector.yield_stat_nodes[yield_node]
except KeyError:
return None, None
class IterationTransform(Visitor.EnvTransform): class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops: """Transform some common for-in loop patterns into efficient C loops:
...@@ -1447,40 +1494,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1447,40 +1494,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# sequence processing # sequence processing
class YieldNodeCollector(Visitor.TreeVisitor):
def __init__(self):
Visitor.TreeVisitor.__init__(self)
self.yield_stat_nodes = {}
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
# XXX: disable inlining while it's not back supported
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
def visit_GeneratorExpressionNode(self, node):
# enable when we support generic generator expressions
#
# everything below this node is out of scope
pass
def _find_single_yield_expression(self, node):
collector = self.YieldNodeCollector()
collector.visitchildren(node)
if len(collector.yield_nodes) != 1:
return None, None
yield_node = collector.yield_nodes[0]
try:
return (yield_node.arg, collector.yield_stat_nodes[yield_node])
except KeyError:
return None, None
def _handle_simple_function_all(self, node, pos_args): def _handle_simple_function_all(self, node, pos_args):
"""Transform """Transform
...@@ -1529,7 +1542,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1529,7 +1542,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
generator_body = gen_expr_node.def_node.gbody generator_body = gen_expr_node.def_node.gbody
loop_node = generator_body.body loop_node = generator_body.body
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1588,7 +1601,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1588,7 +1601,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
elif isinstance(arg, ExprNodes.GeneratorExpressionNode): elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg gen_expr_node = arg
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1646,7 +1659,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1646,7 +1659,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional # FIXME: currently nonfunctional
yield_expression = None yield_expression = None
if yield_expression is None: if yield_expression is None:
...@@ -1779,7 +1792,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1779,7 +1792,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1808,7 +1821,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1808,7 +1821,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -3066,6 +3079,42 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3066,6 +3079,42 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
"PyUnicode_Split", self.PyUnicode_Split_func_type, "PyUnicode_Split", self.PyUnicode_Split_func_type,
'split', is_unbound_method, args) 'split', is_unbound_method, args)
PyUnicode_Join_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
])
def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
"""
unicode.join() builds a list first => see if we can do this more efficiently
"""
if len(args) != 2:
self._error_wrong_arg_count('unicode.join', node, args, "2")
return node
if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
gen_expr_node = args[1]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is not None:
inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='list',
comprehension_type=Builtin.list_type)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=inlined_genexpr.target)
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
args[1] = inlined_genexpr
return self._substitute_method_call(
node, function,
"PyUnicode_Join", self.PyUnicode_Join_func_type,
'join', is_unbound_method, args)
PyString_Tailmatch_func_type = PyrexTypes.CFuncType( PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [ PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
......
...@@ -594,9 +594,15 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -594,9 +594,15 @@ class MethodDispatcherTransform(EnvTransform):
return function_handler(node, function, arg_list, kwargs) return function_handler(node, function, arg_list, kwargs)
else: else:
return function_handler(node, function, arg_list) return function_handler(node, function, arg_list)
elif function.is_attribute and function.type.is_pyobject: elif function.is_attribute:
attr_name = function.attribute attr_name = function.attribute
self_arg = function.obj if function.type.is_pyobject:
self_arg = function.obj
elif node.self:
self_arg = node.self
arg_list = arg_list[1:] # drop CloneNode of self argument
else:
return node
obj_type = self_arg.type obj_type = self_arg.type
is_unbound_method = False is_unbound_method = False
if obj_type.is_builtin_type: if obj_type.is_builtin_type:
...@@ -634,11 +640,12 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -634,11 +640,12 @@ class MethodDispatcherTransform(EnvTransform):
if self_arg is not None: if self_arg is not None:
arg_list = [self_arg] + list(arg_list) arg_list = [self_arg] + list(arg_list)
if kwargs: if kwargs:
return method_handler( result = method_handler(
node, function, arg_list, is_unbound_method, kwargs) node, function, arg_list, is_unbound_method, kwargs)
else: else:
return method_handler( result = method_handler(
node, function, arg_list, is_unbound_method) node, function, arg_list, is_unbound_method)
return result
def _handle_function(self, node, function_name, function, arg_list, kwargs): def _handle_function(self, node, function_name, function, arg_list, kwargs):
"""Fallback handler""" """Fallback handler"""
......
...@@ -22,8 +22,8 @@ pipe_sep = u'|' ...@@ -22,8 +22,8 @@ pipe_sep = u'|'
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//SimpleCallNode", "//PythonCapiCallNode",
"//SimpleCallNode//NameNode") )
def test_unicode_join_bound(unicode sep, l): def test_unicode_join_bound(unicode sep, l):
""" """
>>> l = text.split() >>> l = text.split()
......
...@@ -186,9 +186,8 @@ pipe_sep = u'|' ...@@ -186,9 +186,8 @@ pipe_sep = u'|'
"//CastNode", "//TypecastNode", "//CastNode", "//TypecastNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]") "//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//SimpleCallNode", "//PythonCapiCallNode",
"//SimpleCallNode//NoneCheckNode", )
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def join(unicode sep, l): def join(unicode sep, l):
""" """
>>> l = text.split() >>> l = text.split()
...@@ -201,13 +200,14 @@ def join(unicode sep, l): ...@@ -201,13 +200,14 @@ def join(unicode sep, l):
""" """
return sep.join(l) return sep.join(l)
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode", "//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode", "//NoneCheckNode", "//CastNode", "//TypecastNode", "//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]") "//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//SimpleCallNode", "//PythonCapiCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]") )
def join_sep(l): def join_sep(l):
""" """
>>> l = text.split() >>> l = text.split()
...@@ -222,9 +222,34 @@ def join_sep(l): ...@@ -222,9 +222,34 @@ def join_sep(l):
assert cython.typeof(result) == 'unicode object', cython.typeof(result) assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result return result
@cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode", "//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]"
)
@cython.test_assert_path_exists(
"//PythonCapiCallNode",
"//InlinedGeneratorExpressionNode"
)
def join_sep_genexpr(l):
"""
>>> l = text.split()
>>> len(l)
8
>>> print( '<<%s>>' % '|'.join(s + ' ' for s in l) )
<<ab |jd |sdflk |as |sa |sadas |asdas |fsdf >>
>>> print( '<<%s>>' % join_sep_genexpr(l) )
<<ab |jd |sdflk |as |sa |sadas |asdas |fsdf >>
"""
result = u'|'.join(s + u' ' for s in l)
assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//SimpleCallNode", "//PythonCapiCallNode",
"//SimpleCallNode//NameNode") )
def join_unbound(unicode sep, l): def join_unbound(unicode sep, l):
""" """
>>> l = text.split() >>> l = text.split()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment