Commit 09b4c92a authored by Stefan Behnel's avatar Stefan Behnel

implement for-in-reversed(list/tuple)

parent fabff41b
...@@ -1788,6 +1788,7 @@ class IteratorNode(ExprNode): ...@@ -1788,6 +1788,7 @@ class IteratorNode(ExprNode):
type = py_object_type type = py_object_type
iter_func_ptr = None iter_func_ptr = None
counter_cname = None counter_cname = None
reversed = False # currently only used for list/tuple types (see Optimize.py)
subexprs = ['sequence'] subexprs = ['sequence']
...@@ -1817,6 +1818,9 @@ class IteratorNode(ExprNode): ...@@ -1817,6 +1818,9 @@ class IteratorNode(ExprNode):
raise InternalError("for in carray slice not transformed") raise InternalError("for in carray slice not transformed")
is_builtin_sequence = sequence_type is list_type or \ is_builtin_sequence = sequence_type is list_type or \
sequence_type is tuple_type sequence_type is tuple_type
if not is_builtin_sequence:
# reversed() not currently optimised (see Optimize.py)
assert not self.reversed, "internal error: reversed() only implemented for list/tuple objects"
self.may_be_a_sequence = not sequence_type.is_builtin_type self.may_be_a_sequence = not sequence_type.is_builtin_type
if self.may_be_a_sequence: if self.may_be_a_sequence:
code.putln( code.putln(
...@@ -1826,12 +1830,21 @@ class IteratorNode(ExprNode): ...@@ -1826,12 +1830,21 @@ class IteratorNode(ExprNode):
if is_builtin_sequence or self.may_be_a_sequence: if is_builtin_sequence or self.may_be_a_sequence:
self.counter_cname = code.funcstate.allocate_temp( self.counter_cname = code.funcstate.allocate_temp(
PyrexTypes.c_py_ssize_t_type, manage_ref=False) PyrexTypes.c_py_ssize_t_type, manage_ref=False)
if self.reversed:
if sequence_type is list_type:
init_value = 'PyList_GET_SIZE(%s) - 1' % self.result()
else:
init_value = 'PyTuple_GET_SIZE(%s) - 1' % self.result()
else:
init_value = '0'
code.putln( code.putln(
"%s = 0; %s = %s; __Pyx_INCREF(%s);" % ( "%s = %s; __Pyx_INCREF(%s); %s = %s;" % (
self.counter_cname,
self.result(), self.result(),
self.sequence.py_result(), self.sequence.py_result(),
self.result())) self.result(),
self.counter_cname,
init_value
))
if not is_builtin_sequence: if not is_builtin_sequence:
self.iter_func_ptr = code.funcstate.allocate_temp(self._func_iternext_type, manage_ref=False) self.iter_func_ptr = code.funcstate.allocate_temp(self._func_iternext_type, manage_ref=False)
if self.may_be_a_sequence: if self.may_be_a_sequence:
...@@ -1854,17 +1867,24 @@ class IteratorNode(ExprNode): ...@@ -1854,17 +1867,24 @@ class IteratorNode(ExprNode):
self.counter_cname, self.counter_cname,
test_name, test_name,
self.py_result())) self.py_result()))
if self.reversed:
inc_dec = '--'
else:
inc_dec = '++'
code.putln( code.putln(
"%s = Py%s_GET_ITEM(%s, %s); __Pyx_INCREF(%s); %s++;" % ( "%s = Py%s_GET_ITEM(%s, %s); __Pyx_INCREF(%s); %s%s;" % (
result_name, result_name,
test_name, test_name,
self.py_result(), self.py_result(),
self.counter_cname, self.counter_cname,
result_name, result_name,
self.counter_cname)) self.counter_cname,
inc_dec))
def generate_iter_next_result_code(self, result_name, code): def generate_iter_next_result_code(self, result_name, code):
sequence_type = self.sequence.type sequence_type = self.sequence.type
if self.reversed:
code.putln("if (%s < 0) break;" % self.counter_cname)
if sequence_type is list_type: if sequence_type is list_type:
self.generate_next_sequence_item('List', result_name, code) self.generate_next_sequence_item('List', result_name, code)
return return
...@@ -1913,7 +1933,7 @@ class NextNode(AtomicExprNode): ...@@ -1913,7 +1933,7 @@ class NextNode(AtomicExprNode):
type = py_object_type type = py_object_type
def __init__(self, iterator, env): def __init__(self, iterator):
self.pos = iterator.pos self.pos = iterator.pos
self.iterator = iterator self.iterator = iterator
if iterator.type.is_ptr or iterator.type.is_array: if iterator.type.is_ptr or iterator.type.is_array:
......
...@@ -4631,7 +4631,7 @@ class ForInStatNode(LoopNode, StatNode): ...@@ -4631,7 +4631,7 @@ class ForInStatNode(LoopNode, StatNode):
import ExprNodes import ExprNodes
self.target.analyse_target_types(env) self.target.analyse_target_types(env)
self.iterator.analyse_expressions(env) self.iterator.analyse_expressions(env)
self.item = ExprNodes.NextNode(self.iterator, env) self.item = ExprNodes.NextNode(self.iterator)
if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \ if (self.iterator.type.is_ptr or self.iterator.type.is_array) and \
self.target.type.assignable_from(self.iterator.type): self.target.type.assignable_from(self.iterator.type):
# C array slice optimization. # C array slice optimization.
......
...@@ -218,7 +218,15 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -218,7 +218,15 @@ class IterationTransform(Visitor.VisitorTransform):
error(reversed_function.pos, error(reversed_function.pos,
"reversed() takes exactly 1 argument") "reversed() takes exactly 1 argument")
return node return node
return self._optimise_for_loop(node, args[0], reversed=True) arg = args[0]
# reversed(list/tuple) ?
if arg.type in (Builtin.tuple_type, Builtin.list_type):
node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
node.iterator.reversed = True
return node
return self._optimise_for_loop(node, arg, reversed=True)
PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType( PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_unicode_ptr_type, [ PyrexTypes.c_py_unicode_ptr_type, [
......
# mode: run
# tag: forin, builtins, reversed, enumerate
cimport cython cimport cython
...@@ -7,18 +9,73 @@ IS_PY3 = sys.version_info[0] >= 3 ...@@ -7,18 +9,73 @@ IS_PY3 = sys.version_info[0] >= 3
def _reversed(it): def _reversed(it):
return list(it)[::-1] return list(it)[::-1]
@cython.test_assert_path_exists('//ForInStatNode',
'//ForInStatNode/IteratorNode',
'//ForInStatNode/IteratorNode[@reversed = True]',
)
@cython.test_fail_if_path_exists('//ForInStatNode/IteratorNode//SimpleCallNode')
def reversed_list(list l): def reversed_list(list l):
""" """
>>> [ i for i in _reversed([1,2,3,4]) ] >>> [ i for i in _reversed([1,2,3,4]) ]
[4, 3, 2, 1] [4, 3, 2, 1]
>>> reversed_list([1,2,3,4]) >>> reversed_list([1,2,3,4])
[4, 3, 2, 1] [4, 3, 2, 1]
>>> reversed_list([])
[]
>>> reversed_list(None)
Traceback (most recent call last):
TypeError: 'NoneType' object is not iterable
""" """
result = [] result = []
for item in reversed(l): for item in reversed(l):
result.append(item) result.append(item)
return result return result
@cython.test_assert_path_exists('//ForInStatNode',
'//ForInStatNode/IteratorNode',
'//ForInStatNode/IteratorNode[@reversed = True]',
)
@cython.test_fail_if_path_exists('//ForInStatNode/IteratorNode//SimpleCallNode')
def reversed_tuple(tuple t):
"""
>>> [ i for i in _reversed((1,2,3,4)) ]
[4, 3, 2, 1]
>>> reversed_tuple((1,2,3,4))
[4, 3, 2, 1]
>>> reversed_tuple(())
[]
>>> reversed_tuple(None)
Traceback (most recent call last):
TypeError: 'NoneType' object is not iterable
"""
result = []
for item in reversed(t):
result.append(item)
return result
@cython.test_assert_path_exists('//ForInStatNode',
'//ForInStatNode/IteratorNode',
'//ForInStatNode/IteratorNode[@reversed = True]',
)
@cython.test_fail_if_path_exists('//ForInStatNode/IteratorNode//SimpleCallNode')
def enumerate_reversed_list(list l):
"""
>>> list(enumerate(_reversed([1,2,3])))
[(0, 3), (1, 2), (2, 1)]
>>> enumerate_reversed_list([1,2,3])
[(0, 3), (1, 2), (2, 1)]
>>> enumerate_reversed_list([])
[]
>>> enumerate_reversed_list(None)
Traceback (most recent call last):
TypeError: 'NoneType' object is not iterable
"""
result = []
cdef Py_ssize_t i
for i, item in enumerate(reversed(l)):
result.append((i, item))
return result
@cython.test_assert_path_exists('//ForFromStatNode') @cython.test_assert_path_exists('//ForFromStatNode')
def reversed_range(int N): def reversed_range(int N):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment