Commit c1d71d70 authored by Stefan Behnel's avatar Stefan Behnel

use inlined generator expression for unicode.join(genexpr)

parent 5bcbcf84
......@@ -41,10 +41,10 @@ Features added
* Binary and/or/xor/rshift operations with small constant Python integers
are faster.
* When called on generator expressions, the builtin functions ``all()``,
``any()``, ``dict()``, ``list()``, ``set()`` and ``sorted()`` are
(partially) inlined into the for-loops to avoid the generator iteration
overhead.
* When called on generator expressions, the builtins ``all()``, ``any()``,
``dict()``, ``list()``, ``set()``, ``sorted()`` and ``unicode.join()``
avoid the generator iteration overhead by inlining a part of their
functionality into the for-loop.
* Keyword argument dicts are no longer copied on function entry when they
are not being used or only passed through to other function calls (e.g.
......
......@@ -33,19 +33,23 @@ try:
except ImportError:
basestring = str # Python 3
def load_c_utility(name):
return UtilityCode.load_cached(name, "Optimize.c")
def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
if isinstance(node, coercion_nodes):
return node.arg
return node
def unwrap_node(node):
while isinstance(node, UtilNodes.ResultRefNode):
node = node.expression
return node
def is_common_value(a, b):
a = unwrap_node(a)
b = unwrap_node(b)
......@@ -55,11 +59,54 @@ def is_common_value(a, b):
return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
return False
def filter_none_node(node):
if node is not None and node.constant_result is None:
return None
return node
class _YieldNodeCollector(Visitor.TreeVisitor):
"""
YieldExprNode finder for generator expressions.
"""
def __init__(self):
Visitor.TreeVisitor.__init__(self)
self.yield_stat_nodes = {}
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
# everything below these nodes is out of scope:
def visit_GeneratorExpressionNode(self, node):
pass
def visit_LambdaNode(self, node):
pass
def _find_single_yield_expression(node):
collector = _YieldNodeCollector()
collector.visitchildren(node)
if len(collector.yield_nodes) != 1:
return None, None
yield_node = collector.yield_nodes[0]
try:
return yield_node.arg, collector.yield_stat_nodes[yield_node]
except KeyError:
return None, None
class IterationTransform(Visitor.EnvTransform):
"""Transform some common for-in loop patterns into efficient C loops:
......@@ -1447,40 +1494,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
# sequence processing
class YieldNodeCollector(Visitor.TreeVisitor):
def __init__(self):
Visitor.TreeVisitor.__init__(self)
self.yield_stat_nodes = {}
self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren
# XXX: disable inlining while it's not back supported
def visit_YieldExprNode(self, node):
self.yield_nodes.append(node)
self.visitchildren(node)
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
def visit_GeneratorExpressionNode(self, node):
# enable when we support generic generator expressions
#
# everything below this node is out of scope
pass
def _find_single_yield_expression(self, node):
collector = self.YieldNodeCollector()
collector.visitchildren(node)
if len(collector.yield_nodes) != 1:
return None, None
yield_node = collector.yield_nodes[0]
try:
return (yield_node.arg, collector.yield_stat_nodes[yield_node])
except KeyError:
return None, None
def _handle_simple_function_all(self, node, pos_args):
"""Transform
......@@ -1529,7 +1542,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
generator_body = gen_expr_node.def_node.gbody
loop_node = generator_body.body
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
return node
......@@ -1588,7 +1601,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
gen_expr_node = arg
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
return node
......@@ -1646,7 +1659,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None:
......@@ -1779,7 +1792,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
return node
......@@ -1808,7 +1821,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is None:
return node
......@@ -3066,6 +3079,42 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
"PyUnicode_Split", self.PyUnicode_Split_func_type,
'split', is_unbound_method, args)
PyUnicode_Join_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
])
def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
"""
unicode.join() builds a list first => see if we can do this more efficiently
"""
if len(args) != 2:
self._error_wrong_arg_count('unicode.join', node, args, "2")
return node
if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
gen_expr_node = args[1]
loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
if yield_expression is not None:
inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
node.pos, gen_expr_node, orig_func='list',
comprehension_type=Builtin.list_type)
append_node = ExprNodes.ComprehensionAppendNode(
yield_expression.pos,
expr=yield_expression,
target=inlined_genexpr.target)
Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
args[1] = inlined_genexpr
return self._substitute_method_call(
node, function,
"PyUnicode_Join", self.PyUnicode_Join_func_type,
'join', is_unbound_method, args)
PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
......
......@@ -594,9 +594,15 @@ class MethodDispatcherTransform(EnvTransform):
return function_handler(node, function, arg_list, kwargs)
else:
return function_handler(node, function, arg_list)
elif function.is_attribute and function.type.is_pyobject:
elif function.is_attribute:
attr_name = function.attribute
self_arg = function.obj
if function.type.is_pyobject:
self_arg = function.obj
elif node.self:
self_arg = node.self
arg_list = arg_list[1:] # drop CloneNode of self argument
else:
return node
obj_type = self_arg.type
is_unbound_method = False
if obj_type.is_builtin_type:
......@@ -634,11 +640,12 @@ class MethodDispatcherTransform(EnvTransform):
if self_arg is not None:
arg_list = [self_arg] + list(arg_list)
if kwargs:
return method_handler(
result = method_handler(
node, function, arg_list, is_unbound_method, kwargs)
else:
return method_handler(
result = method_handler(
node, function, arg_list, is_unbound_method)
return result
def _handle_function(self, node, function_name, function, arg_list, kwargs):
"""Fallback handler"""
......
......@@ -22,8 +22,8 @@ pipe_sep = u'|'
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NameNode")
"//PythonCapiCallNode",
)
def test_unicode_join_bound(unicode sep, l):
"""
>>> l = text.split()
......
......@@ -186,9 +186,8 @@ pipe_sep = u'|'
"//CastNode", "//TypecastNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
"//PythonCapiCallNode",
)
def join(unicode sep, l):
"""
>>> l = text.split()
......@@ -201,13 +200,14 @@ def join(unicode sep, l):
"""
return sep.join(l)
@cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode", "//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
"//PythonCapiCallNode",
)
def join_sep(l):
"""
>>> l = text.split()
......@@ -222,9 +222,34 @@ def join_sep(l):
assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result
@cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode", "//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]"
)
@cython.test_assert_path_exists(
"//PythonCapiCallNode",
"//InlinedGeneratorExpressionNode"
)
def join_sep_genexpr(l):
"""
>>> l = text.split()
>>> len(l)
8
>>> print( '<<%s>>' % '|'.join(s + ' ' for s in l) )
<<ab |jd |sdflk |as |sa |sadas |asdas |fsdf >>
>>> print( '<<%s>>' % join_sep_genexpr(l) )
<<ab |jd |sdflk |as |sa |sadas |asdas |fsdf >>
"""
result = u'|'.join(s + u' ' for s in l)
assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NameNode")
"//PythonCapiCallNode",
)
def join_unbound(unicode sep, l):
"""
>>> l = text.split()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment