Commit ae1bd00b authored by Stefan Behnel's avatar Stefan Behnel

fixed multiplied tuple optimisation by unoptimising the 'unknown factor' case,...

fixed multiplied tuple optimisation by unoptimising the 'unknown factor' case, but still allow the underlying tuple to be constant
parent 772ce2d6
...@@ -4369,12 +4369,7 @@ class SequenceNode(ExprNode): ...@@ -4369,12 +4369,7 @@ class SequenceNode(ExprNode):
if self.mult_factor: if self.mult_factor:
self.mult_factor.analyse_types(env) self.mult_factor.analyse_types(env)
if not self.mult_factor.type.is_int: if not self.mult_factor.type.is_int:
if self.mult_factor.type.is_pyobject: self.mult_factor = self.mult_factor.coerce_to_pyobject(env)
self.mult_factor = self.mult_factor.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
else:
error(self.pos, "can't multiply sequence by non-int of type '%s'" %
self.mult_factor.type)
self.is_temp = 1 self.is_temp = 1
# not setting self.type here, subtypes do this # not setting self.type here, subtypes do this
...@@ -4406,38 +4401,44 @@ class SequenceNode(ExprNode): ...@@ -4406,38 +4401,44 @@ class SequenceNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
self.generate_operation_code(code) self.generate_operation_code(code)
def generate_sequence_packing_code(self, code): def generate_sequence_packing_code(self, code, target=None, plain=False):
if target is None:
target = self.result()
py_multiply = self.mult_factor and not self.mult_factor.type.is_int
if plain or py_multiply:
mult_factor = None
else:
mult_factor = self.mult_factor
if mult_factor:
mult = mult_factor.result()
if isinstance(mult_factor.constant_result, (int,long)) \
and mult_factor.constant_result > 0:
size_factor = ' * %s' % mult_factor.constant_result
else:
size_factor = ' * ((%s<0) ? 0:%s)' % (mult, mult)
else:
size_factor = ''
mult = ''
if self.type is Builtin.list_type: if self.type is Builtin.list_type:
create_func, set_item_func = 'PyList_New', 'PyList_SET_ITEM' create_func, set_item_func = 'PyList_New', 'PyList_SET_ITEM'
elif self.type is Builtin.tuple_type: elif self.type is Builtin.tuple_type:
create_func, set_item_func = 'PyTuple_New', 'PyTuple_SET_ITEM' create_func, set_item_func = 'PyTuple_New', 'PyTuple_SET_ITEM'
else: else:
raise InternalError("sequence unpacking for unexpected type %s" % self.type) raise InternalError("sequence unpacking for unexpected type %s" % self.type)
if self.mult_factor:
mult = self.mult_factor.result()
if isinstance(self.mult_factor.constant_result, (int,long)) \
and self.mult_factor.constant_result > 0:
size_factor = ' * %s' % self.mult_factor.constant_result
else:
size_factor = ' * ((%s<0) ? 0:%s)' % (mult, mult)
else:
size_factor = ''
mult = ''
arg_count = len(self.args) arg_count = len(self.args)
code.putln("%s = %s(%s%s); %s" % ( code.putln("%s = %s(%s%s); %s" % (
self.result(), target, create_func, arg_count, size_factor,
create_func, code.error_goto_if_null(target, self.pos)))
arg_count, code.put_gotref(target)
size_factor,
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
if mult: if mult:
# FIXME: can't use a temp variable here as the code may # FIXME: can't use a temp variable here as the code may
# end up in the constant building function. Temps # end up in the constant building function. Temps
# currently don't work there. # currently don't work there.
#counter = code.funcstate.allocate_temp(self.mult_factor.type, manage_ref=False) #counter = code.funcstate.allocate_temp(mult_factor.type, manage_ref=False)
counter = '__pyx_n' counter = Naming.quick_temp_cname
code.putln('{ Py_ssize_t %s;' % counter) code.putln('{ Py_ssize_t %s;' % counter)
if arg_count == 1: if arg_count == 1:
offset = counter offset = counter
...@@ -4454,7 +4455,7 @@ class SequenceNode(ExprNode): ...@@ -4454,7 +4455,7 @@ class SequenceNode(ExprNode):
code.put_incref(arg.result(), arg.ctype()) code.put_incref(arg.result(), arg.ctype())
code.putln("%s(%s, %s, %s);" % ( code.putln("%s(%s, %s, %s);" % (
set_item_func, set_item_func,
self.result(), target,
(offset and i) and ('%s + %s' % (offset, i)) or (offset or i), (offset and i) and ('%s + %s' % (offset, i)) or (offset or i),
arg.py_result())) arg.py_result()))
code.put_giveref(arg.py_result()) code.put_giveref(arg.py_result())
...@@ -4462,9 +4463,18 @@ class SequenceNode(ExprNode): ...@@ -4462,9 +4463,18 @@ class SequenceNode(ExprNode):
code.putln('}') code.putln('}')
#code.funcstate.release_temp(counter) #code.funcstate.release_temp(counter)
code.putln('}') code.putln('}')
elif py_multiply and not plain:
code.putln('{ PyObject* %s = PyNumber_Multiply(%s, %s); %s' % (
Naming.quick_temp_cname, target, self.mult_factor.py_result(),
code.error_goto_if_null(Naming.quick_temp_cname, self.pos)
))
code.put_gotref(Naming.quick_temp_cname)
code.put_decref(target, py_object_type)
code.putln('%s = %s;' % (target, Naming.quick_temp_cname))
code.putln('}')
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
if self.mult_factor: if self.mult_factor and self.mult_factor.type.is_int:
super(SequenceNode, self).generate_subexpr_disposal_code(code) super(SequenceNode, self).generate_subexpr_disposal_code(code)
else: else:
# We call generate_post_assignment_code here instead # We call generate_post_assignment_code here instead
...@@ -4474,6 +4484,8 @@ class SequenceNode(ExprNode): ...@@ -4474,6 +4484,8 @@ class SequenceNode(ExprNode):
arg.generate_post_assignment_code(code) arg.generate_post_assignment_code(code)
# Should NOT call free_temps -- this is invoked by the default # Should NOT call free_temps -- this is invoked by the default
# generate_evaluation_code which will do that. # generate_evaluation_code which will do that.
if self.mult_factor:
self.mult_factor.generate_disposal_code(code)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
if self.starred_assignment: if self.starred_assignment:
...@@ -4682,21 +4694,27 @@ class TupleNode(SequenceNode): ...@@ -4682,21 +4694,27 @@ class TupleNode(SequenceNode):
# Tuple constructor. # Tuple constructor.
type = tuple_type type = tuple_type
is_partly_literal = False
gil_message = "Constructing Python tuple" gil_message = "Constructing Python tuple"
def analyse_types(self, env, skip_children=False): def analyse_types(self, env, skip_children=False):
if len(self.args) == 0: if len(self.args) == 0:
self.is_temp = 0 self.is_temp = False
self.is_literal = 1 self.is_literal = True
else: else:
SequenceNode.analyse_types(self, env, skip_children) SequenceNode.analyse_types(self, env, skip_children)
for child in self.args: for child in self.args:
if not child.is_literal: if not child.is_literal:
break break
else: else:
self.is_temp = 0 if not self.mult_factor or self.mult_factor.is_literal and \
self.is_literal = 1 isinstance(self.mult_factor.constant_result, (int, long)):
self.is_temp = False
self.is_literal = True
else:
self.is_temp = True
self.is_partly_literal = True
def is_simple(self): def is_simple(self):
# either temp or constant => always simple # either temp or constant => always simple
...@@ -4727,15 +4745,28 @@ class TupleNode(SequenceNode): ...@@ -4727,15 +4745,28 @@ class TupleNode(SequenceNode):
if len(self.args) == 0: if len(self.args) == 0:
# result_code is Naming.empty_tuple # result_code is Naming.empty_tuple
return return
if self.is_literal: if self.is_partly_literal:
# underlying tuple is const, but factor is not
tuple_target = code.get_py_const(py_object_type, 'tuple_', cleanup_level=2)
const_code = code.get_cached_constants_writer()
const_code.mark_pos(self.pos)
self.generate_sequence_packing_code(const_code, tuple_target, plain=True)
const_code.put_giveref(tuple_target)
code.putln('%s = PyNumber_Multiply(%s, %s); %s' % (
self.result(), tuple_target, self.mult_factor.py_result(),
code.error_goto_if_null(self.result(), self.pos)
))
code.put_gotref(self.py_result())
elif self.is_literal:
# non-empty cached tuple => result is global constant, # non-empty cached tuple => result is global constant,
# creation code goes into separate code writer # creation code goes into separate code writer
self.result_code = code.get_py_const(py_object_type, 'tuple_', cleanup_level=2) self.result_code = code.get_py_const(py_object_type, 'tuple_', cleanup_level=2)
code = code.get_cached_constants_writer() code = code.get_cached_constants_writer()
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.generate_sequence_packing_code(code) self.generate_sequence_packing_code(code)
if self.is_literal:
code.put_giveref(self.py_result()) code.put_giveref(self.py_result())
else:
self.generate_sequence_packing_code(code)
class ListNode(SequenceNode): class ListNode(SequenceNode):
......
...@@ -92,6 +92,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope" ...@@ -92,6 +92,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
frame_cname = pyrex_prefix + "frame" frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code" frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
quick_temp_cname = pyrex_prefix + "temp" # temp variable for quick'n'dirty temping
genexpr_id_ref = 'genexpr' genexpr_id_ref = 'genexpr'
......
...@@ -175,6 +175,46 @@ def multiplied_lists_with_side_effects(): ...@@ -175,6 +175,46 @@ def multiplied_lists_with_side_effects():
""" """
return [side_effect(1), side_effect(2), side_effect(3)] * 5 return [side_effect(1), side_effect(2), side_effect(3)] * 5
@cython.test_fail_if_path_exists("//MulNode")
def multiplied_lists_nonconst_with_side_effects(x):
"""
>>> multiplied_lists_nonconst_with_side_effects(5) == [1,2,3] * 5
1
2
3
True
"""
return [side_effect(1), side_effect(2), side_effect(3)] * x
@cython.test_fail_if_path_exists("//MulNode")
def multiplied_nonconst_tuple_arg(x):
"""
>>> multiplied_nonconst_tuple_arg(5) == (1,2) * 5
True
"""
return (1,2) * x
@cython.test_fail_if_path_exists("//MulNode")
def multiplied_nonconst_tuple(x):
"""
>>> multiplied_nonconst_tuple(5) == (1,2) * (5+1)
True
"""
return (1,2) * (x + 1)
MULT = 5
@cython.test_fail_if_path_exists("//MulNode")
def multiplied_global_nonconst_tuple():
"""
>>> multiplied_global_nonconst_tuple() == (1,2,3) * 5
1
2
3
True
"""
return (side_effect(1), side_effect(2), side_effect(3)) * MULT
@cython.test_fail_if_path_exists("//MulNode") @cython.test_fail_if_path_exists("//MulNode")
def multiplied_const_tuple(): def multiplied_const_tuple():
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment