Unverified Commit ab1d7284 authored by da-woods's avatar da-woods Committed by GitHub

Implement generic optimized loop iterator with indexing and type inference for...

Implement generic optimized loop iterator with indexing and type inference for memoryviews (GH-3617)

* Adds bytearray iteration since that was not previously optimised (because it allows changing length during iteration).
* Always set `entry.init` for memoryviewslice.
parent 8b228a71
......@@ -3564,6 +3564,8 @@ class IndexNode(_IndexingBaseNode):
bytearray_type, list_type, tuple_type):
# slicing these returns the same type
return base_type
elif base_type.is_memoryviewslice:
return base_type
else:
# TODO: Handle buffers (hopefully without too much redundancy).
return py_object_type
......@@ -3606,6 +3608,23 @@ class IndexNode(_IndexingBaseNode):
index += base_type.size
if 0 <= index < base_type.size:
return base_type.components[index]
elif base_type.is_memoryviewslice:
if base_type.ndim == 0:
pass # probably an error, but definitely don't know what to do - return pyobject for now
if base_type.ndim == 1:
return base_type.dtype
else:
return PyrexTypes.MemoryViewSliceType(base_type.dtype, base_type.axes[1:])
if self.index.is_sequence_constructor and base_type.is_memoryviewslice:
inferred_type = base_type
for a in self.index.args:
if not inferred_type.is_memoryviewslice:
break # something's gone wrong
inferred_type = IndexNode(self.pos, base=ExprNode(self.base.pos, type=inferred_type),
index=a).infer_type(env)
else:
return inferred_type
if base_type.is_cpp_class:
class FakeOperand:
......@@ -13466,6 +13485,9 @@ class CoerceToTempNode(CoercionNode):
# The arg is always already analysed
return self
def may_be_none(self):
return self.arg.may_be_none()
def coerce_to_boolean(self, env):
self.arg = self.arg.coerce_to_boolean(env)
if self.arg.is_simple():
......
......@@ -228,6 +228,12 @@ class IterationTransform(Visitor.EnvTransform):
return self._transform_bytes_iteration(node, iterable, reversed=reversed)
if iterable.type is Builtin.unicode_type:
return self._transform_unicode_iteration(node, iterable, reversed=reversed)
# in principle _transform_indexable_iteration would work on most of the above, and
# also tuple and list. However, it probably isn't quite as optimized
if iterable.type is Builtin.bytearray_type:
return self._transform_indexable_iteration(node, iterable, is_mutable=True, reversed=reversed)
if isinstance(iterable, ExprNodes.CoerceToPyTypeNode) and iterable.arg.type.is_memoryviewslice:
return self._transform_indexable_iteration(node, iterable.arg, is_mutable=False, reversed=reversed)
# the rest is based on function calls
if not isinstance(iterable, ExprNodes.SimpleCallNode):
......@@ -333,6 +339,92 @@ class IterationTransform(Visitor.EnvTransform):
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
])
def _transform_indexable_iteration(self, node, slice_node, is_mutable, reversed=False):
"""In principle can handle any iterable that Cython has a len() for and knows how to index"""
unpack_temp_node = UtilNodes.LetRefNode(
slice_node.as_none_safe_node("'NoneType' is not iterable"),
may_hold_none=False, is_temp=True
)
start_node = ExprNodes.IntNode(
node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
def make_length_call():
# helper function since we need to create this node for a couple of places
builtin_len = ExprNodes.NameNode(node.pos, name="len",
entry=Builtin.builtin_scope.lookup("len"))
return ExprNodes.SimpleCallNode(node.pos,
function=builtin_len,
args=[unpack_temp_node]
)
length_temp = UtilNodes.LetRefNode(make_length_call(), type=PyrexTypes.c_py_ssize_t_type, is_temp=True)
end_node = length_temp
if reversed:
relation1, relation2 = '>', '>='
start_node, end_node = end_node, start_node
else:
relation1, relation2 = '<=', '<'
counter_ref = UtilNodes.LetRefNode(pos=node.pos, type=PyrexTypes.c_py_ssize_t_type)
target_value = ExprNodes.IndexNode(slice_node.pos, base=unpack_temp_node,
index=counter_ref)
target_assign = Nodes.SingleAssignmentNode(
pos = node.target.pos,
lhs = node.target,
rhs = target_value)
# analyse with boundscheck and wraparound
# off (because we're confident we know the size)
env = self.current_env()
new_directives = Options.copy_inherited_directives(env.directives, boundscheck=False, wraparound=False)
target_assign = Nodes.CompilerDirectivesNode(
target_assign.pos,
directives=new_directives,
body=target_assign,
)
body = Nodes.StatListNode(
node.pos,
stats = [target_assign]) # exclude node.body for now to not reanalyse it
if is_mutable:
# We need to be slightly careful here that we are actually modifying the loop
# bounds and not a temp copy of it. Setting is_temp=True on length_temp seems
# to ensure this.
# If this starts to fail then we could insert an "if out_of_bounds: break" instead
loop_length_reassign = Nodes.SingleAssignmentNode(node.pos,
lhs = length_temp,
rhs = make_length_call())
body.stats.append(loop_length_reassign)
loop_node = Nodes.ForFromStatNode(
node.pos,
bound1=start_node, relation1=relation1,
target=counter_ref,
relation2=relation2, bound2=end_node,
step=None, body=body,
else_clause=node.else_clause,
from_range=True)
ret = UtilNodes.LetNode(
unpack_temp_node,
UtilNodes.LetNode(
length_temp,
# TempResultFromStatNode provides the framework where the "counter_ref"
# temp is set up and can be assigned to. However, we don't need the
# result it returns so wrap it in an ExprStatNode.
Nodes.ExprStatNode(node.pos,
expr=UtilNodes.TempResultFromStatNode(
counter_ref,
loop_node
)
)
)
).analyse_expressions(env)
body.stats.insert(1, node.body)
return ret
def _transform_bytes_iteration(self, node, slice_node, reversed=False):
target_type = node.target.type
if not target_type.is_int and target_type is not Builtin.bytes_type:
......
......@@ -166,6 +166,16 @@ def get_directive_defaults():
_directive_defaults[old_option.directive_name] = value
return _directive_defaults
def copy_inherited_directives(outer_directives, **new_directives):
# A few directives are not copied downwards and this function removes them.
# For example, test_assert_path_exists and test_fail_if_path_exists should not be inherited
# otherwise they can produce very misleading test failures
new_directives_out = dict(outer_directives)
for name in ('test_assert_path_exists', 'test_fail_if_path_exists'):
new_directives_out.pop(name, None)
new_directives_out.update(new_directives)
return new_directives_out
# Declare compiler directives
_directive_defaults = {
'binding': True, # was False before 3.0
......
......@@ -992,12 +992,7 @@ class InterpretCompilerDirectives(CythonTransform):
return self.visit_Node(node)
old_directives = self.directives
new_directives = dict(old_directives)
# test_assert_path_exists and test_fail_if_path_exists should not be inherited
# otherwise they can produce very misleading test failures
new_directives.pop('test_assert_path_exists', None)
new_directives.pop('test_fail_if_path_exists', None)
new_directives.update(directives)
new_directives = Options.copy_inherited_directives(old_directives, **directives)
if new_directives == old_directives:
return self.visit_Node(node)
......
......@@ -672,6 +672,10 @@ class MemoryViewSliceType(PyrexType):
else:
return False
def __ne__(self, other):
# TODO drop when Python2 is dropped
return not (self == other)
def same_as_resolved_type(self, other_type):
return ((other_type.is_memoryviewslice and
#self.writable_needed == other_type.writable_needed and # FIXME: should be only uni-directional
......@@ -2516,6 +2520,7 @@ class CPointerBaseType(CType):
if self.is_string:
assert isinstance(value, str)
return '"%s"' % StringEncoding.escape_byte_string(value)
return str(value)
class CArrayType(CPointerBaseType):
......
......@@ -140,7 +140,6 @@ class MarkParallelAssignments(EnvTransform):
'+',
sequence.args[0],
sequence.args[2]))
if not is_special:
# A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to
......@@ -360,9 +359,11 @@ class SimpleAssignmentTypeInferer(object):
applies to nested scopes in top-down order.
"""
def set_entry_type(self, entry, entry_type):
entry.type = entry_type
for e in entry.all_entries():
e.type = entry_type
if e.type.is_memoryviewslice:
# memoryview slices crash if they don't get initialized
e.init = e.type.default_value
def infer_types(self, scope):
enabled = scope.directives['infer_types']
......@@ -577,6 +578,8 @@ def safe_spanning_type(types, might_overflow, pos, scope):
# used, won't arise in pure Python, and there shouldn't be side
# effects, so I'm declaring this safe.
return result_type
elif result_type.is_memoryviewslice:
return result_type
# TODO: double complex should be OK as well, but we need
# to make sure everything is supported.
elif (result_type.is_int or result_type.is_enum) and not might_overflow:
......
......@@ -360,3 +360,6 @@ class TempResultFromStatNode(ExprNodes.ExprNode):
def generate_result_code(self, code):
self.result_ref.result_code = self.result()
self.body.generate_execution_code(code)
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
......@@ -247,7 +247,7 @@ def basic_struct(MyStruct[:] mslice):
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="ccqii"))
[('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)]
"""
buf = mslice
cdef object buf = mslice
print sorted([(k, int(v)) for k, v in buf[0].items()])
def nested_struct(NestedStruct[:] mslice):
......@@ -259,7 +259,7 @@ def nested_struct(NestedStruct[:] mslice):
>>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
1 2 3 4 5
"""
buf = mslice
cdef object buf = mslice
d = buf[0]
print d['x']['a'], d['x']['b'], d['y']['a'], d['y']['b'], d['z']
......@@ -275,7 +275,7 @@ def packed_struct(PackedStruct[:] mslice):
1 2
"""
buf = mslice
cdef object buf = mslice
print buf[0]['a'], buf[0]['b']
def nested_packed_struct(NestedPackedStruct[:] mslice):
......@@ -289,7 +289,7 @@ def nested_packed_struct(NestedPackedStruct[:] mslice):
>>> nested_packed_struct(NestedPackedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="^c@i^ci@i"))
1 2 3 4 5
"""
buf = mslice
cdef object buf = mslice
d = buf[0]
print d['a'], d['b'], d['sub']['a'], d['sub']['b'], d['c']
......@@ -299,7 +299,7 @@ def complex_dtype(long double complex[:] mslice):
>>> complex_dtype(LongComplexMockBuffer(None, [(0, -1)]))
-1j
"""
buf = mslice
cdef object buf = mslice
print buf[0]
def complex_inplace(long double complex[:] mslice):
......@@ -307,7 +307,7 @@ def complex_inplace(long double complex[:] mslice):
>>> complex_inplace(LongComplexMockBuffer(None, [(0, -1)]))
(1+1j)
"""
buf = mslice
cdef object buf = mslice
buf[0] = buf[0] + 1 + 2j
print buf[0]
......@@ -318,7 +318,7 @@ def complex_struct_dtype(LongComplex[:] mslice):
>>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
0.0 -1.0
"""
buf = mslice
cdef object buf = mslice
print buf[0]['real'], buf[0]['imag']
#
......@@ -356,7 +356,7 @@ def get_int_2d(int[:, :] mslice, int i, int j):
...
IndexError: Out of bounds on buffer access (axis 1)
"""
buf = mslice
cdef object buf = mslice
return buf[i, j]
def set_int_2d(int[:, :] mslice, int i, int j, int value):
......@@ -409,10 +409,47 @@ def set_int_2d(int[:, :] mslice, int i, int j, int value):
IndexError: Out of bounds on buffer access (axis 1)
"""
buf = mslice
cdef object buf = mslice
buf[i, j] = value
#
# auto type inference
# (note that for most numeric types "might_overflow" stops the type inference from working well)
#
def type_infer(double[:, :] arg):
"""
>>> type_infer(DoubleMockBuffer(None, range(6), (2,3)))
double
double[:]
double[:]
double[:, :]
"""
a = arg[0,0]
print(cython.typeof(a))
b = arg[0]
print(cython.typeof(b))
c = arg[0,:]
print(cython.typeof(c))
d = arg[:,:]
print(cython.typeof(d))
#
# Loop optimization
#
@cython.test_fail_if_path_exists("//CoerceToPyTypeNode")
def memview_iter(double[:, :] arg):
"""
memview_iter(DoubleMockBuffer("C", range(6), (2,3)))
True
"""
cdef double total = 0
for mview1d in arg:
for val in mview1d:
total += val
if total == 15:
return True
#
# Test all kinds of indexing and flags
#
......@@ -426,7 +463,7 @@ def writable(unsigned short int[:, :, :] mslice):
>>> [str(x) for x in R.received_flags] # Py2/3
['FORMAT', 'ND', 'STRIDES', 'WRITABLE']
"""
buf = mslice
cdef object buf = mslice
buf[2, 2, 1] = 23
def strided(int[:] mslice):
......@@ -441,7 +478,7 @@ def strided(int[:] mslice):
>>> A.release_ok
True
"""
buf = mslice
cdef object buf = mslice
return buf[2]
def c_contig(int[::1] mslice):
......@@ -450,7 +487,7 @@ def c_contig(int[::1] mslice):
>>> c_contig(A)
2
"""
buf = mslice
cdef object buf = mslice
return buf[2]
def c_contig_2d(int[:, ::1] mslice):
......@@ -461,7 +498,7 @@ def c_contig_2d(int[:, ::1] mslice):
>>> c_contig_2d(A)
7
"""
buf = mslice
cdef object buf = mslice
return buf[1, 3]
def f_contig(int[::1, :] mslice):
......@@ -470,7 +507,7 @@ def f_contig(int[::1, :] mslice):
>>> f_contig(A)
2
"""
buf = mslice
cdef object buf = mslice
return buf[0, 1]
def f_contig_2d(int[::1, :] mslice):
......@@ -481,7 +518,7 @@ def f_contig_2d(int[::1, :] mslice):
>>> f_contig_2d(A)
7
"""
buf = mslice
cdef object buf = mslice
return buf[3, 1]
def generic(int[::view.generic, ::view.generic] mslice1,
......@@ -552,7 +589,7 @@ def printbuf_td_cy_int(td_cy_int[:] mslice, shape):
...
ValueError: Buffer dtype mismatch, expected 'td_cy_int' but got 'short'
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print buf[i],
......@@ -567,7 +604,7 @@ def printbuf_td_h_short(td_h_short[:] mslice, shape):
...
ValueError: Buffer dtype mismatch, expected 'td_h_short' but got 'int'
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print buf[i],
......@@ -582,7 +619,7 @@ def printbuf_td_h_cy_short(td_h_cy_short[:] mslice, shape):
...
ValueError: Buffer dtype mismatch, expected 'td_h_cy_short' but got 'int'
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print buf[i],
......@@ -597,7 +634,7 @@ def printbuf_td_h_ushort(td_h_ushort[:] mslice, shape):
...
ValueError: Buffer dtype mismatch, expected 'td_h_ushort' but got 'short'
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print buf[i],
......@@ -612,7 +649,7 @@ def printbuf_td_h_double(td_h_double[:] mslice, shape):
...
ValueError: Buffer dtype mismatch, expected 'td_h_double' but got 'float'
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print buf[i],
......@@ -649,7 +686,7 @@ def printbuf_object(object[:] mslice, shape):
{4: 23} 2
[34, 3] 2
"""
buf = mslice
cdef object buf = mslice
cdef int i
for i in range(shape[0]):
print repr(buf[i]), (<PyObject*>buf[i]).ob_refcnt
......@@ -670,7 +707,7 @@ def assign_to_object(object[:] mslice, int idx, obj):
(2, 3)
>>> decref(b)
"""
buf = mslice
cdef object buf = mslice
buf[idx] = obj
def assign_temporary_to_object(object[:] mslice):
......@@ -697,7 +734,7 @@ def assign_temporary_to_object(object[:] mslice):
>>> assign_to_object(A, 1, a)
>>> decref(a)
"""
buf = mslice
cdef object buf = mslice
buf[1] = {3-2: 2+(2*4)-2}
......@@ -745,7 +782,7 @@ def test_generic_slicing(arg, indirect=False):
"""
cdef int[::view.generic, ::view.generic, :] _a = arg
a = _a
cdef object a = _a
b = a[2:8:2, -4:1:-1, 1:3]
print b.shape
......@@ -828,7 +865,7 @@ def test_direct_slicing(arg):
released A
"""
cdef int[:, :, :] _a = arg
a = _a
cdef object a = _a
b = a[2:8:2, -4:1:-1, 1:3]
print b.shape
......@@ -856,7 +893,7 @@ def test_slicing_and_indexing(arg):
released A
"""
cdef int[:, :, :] _a = arg
a = _a
cdef object a = _a
b = a[-5:, 1, 1::2]
c = b[4:1:-1, ::-1]
d = c[2, 1:2]
......
......@@ -1525,7 +1525,7 @@ def test_index_slicing_away_direct_indirect():
All dimensions preceding dimension 1 must be indexed and not sliced
"""
cdef int[:, ::view.indirect, :] a = TestIndexSlicingDirectIndirectDims()
a_obj = a
cdef object a_obj = a
print a[1][2][3]
print a[1, 2, 3]
......
......@@ -186,7 +186,7 @@ def test_transpose():
numpy_obj = np.arange(4 * 3, dtype=np.int32).reshape(4, 3)
a = numpy_obj
a_obj = a
cdef object a_obj = a
cdef dtype_t[:, :] b = a.T
print a.T.shape[0], a.T.shape[1]
......@@ -244,7 +244,7 @@ def test_copy_and_contig_attributes(a):
>>> test_copy_and_contig_attributes(a)
"""
cdef np.int32_t[:, :] mslice = a
m = mslice
cdef object m = mslice # object copy
# Test object copy attributes
assert np.all(a == np.array(m.copy()))
......
# mode: run
# tag: pure3, pure2
import cython
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
@cython.locals(x=bytearray)
def basic_bytearray_iter(x):
"""
>>> basic_bytearray_iter(bytearray(b"hello"))
h
e
l
l
o
"""
for a in x:
print(chr(a))
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
@cython.locals(x=bytearray)
def reversed_bytearray_iter(x):
"""
>>> reversed_bytearray_iter(bytearray(b"hello"))
o
l
l
e
h
"""
for a in reversed(x):
print(chr(a))
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
@cython.locals(x=bytearray)
def modifying_bytearray_iter1(x):
"""
>>> modifying_bytearray_iter1(bytearray(b"abcdef"))
a
b
c
3
"""
count = 0
for a in x:
print(chr(a))
del x[-1]
count += 1
print(count)
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
@cython.locals(x=bytearray)
def modifying_bytearray_iter2(x):
"""
>>> modifying_bytearray_iter2(bytearray(b"abcdef"))
a
c
e
3
"""
count = 0
for a in x:
print(chr(a))
del x[0]
count += 1
print(count)
@cython.test_assert_path_exists("//ForFromStatNode")
@cython.test_fail_if_path_exists("//ForInStatNode")
@cython.locals(x=bytearray)
def modifying_reversed_bytearray_iter(x):
"""
NOTE - I'm not 100% sure how well-defined this behaviour is in Python.
However, for the moment Python and Cython seem to do the same thing.
Testing that it doesn't crash is probably more important than the exact output!
>>> modifying_reversed_bytearray_iter(bytearray(b"abcdef"))
f
f
f
f
f
f
"""
for a in reversed(x):
print(chr(a))
del x[0]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment