Commit ab78f93b authored by Stefan Behnel's avatar Stefan Behnel

adapt and apply major refactoring of IndexNode originally written by Mark Florisson

parent 7da49602
......@@ -201,7 +201,13 @@ class BufferEntry(object):
self.type = entry.type
self.cname = entry.buffer_aux.buflocal_nd_var.cname
self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
self.buf_ptr_type = self.entry.type.buffer_ptr_type
self.buf_ptr_type = entry.type.buffer_ptr_type
self.init_attributes()
def init_attributes(self):
self.shape = self.get_buf_shapevars()
self.strides = self.get_buf_stridevars()
self.suboffsets = self.get_buf_suboffsetvars()
def get_buf_suboffsetvars(self):
return self._for_all_ndim("%s.diminfo[%d].suboffsets")
......
......@@ -322,6 +322,13 @@ class ExprNode(Node):
is_string_literal = False
is_attribute = False
is_subscript = False
is_slice = False
is_buffer_access = False
is_memview_index = False
is_memview_slice = False
is_memview_broadcast = False
is_memview_copy_assignment = False
saved_subexpr_nodes = None
is_temp = False
......@@ -330,9 +337,6 @@ class ExprNode(Node):
constant_result = constant_value_not_set
# whether this node with a memoryview type should be broadcast
memslice_broadcast = False
child_attrs = property(fget=operator.attrgetter('subexprs'))
def not_implemented(self, method_name):
......@@ -790,14 +794,12 @@ class ExprNode(Node):
if src.type.is_pyobject:
src = CoerceToMemViewSliceNode(src, dst_type, env)
elif src.type.is_array:
src = CythonArrayNode.from_carray(src, env).coerce_to(
dst_type, env)
src = CythonArrayNode.from_carray(src, env).coerce_to(dst_type, env)
elif not src_type.is_error:
error(self.pos,
"Cannot convert '%s' to memoryviewslice" %
(src_type,))
elif not MemoryView.src_conforms_to_dst(
src.type, dst_type, broadcast=self.memslice_broadcast):
"Cannot convert '%s' to memoryviewslice" % (src_type,))
elif not src.type.conforms_to(dst_type, broadcast=self.is_memview_broadcast,
copying=self.is_memview_copy_assignment):
if src.type.dtype.same_as(dst_type.dtype):
msg = "Memoryview '%s' not conformable to memoryview '%s'."
tup = src.type, dst_type
......@@ -1834,10 +1836,6 @@ class NameNode(AtomicExprNode):
self.gil_error()
elif entry.is_pyglobal:
self.gil_error()
elif self.entry.type.is_memoryviewslice:
if self.cf_is_null or self.cf_maybe_null:
from . import MemoryView
MemoryView.err_if_nogil_initialized_check(self.pos, env)
gil_message = "Accessing Python global or builtin"
......@@ -2915,14 +2913,43 @@ class ParallelThreadIdNode(AtomicExprNode): #, Nodes.ParallelNode):
#
#-------------------------------------------------------------------
class IndexNode(ExprNode):
class _IndexingBaseNode(ExprNode):
# Base class for indexing nodes.
#
# base ExprNode the value being indexed
def is_ephemeral(self):
# in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type)
def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const()
def is_lvalue(self):
# NOTE: references currently have both is_reference and is_ptr
# set. Since pointers and references have different lvalue
# rules, we must be careful to separate the two.
if self.type.is_reference:
if self.type.ref_base_type.is_array:
# fixed-sized arrays aren't l-values
return False
elif self.type.is_ptr:
# non-const pointers can always be reassigned
return True
# Just about everything else returned by the index operator
# can be an lvalue.
return True
class IndexNode(_IndexingBaseNode):
# Sequence indexing.
#
# base ExprNode
# index ExprNode
# indices [ExprNode]
# type_indices [PyrexType]
# is_buffer_access boolean Whether this is a buffer access.
#
# indices is used on buffer access, index on non-buffer access.
# The former contains a clean list of index parameters, the
......@@ -2931,33 +2958,18 @@ class IndexNode(ExprNode):
# is_fused_index boolean Whether the index is used to specialize a
# c(p)def function
subexprs = ['base', 'index', 'indices']
indices = None
subexprs = ['base', 'index']
type_indices = None
is_subscript = True
is_fused_index = False
# Whether we're assigning to a buffer (in that case it needs to be
# writable)
writable_needed = False
# Whether we are indexing or slicing a memoryviewslice
memslice_index = False
memslice_slice = False
is_memslice_copy = False
memslice_ellipsis_noop = False
warned_untyped_idx = False
# set by SingleAssignmentNode after analyse_types()
is_memslice_scalar_assignment = False
def __init__(self, pos, index, **kw):
ExprNode.__init__(self, pos, index=index, **kw)
self._index = index
def calculate_constant_result(self):
self.constant_result = \
self.base.constant_result[self.index.constant_result]
self.constant_result = self.base.constant_result[self.index.constant_result]
def compile_time_value(self, denv):
base = self.base.compile_time_value(denv)
......@@ -2967,18 +2979,7 @@ class IndexNode(ExprNode):
except Exception as e:
self.compile_time_value_error(e)
def is_ephemeral(self):
# in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type)
def is_simple(self):
if self.is_buffer_access or self.memslice_index:
return False
elif self.memslice_slice:
return True
base = self.base
return (base.is_simple() and self.index.is_simple()
and base.type and (base.type.is_ptr or base.type.is_array))
......@@ -3023,7 +3024,7 @@ class IndexNode(ExprNode):
def infer_type(self, env):
base_type = self.base.infer_type(env)
if isinstance(self.index, SliceNode):
if self.index.is_slice:
# slicing!
if base_type.is_string:
# sliced C strings must coerce to Python
......@@ -3105,7 +3106,7 @@ class IndexNode(ExprNode):
node = self.analyse_base_and_index_types(env, setting=True)
if node.type.is_const:
error(self.pos, "Assignment to const dereference")
if not node.is_lvalue():
if node is self and not node.is_lvalue():
error(self.pos, "Assignment to non-lvalue of type '%s'" % node.type)
return node
......@@ -3114,19 +3115,6 @@ class IndexNode(ExprNode):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
# Note that this function must leave IndexNode in a cloneable state.
# For buffers, self.index is packed out on the initial analysis, and
# when cloning self.indices is copied.
self.is_buffer_access = False
# a[...] = b
self.is_memslice_copy = False
# incomplete indexing, Ellipsis indexing or slicing
self.memslice_slice = False
# integer indexing
self.memslice_index = False
if analyse_base:
self.base = self.base.analyse_types(env)
......@@ -3136,8 +3124,7 @@ class IndexNode(ExprNode):
self.type = PyrexTypes.error_type
return self
is_slice = isinstance(self.index, SliceNode)
is_slice = self.index.is_slice
if not env.directives['wraparound']:
if is_slice:
check_negative_indices(self.index.start, self.index.stop)
......@@ -3149,181 +3136,21 @@ class IndexNode(ExprNode):
self.index = self.index.coerce_to_pyobject(env)
is_memslice = self.base.type.is_memoryviewslice
# Handle the case where base is a literal char* (and we expect a string, not an int)
if not is_memslice and (isinstance(self.base, BytesNode) or is_slice):
if self.base.type.is_string or not (self.base.type.is_ptr or self.base.type.is_array):
self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False
buffer_access = False
if self.indices:
indices = self.indices
elif isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
if (is_memslice and not self.indices and
isinstance(self.index, EllipsisNode)):
# Memoryviewslice copying
self.is_memslice_copy = True
elif is_memslice:
# memoryviewslice indexing or slicing
from . import MemoryView
skip_child_analysis = True
newaxes = [newaxis for newaxis in indices if newaxis.is_none]
have_slices, indices = MemoryView.unellipsify(indices,
newaxes,
self.base.type.ndim)
self.memslice_index = (not newaxes and
len(indices) == self.base.type.ndim)
axes = []
index_type = PyrexTypes.c_py_ssize_t_type
new_indices = []
if len(indices) - len(newaxes) > self.base.type.ndim:
self.type = error_type
error(indices[self.base.type.ndim].pos,
"Too many indices specified for type %s" %
self.base.type)
return self
axis_idx = 0
for i, index in enumerate(indices[:]):
index = index.analyse_types(env)
if not index.is_none:
access, packing = self.base.type.axes[axis_idx]
axis_idx += 1
if isinstance(index, SliceNode):
self.memslice_slice = True
if index.step.is_none:
axes.append((access, packing))
else:
axes.append((access, 'strided'))
# Coerce start, stop and step to temps of the right type
for attr in ('start', 'stop', 'step'):
value = getattr(index, attr)
if not value.is_none:
value = value.coerce_to(index_type, env)
#value = value.coerce_to_temp(env)
setattr(index, attr, value)
new_indices.append(value)
elif index.is_none:
self.memslice_slice = True
new_indices.append(index)
axes.append(('direct', 'strided'))
elif index.type.is_int or index.type.is_pyobject:
if index.type.is_pyobject and not self.warned_untyped_idx:
warning(index.pos, "Index should be typed for more "
"efficient access", level=2)
IndexNode.warned_untyped_idx = True
self.memslice_index = True
index = index.coerce_to(index_type, env)
indices[i] = index
new_indices.append(index)
else:
self.type = error_type
error(index.pos, "Invalid index for memoryview specified")
return self
self.memslice_index = self.memslice_index and not self.memslice_slice
self.original_indices = indices
# All indices with all start/stop/step for slices.
# We need to keep this around
self.indices = new_indices
self.env = env
elif self.base.type.is_buffer:
# Buffer indexing
if len(indices) == self.base.type.ndim:
buffer_access = True
skip_child_analysis = True
for x in indices:
x = x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access and not self.base.type.is_memoryviewslice:
assert hasattr(self.base, "entry") # Must be a NameNode-like node
# On cloning, indices is cloned. Otherwise, unpack index into indices
assert not (buffer_access and isinstance(self.index, CloneNode))
replacement_node = self.analyse_as_buffer_operation(env, getting)
if replacement_node is not None:
return replacement_node
self.nogil = env.nogil
base_type = self.base.type
if buffer_access or self.memslice_index:
#if self.base.type.is_memoryviewslice and not self.base.is_name:
# self.base = self.base.coerce_to_temp(env)
self.base = self.base.coerce_to_simple(env)
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_buffer_access = True
self.buffer_type = self.base.type #self.base.entry.type
if getting and self.type.is_pyobject:
self.is_temp = True
if setting and self.base.type.is_memoryviewslice:
self.base.type.writable_needed = True
elif setting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.writable_needed = True
if self.base.type.is_buffer:
self.base.entry.buffer_aux.writable_needed = True
elif self.is_memslice_copy:
self.type = self.base.type
if getting:
self.memslice_ellipsis_noop = True
else:
self.memslice_broadcast = True
elif self.memslice_slice:
self.index = None
self.is_temp = True
self.use_managed_ref = True
if not MemoryView.validate_axes(self.pos, axes):
self.type = error_type
return self
self.type = PyrexTypes.MemoryViewSliceType(
self.base.type.dtype, axes)
if (self.base.type.is_memoryviewslice and not
self.base.is_name and not
self.base.result_in_temp()):
self.base = self.base.coerce_to_temp(env)
if setting:
self.memslice_broadcast = True
else:
base_type = self.base.type
if not base_type.is_cfunction:
if isinstance(self.index, TupleNode):
self.index = self.index.analyse_types(
env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
self.index = self.index.analyse_types(env)
self.original_index_type = self.index.type
if not base_type.is_cfunction:
self.index = self.index.analyse_types(env)
self.original_index_type = self.index.type
if base_type.is_unicode_char:
# we infer Py_UNICODE/Py_UCS4 for unicode strings in some
......@@ -3335,125 +3162,173 @@ class IndexNode(ExprNode):
return self.base
self.base = self.base.coerce_to_pyobject(env)
base_type = self.base.type
if base_type.is_pyobject:
if self.index.type.is_int and base_type is not dict_type:
if (getting
and (base_type in (list_type, tuple_type, bytearray_type))
and (not self.index.type.signed
or not env.directives['wraparound']
or (isinstance(self.index, IntNode) and
self.index.has_constant_result() and self.index.constant_result >= 0))
and not env.directives['boundscheck']):
self.is_temp = 0
else:
self.is_temp = 1
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
self.original_index_type.create_to_py_utility_code(env)
else:
self.index = self.index.coerce_to_pyobject(env)
self.is_temp = 1
if self.index.type.is_int and base_type is unicode_type:
# Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string
# if required, so this is fast and safe
self.type = PyrexTypes.c_py_ucs4_type
elif self.index.type.is_int and base_type is bytearray_type:
if setting:
self.type = PyrexTypes.c_uchar_type
else:
# not using 'uchar' to enable fast and safe error reporting as '-1'
self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type
else:
item_type = None
if base_type in (list_type, tuple_type) and self.index.type.is_int:
item_type = infer_sequence_item_type(
env, self.base, self.index, seq_type=base_type)
if item_type is None:
item_type = py_object_type
self.type = item_type
if base_type in (list_type, tuple_type, dict_type):
# do the None check explicitly (not in a helper) to allow optimising it away
self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
if base_type.is_pyobject:
return self.analyse_as_pyobject(env, is_slice, getting, setting)
elif base_type.is_ptr or base_type.is_array:
return self.analyse_as_c_array(env, is_slice)
elif base_type.is_cpp_class:
return self.analyse_as_cpp(env, setting)
elif base_type.is_cfunction:
return self.analyse_as_c_function(env)
elif base_type.is_ctuple:
return self.analyse_as_c_tuple(env, getting, setting)
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
base_type)
self.type = PyrexTypes.error_type
return self
def analyse_as_pyobject(self, env, is_slice, getting, setting):
base_type = self.base.type
if self.index.type.is_int and base_type is not dict_type:
if (getting
and (base_type in (list_type, tuple_type, bytearray_type))
and (not self.index.type.signed
or not env.directives['wraparound']
or (isinstance(self.index, IntNode) and
self.index.has_constant_result() and self.index.constant_result >= 0))
and not env.directives['boundscheck']):
self.is_temp = 0
else:
if base_type.is_ptr or base_type.is_array:
self.type = base_type.base_type
if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
elif base_type.is_cpp_class:
function = env.lookup_operator("[]", [self.base, self.index])
if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type
if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type)
elif base_type.is_cfunction:
if base_type.is_fused:
self.parse_indexed_fused_cdef(env)
else:
self.type_indices = self.parse_index_as_types(env)
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
self.type = error_type
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = error_type
else:
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
elif base_type.is_ctuple:
if isinstance(self.index, IntNode) and self.index.has_constant_result():
index = self.index.constant_result
if -base_type.size <= index < base_type.size:
if index < 0:
index += base_type.size
self.type = base_type.components[index]
else:
error(self.pos,
"Index %s out of bounds for '%s'" %
(index, base_type))
self.type = PyrexTypes.error_type
else:
self.base = self.base.coerce_to_pyobject(env)
return self.analyse_base_and_index_types(env, getting=getting, setting=setting, analyse_base=False)
else:
error(self.pos,
"Attempting to index non-array type '%s'" %
base_type)
self.type = PyrexTypes.error_type
self.is_temp = 1
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
self.original_index_type.create_to_py_utility_code(env)
else:
self.index = self.index.coerce_to_pyobject(env)
self.is_temp = 1
if self.index.type.is_int and base_type is unicode_type:
# Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string
# if required, so this is fast and safe
self.type = PyrexTypes.c_py_ucs4_type
elif self.index.type.is_int and base_type is bytearray_type:
if setting:
self.type = PyrexTypes.c_uchar_type
else:
# not using 'uchar' to enable fast and safe error reporting as '-1'
self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type
else:
item_type = None
if base_type in (list_type, tuple_type) and self.index.type.is_int:
item_type = infer_sequence_item_type(
env, self.base, self.index, seq_type=base_type)
if item_type is None:
item_type = py_object_type
self.type = item_type
if base_type in (list_type, tuple_type, dict_type):
# do the None check explicitly (not in a helper) to allow optimising it away
self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
self.wrap_in_nonecheck_node(env, getting)
return self
def wrap_in_nonecheck_node(self, env, getting):
if not env.directives['nonecheck'] or not self.base.may_be_none():
return
def analyse_as_c_array(self, env, is_slice):
base_type = self.base.type
self.type = base_type.base_type
if is_slice:
self.type = base_type
elif self.index.type.is_pyobject:
self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
elif not self.index.type.is_int:
error(self.pos, "Invalid index type '%s'" % self.index.type)
return self
if self.base.type.is_memoryviewslice:
if self.is_memslice_copy and not getting:
msg = "Cannot assign to None memoryview slice"
elif self.memslice_slice:
msg = "Cannot slice None memoryview slice"
def analyse_as_cpp(self, env, setting):
base_type = self.base.type
function = env.lookup_operator("[]", [self.base, self.index])
if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return self
func_type = function.type
if func_type.is_ptr:
func_type = func_type.base_type
self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type
if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type)
return self
def analyse_as_c_function(self, env):
base_type = self.base.type
if base_type.is_fused:
self.parse_indexed_fused_cdef(env)
else:
self.type_indices = self.parse_index_as_types(env)
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
self.type = error_type
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = error_type
else:
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
return self
def analyse_as_c_tuple(self, env, getting, setting):
base_type = self.base.type
if isinstance(self.index, IntNode) and self.index.has_constant_result():
index = self.index.constant_result
if -base_type.size <= index < base_type.size:
if index < 0:
index += base_type.size
self.type = base_type.components[index]
else:
msg = "Cannot index None memoryview slice"
error(self.pos,
"Index %s out of bounds for '%s'" %
(index, base_type))
self.type = PyrexTypes.error_type
return self
else:
self.base = self.base.coerce_to_pyobject(env)
return self.analyse_base_and_index_types(env, getting=getting, setting=setting, analyse_base=False)
def analyse_as_buffer_operation(self, env, getting):
"""
Analyse buffer indexing and memoryview indexing/slicing
"""
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
msg = "'NoneType' object is not subscriptable"
indices = [self.index]
self.base = self.base.as_none_safe_node(msg)
base_type = self.base.type
replacement_node = None
if base_type.is_memoryviewslice:
# memoryviewslice indexing or slicing
from . import MemoryView
have_slices, indices, newaxes = MemoryView.unellipsify(indices, base_type.ndim)
if have_slices:
replacement_node = MemoryViewSliceNode(self.pos, indices=indices, base=self.base)
else:
replacement_node = MemoryViewIndexNode(self.pos, indices=indices, base=self.base)
elif base_type.is_buffer and len(indices) == base_type.ndim:
# Buffer indexing
is_buffer_access = True
for index in indices:
index = index.analyse_types(env)
if not index.type.is_int:
is_buffer_access = False
if is_buffer_access:
replacement_node = BufferIndexNode(self.pos, indices=indices, base=self.base)
# On cloning, indices is cloned. Otherwise, unpack index into indices.
assert not isinstance(self.index, CloneNode)
if replacement_node is not None:
replacement_node = replacement_node.analyse_types(env, getting)
return replacement_node
def wrap_in_nonecheck_node(self, env, getting):
if not env.directives['nonecheck'] or not self.base.may_be_none():
return
self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable")
def parse_index_as_types(self, env, required=True):
if isinstance(self.index, TupleNode):
......@@ -3563,43 +3438,8 @@ class IndexNode(ExprNode):
gil_message = "Indexing Python object"
def nogil_check(self, env):
if self.is_buffer_access or self.memslice_index or self.memslice_slice:
if not self.memslice_slice and env.directives['boundscheck']:
# error(self.pos, "Cannot check buffer index bounds without gil; "
# "use boundscheck(False) directive")
warning(self.pos, "Use boundscheck(False) for faster access",
level=1)
if self.type.is_pyobject:
error(self.pos, "Cannot access buffer with object dtype without gil")
return
super(IndexNode, self).nogil_check(env)
def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const()
def is_lvalue(self):
# NOTE: references currently have both is_reference and is_ptr
# set. Since pointers and references have different lvalue
# rules, we must be careful to separate the two.
if self.type.is_reference:
if self.type.ref_base_type.is_array:
# fixed-sized arrays aren't l-values
return False
elif self.type.is_ptr:
# non-const pointers can always be reassigned
return True
# Just about everything else returned by the index operator
# can be an lvalue.
return True
def calculate_result_code(self):
if self.is_buffer_access:
return "(*%s)" % self.buffer_ptr_code
elif self.is_memslice_copy:
return self.base.result()
elif self.base.type in (list_type, tuple_type, bytearray_type):
if self.base.type in (list_type, tuple_type, bytearray_type):
if self.base.type is list_type:
index_code = "PyList_GET_ITEM(%s, %s)"
elif self.base.type is tuple_type:
......@@ -3641,101 +3481,62 @@ class IndexNode(ExprNode):
else:
return ""
def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code)
if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_disposal_code(code)
else:
for i in self.indices:
i.generate_disposal_code(code)
def free_subexpr_temps(self, code):
self.base.free_temps(code)
if self.indices is None:
self.index.free_temps(code)
else:
for i in self.indices:
i.free_temps(code)
def generate_result_code(self, code):
if self.is_buffer_access or self.memslice_index:
buffer_entry, self.buffer_ptr_code = self.buffer_lookup_code(code)
if self.type.is_pyobject:
# is_temp is True, so must pull out value and incref it.
# NOTE: object temporary results for nodes are declared
# as PyObject *, so we need a cast
code.putln("%s = (PyObject *) *%s;" % (self.temp_code,
self.buffer_ptr_code))
code.putln("__Pyx_INCREF((PyObject*)%s);" % self.temp_code)
elif self.memslice_slice:
self.put_memoryviewslice_slice_code(code)
elif self.is_temp:
if self.type.is_pyobject:
error_value = 'NULL'
if self.index.type.is_int:
if self.base.type is list_type:
function = "__Pyx_GetItemInt_List"
elif self.base.type is tuple_type:
function = "__Pyx_GetItemInt_Tuple"
else:
function = "__Pyx_GetItemInt"
code.globalstate.use_utility_code(
TempitaUtilityCode.load_cached("GetItemInt", "ObjectHandling.c"))
if not self.is_temp:
# all handled in self.calculate_result_code()
return
if self.type.is_pyobject:
error_value = 'NULL'
if self.index.type.is_int:
if self.base.type is list_type:
function = "__Pyx_GetItemInt_List"
elif self.base.type is tuple_type:
function = "__Pyx_GetItemInt_Tuple"
else:
if self.base.type is dict_type:
function = "__Pyx_PyDict_GetItem"
code.globalstate.use_utility_code(
UtilityCode.load_cached("DictGetItem", "ObjectHandling.c"))
else:
function = "PyObject_GetItem"
elif self.type.is_unicode_char and self.base.type is unicode_type:
assert self.index.type.is_int
function = "__Pyx_GetItemInt_Unicode"
error_value = '(Py_UCS4)-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntUnicode", "StringTools.c"))
elif self.base.type is bytearray_type:
assert self.index.type.is_int
assert self.type.is_int
function = "__Pyx_GetItemInt_ByteArray"
error_value = '-1'
function = "__Pyx_GetItemInt"
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
TempitaUtilityCode.load_cached("GetItemInt", "ObjectHandling.c"))
else:
assert False, "unexpected type %s and base type %s for indexing" % (
self.type, self.base.type)
if self.base.type is dict_type:
function = "__Pyx_PyDict_GetItem"
code.globalstate.use_utility_code(
UtilityCode.load_cached("DictGetItem", "ObjectHandling.c"))
else:
function = "PyObject_GetItem"
elif self.type.is_unicode_char and self.base.type is unicode_type:
assert self.index.type.is_int
function = "__Pyx_GetItemInt_Unicode"
error_value = '(Py_UCS4)-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntUnicode", "StringTools.c"))
elif self.base.type is bytearray_type:
assert self.index.type.is_int
assert self.type.is_int
function = "__Pyx_GetItemInt_ByteArray"
error_value = '-1'
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
else:
assert False, "unexpected type %s and base type %s for indexing" % (
self.type, self.base.type)
if self.index.type.is_int:
index_code = self.index.result()
else:
index_code = self.index.py_result()
if self.index.type.is_int:
index_code = self.index.result()
else:
index_code = self.index.py_result()
code.putln(
"%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % (
self.result(),
function,
self.base.py_result(),
index_code,
self.extra_index_params(code),
self.result(),
error_value,
code.error_goto(self.pos)))
if self.type.is_pyobject:
code.put_gotref(self.py_result())
code.putln(
"%s = %s(%s, %s%s); if (unlikely(%s == %s)) %s;" % (
self.result(),
function,
self.base.py_result(),
index_code,
self.extra_index_params(code),
self.result(),
error_value,
code.error_goto(self.pos)))
if self.type.is_pyobject:
code.put_gotref(self.py_result())
def generate_setitem_code(self, value_code, code):
if self.index.type.is_int:
......@@ -3770,57 +3571,20 @@ class IndexNode(ExprNode):
self.extra_index_params(code),
code.error_goto(self.pos)))
def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode
buffer_entry, ptrexpr = self.buffer_lookup_code(code)
if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(buffer_entry.buf_ptr_type,
manage_ref=False)
rhs_code = rhs.result()
code.putln("%s = %s;" % (ptr, ptrexpr))
code.put_gotref("*%s" % ptr)
code.putln("__Pyx_INCREF(%s); __Pyx_DECREF(*%s);" % (
rhs_code, ptr))
code.putln("*%s %s= %s;" % (ptr, op, rhs_code))
code.put_giveref("*%s" % ptr)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
generate_evaluation_code = (self.is_memslice_scalar_assignment or
self.memslice_slice)
if generate_evaluation_code:
self.generate_evaluation_code(code)
else:
self.generate_subexpr_evaluation_code(code)
self.generate_subexpr_evaluation_code(code)
if self.is_buffer_access or self.memslice_index:
self.generate_buffer_setitem_code(rhs, code)
elif self.is_memslice_scalar_assignment:
self.generate_memoryviewslice_assign_scalar_code(rhs, code)
elif self.memslice_slice or self.is_memslice_copy:
self.generate_memoryviewslice_setslice_code(rhs, code)
elif self.type.is_pyobject:
if self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code)
elif self.base.type is bytearray_type:
value_code = self._check_byte_value(code, rhs)
self.generate_setitem_code(value_code, code)
else:
code.putln(
"%s = %s;" % (
self.result(), rhs.result()))
if generate_evaluation_code:
self.generate_disposal_code(code)
else:
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
"%s = %s;" % (self.result(), rhs.result()))
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
......@@ -3884,27 +3648,88 @@ class IndexNode(ExprNode):
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
def buffer_entry(self):
from . import Buffer, MemoryView
base = self.base
if self.base.is_nonecheck:
base = base.arg
class BufferIndexNode(_IndexingBaseNode):
"""
Indexing of buffers and memoryviews. This node is created during type
analysis from IndexNode and replaces it.
if base.is_name:
entry = base.entry
else:
# SimpleCallNode is_simple is not consistent with coerce_to_simple
assert base.is_simple() or base.is_temp
cname = base.result()
entry = Symtab.Entry(cname, cname, self.base.type, self.base.pos)
Attributes:
base - base node being indexed
indices - list of indexing expressions
"""
if entry.type.is_buffer:
buffer_entry = Buffer.BufferEntry(entry)
else:
buffer_entry = MemoryView.MemoryViewSliceBufferEntry(entry)
subexprs = ['base', 'indices']
is_buffer_access = True
# Whether we're assigning to a buffer (in that case it needs to be writable)
writable_needed = False
def analyse_target_types(self, env):
self.analyse_types(env, getting=False)
def analyse_types(self, env, getting=True):
"""
Analyse types for buffer indexing only. Overridden by memoryview
indexing and slicing subclasses
"""
# self.indices are already analyzed
if not self.base.is_name:
error(self.pos, "Can only index buffer variables")
self.type = error_type
return self
if not getting:
if not self.base.entry.type.writable:
error(self.pos, "Writing to readonly buffer")
else:
self.writable_needed = True
if self.base.type.is_buffer:
self.base.entry.buffer_aux.writable_needed = True
self.none_error_message = "'NoneType' object is not subscriptable"
self.analyse_buffer_index(env, getting)
self.wrap_in_nonecheck_node(env)
return self
def analyse_buffer_index(self, env, getting):
self.base = self.base.coerce_to_simple(env)
self.type = self.base.type.dtype
self.buffer_type = self.base.type
if getting and self.type.is_pyobject:
self.is_temp = True
return buffer_entry
def analyse_assignment(self, rhs):
"""
Called by IndexNode when this node is assigned to,
with the rhs of the assignment
"""
def wrap_in_nonecheck_node(self, env):
if not env.directives['nonecheck'] or not self.base.may_be_none():
return
self.base = self.base.as_none_safe_node(self.none_error_message)
def nogil_check(self, env):
if self.is_buffer_access or self.is_memview_index:
if env.directives['boundscheck']:
warning(self.pos, "Use boundscheck(False) for faster access",
level=1)
if self.type.is_pyobject:
error(self.pos, "Cannot access buffer with object dtype without gil")
self.type = error_type
def calculate_result_code(self):
return "(*%s)" % self.buffer_ptr_code
def buffer_entry(self):
base = self.base
if self.base.is_nonecheck:
base = base.arg
return base.type.get_entry(base)
def buffer_lookup_code(self, code):
"""
......@@ -3938,17 +3763,228 @@ class IndexNode(ExprNode):
negative_indices=negative_indices,
in_nogil_context=self.in_nogil_context)
def put_memoryviewslice_slice_code(self, code):
"memslice[:]"
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.generate_subexpr_evaluation_code(code)
self.generate_buffer_setitem_code(rhs, code)
self.generate_subexpr_disposal_code(code)
self.free_subexpr_temps(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
def generate_buffer_setitem_code(self, rhs, code, op=""):
# Used from generate_assignment_code and InPlaceAssignmentNode
buffer_entry, ptrexpr = self.buffer_lookup_code(code)
if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(buffer_entry.buf_ptr_type,
manage_ref=False)
rhs_code = rhs.result()
code.putln("%s = %s;" % (ptr, ptrexpr))
code.put_gotref("*%s" % ptr)
code.putln("__Pyx_INCREF(%s); __Pyx_DECREF(*%s);" % (
rhs_code, ptr))
code.putln("*%s %s= %s;" % (ptr, op, rhs_code))
code.put_giveref("*%s" % ptr)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_result_code(self, code):
buffer_entry, self.buffer_ptr_code = self.buffer_lookup_code(code)
if self.type.is_pyobject:
# is_temp is True, so must pull out value and incref it.
# NOTE: object temporary results for nodes are declared
# as PyObject *, so we need a cast
code.putln("%s = (PyObject *) *%s;" % (self.result(), self.buffer_ptr_code))
code.putln("__Pyx_INCREF((PyObject*)%s);" % self.result())
class MemoryViewIndexNode(BufferIndexNode):
is_memview_index = True
is_buffer_access = False
warned_untyped_idx = False
def analyse_types(self, env, getting=True):
# memoryviewslice indexing or slicing
from . import MemoryView
indices = self.indices
have_slices, indices, newaxes = MemoryView.unellipsify(indices, self.base.type.ndim)
self.memslice_index = (not newaxes and len(indices) == self.base.type.ndim)
axes = []
index_type = PyrexTypes.c_py_ssize_t_type
new_indices = []
if len(indices) - len(newaxes) > self.base.type.ndim:
self.type = error_type
error(indices[self.base.type.ndim].pos,
"Too many indices specified for type %s" % self.base.type)
return self
axis_idx = 0
for i, index in enumerate(indices[:]):
index = index.analyse_types(env)
if index.is_none:
self.is_memview_slice = True
new_indices.append(index)
axes.append(('direct', 'strided'))
continue
access, packing = self.base.type.axes[axis_idx]
axis_idx += 1
if index.is_slice:
self.is_memview_slice = True
if index.step.is_none:
axes.append((access, packing))
else:
axes.append((access, 'strided'))
# Coerce start, stop and step to temps of the right type
for attr in ('start', 'stop', 'step'):
value = getattr(index, attr)
if not value.is_none:
value = value.coerce_to(index_type, env)
#value = value.coerce_to_temp(env)
setattr(index, attr, value)
new_indices.append(value)
elif index.type.is_int or index.type.is_pyobject:
if index.type.is_pyobject and not self.warned_untyped_idx:
warning(index.pos, "Index should be typed for more efficient access", level=2)
MemoryViewIndexNode.warned_untyped_idx = True
self.is_memview_index = True
index = index.coerce_to(index_type, env)
indices[i] = index
new_indices.append(index)
else:
self.type = error_type
error(index.pos, "Invalid index for memoryview specified, type %s" % index.type)
return self
### FIXME: replace by MemoryViewSliceNode if is_memview_slice ?
self.is_memview_index = self.is_memview_index and not self.is_memview_slice
self.indices = new_indices
# All indices with all start/stop/step for slices.
# We need to keep this around.
self.original_indices = indices
self.nogil = env.nogil
self.analyse_operation(env, getting, axes)
self.wrap_in_nonecheck_node(env)
return self
def analyse_operation(self, env, getting, axes):
self.none_error_message = "Cannot index None memoryview slice"
self.analyse_buffer_index(env, getting)
def analyse_broadcast_operation(self, rhs):
"""
Support broadcasting for slice assignment.
E.g.
m_2d[...] = m_1d # or,
m_1d[...] = m_2d # if the leading dimension has extent 1
"""
if self.type.is_memoryviewslice:
lhs = self
if lhs.is_memview_broadcast or rhs.is_memview_broadcast:
lhs.is_memview_broadcast = True
rhs.is_memview_broadcast = True
def analyse_as_memview_scalar_assignment(self, rhs):
lhs = self.analyse_assignment(rhs)
if lhs:
rhs.is_memview_copy_assignment = lhs.is_memview_copy_assignment
return lhs
return self
class MemoryViewSliceNode(MemoryViewIndexNode):
is_memview_slice = True
# No-op slicing operation, this node will be replaced
is_ellipsis_noop = False
is_memview_scalar_assignment = False
is_memview_index = False
is_memview_broadcast = False
def analyse_ellipsis_noop(self, env, getting):
"""Slicing operations needing no evaluation, i.e. m[...] or m[:, :]"""
### FIXME: replace directly
self.is_ellipsis_noop = all(
index.is_slice and index.start.is_none and index.stop.is_none and index.step.is_none
for index in self.indices)
if self.is_ellipsis_noop:
self.type = self.base.type
def analyse_operation(self, env, getting, axes):
from . import MemoryView
if not getting:
self.is_memview_broadcast = True
self.none_error_message = "Cannot assign to None memoryview slice"
else:
self.none_error_message = "Cannot slice None memoryview slice"
self.analyse_ellipsis_noop(env, getting)
if self.is_ellipsis_noop:
return
self.index = None
self.is_temp = True
self.use_managed_ref = True
if not MemoryView.validate_axes(self.pos, axes):
self.type = error_type
return
self.type = PyrexTypes.MemoryViewSliceType(self.base.type.dtype, axes)
if not (self.base.is_simple() or self.base.result_in_temp()):
self.base = self.base.coerce_to_temp(env)
def analyse_assignment(self, rhs):
if not rhs.type.is_memoryviewslice and (
self.type.dtype.assignable_from(rhs.type) or
rhs.type.is_pyobject):
# scalar assignment
return MemoryCopyScalar(self.pos, self)
else:
return MemoryCopySlice(self.pos, self)
def is_simple(self):
if self.is_ellipsis_noop:
# TODO: fix SimpleCallNode.is_simple()
return self.base.is_simple() or self.base.result_in_temp()
return self.result_in_temp()
def calculate_result_code(self):
"""This is called in case this is a no-op slicing node"""
return self.base.result()
def generate_result_code(self, code):
if self.is_ellipsis_noop:
return ### FIXME: remove
buffer_entry = self.buffer_entry()
have_gil = not self.in_nogil_context
# TODO Mark: this is insane, do it better
have_slices = False
it = iter(self.indices)
for index in self.original_indices:
is_slice = isinstance(index, SliceNode)
have_slices = have_slices or is_slice
if is_slice:
if index.is_slice:
have_slices = True
if not index.start.is_none:
index.start = next(it)
if not index.stop.is_none:
......@@ -3960,21 +3996,123 @@ class IndexNode(ExprNode):
assert not list(it)
buffer_entry.generate_buffer_slice_code(code, self.original_indices,
self.result(),
have_gil=have_gil,
have_slices=have_slices,
directives=code.globalstate.directives)
buffer_entry.generate_buffer_slice_code(
code, self.original_indices, self.result(),
have_gil=have_gil, have_slices=have_slices,
directives=code.globalstate.directives)
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
if self.is_ellipsis_noop:
self.generate_subexpr_evaluation_code(code)
else:
self.generate_evaluation_code(code)
if self.is_memview_scalar_assignment:
self.generate_memoryviewslice_assign_scalar_code(rhs, code)
else:
self.generate_memoryviewslice_setslice_code(rhs, code)
if self.is_ellipsis_noop:
self.generate_subexpr_disposal_code(code)
else:
self.generate_disposal_code(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
class MemoryCopyNode(ExprNode):
"""
Wraps a memoryview slice for slice assignment.
dst: destination mememoryview slice
"""
subexprs = ['dst']
def __init__(self, pos, dst):
super(MemoryCopyNode, self).__init__(pos)
self.dst = dst
self.type = dst.type
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.dst.generate_evaluation_code(code)
self._generate_assignment_code(rhs, code)
self.dst.generate_disposal_code(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
def generate_memoryviewslice_setslice_code(self, rhs, code):
"memslice1[...] = memslice2 or memslice1[:] = memslice2"
from . import MemoryView
MemoryView.copy_broadcast_memview_src_to_dst(rhs, self, code)
def generate_memoryviewslice_assign_scalar_code(self, rhs, code):
"memslice1[...] = 0.0 or memslice1[:] = 0.0"
class MemoryCopySlice(MemoryCopyNode):
"""
Copy the contents of slice src to slice dst. Does not support indirect
slices.
memslice1[...] = memslice2
memslice1[:] = memslice2
"""
is_memview_copy_assignment = True
copy_slice_cname = "__pyx_memoryview_copy_contents"
def _generate_assignment_code(self, src, code):
dst = self.dst
src.type.assert_direct_dims(src.pos)
dst.type.assert_direct_dims(dst.pos)
code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (self.copy_slice_cname,
src.result(), dst.result(),
src.type.ndim, dst.type.ndim,
dst.type.dtype.is_pyobject),
dst.pos))
class MemoryCopyScalar(MemoryCopyNode):
"""
Assign a scalar to a slice. dst must be simple, scalar will be assigned
to a correct type and not just something assignable.
memslice1[...] = 0.0
memslice1[:] = 0.0
"""
def __init__(self, pos, dst):
super(MemoryCopyScalar, self).__init__(pos, dst)
self.type = dst.type.dtype
def _generate_assignment_code(self, scalar, code):
from . import MemoryView
MemoryView.assign_scalar(self, rhs, code)
self.dst.type.assert_direct_dims(self.dst.pos)
dtype = self.dst.type.dtype
type_decl = dtype.declaration_code("")
slice_decl = self.dst.type.declaration_code("")
code.begin_block()
code.putln("%s __pyx_temp_scalar = %s;" % (type_decl, scalar.result()))
if self.dst.result_in_temp() or self.dst.is_simple():
dst_temp = self.dst.result()
else:
code.putln("%s __pyx_temp_slice = %s;" % (slice_decl, self.dst.result()))
dst_temp = "__pyx_temp_slice"
slice_iter_obj = MemoryView.slice_iter(self.dst.type, dst_temp,
self.dst.type.ndim, code)
p = slice_iter_obj.start_loops()
if dtype.is_pyobject:
code.putln("Py_DECREF(*(PyObject **) %s);" % p)
code.putln("*((%s *) %s) = __pyx_temp_scalar;" % (type_decl, p))
if dtype.is_pyobject:
code.putln("Py_INCREF(__pyx_temp_scalar);")
slice_iter_obj.end_loops()
code.end_block()
class SliceIndexNode(ExprNode):
......@@ -4428,7 +4566,7 @@ class SliceNode(ExprNode):
# step ExprNode
subexprs = ['start', 'stop', 'step']
is_slice = True
type = slice_type
is_temp = 1
......@@ -4710,8 +4848,7 @@ class SimpleCallNode(CallNode):
return
elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry
elif (isinstance(self.function, IndexNode) and
self.function.is_fused_index):
elif self.function.is_subscript and self.function.is_fused_index:
overloaded_entry = self.function.type.entry
else:
overloaded_entry = None
......@@ -6014,7 +6151,7 @@ class AttributeNode(ExprNode):
self.is_memslice_transpose = True
self.is_temp = True
self.use_managed_ref = True
self.type = self.obj.type
self.type = self.obj.type.transpose(self.pos)
return
else:
obj_type.declare_attribute(self.attribute, env, self.pos)
......@@ -6099,13 +6236,9 @@ class AttributeNode(ExprNode):
self.obj = self.obj.as_none_safe_node(msg, 'PyExc_AttributeError',
format_args=format_args)
def nogil_check(self, env):
if self.is_py_attr:
self.gil_error()
elif self.type.is_memoryviewslice:
from . import MemoryView
MemoryView.err_if_nogil_initialized_check(self.pos, env, 'attribute')
gil_message = "Accessing Python attribute"
......@@ -9246,7 +9379,7 @@ class AmpersandNode(CUnopNode):
if argtype.is_memoryviewslice:
self.error("Cannot take address of memoryview slice")
else:
self.error("Taking address of non-lvalue")
self.error("Taking address of non-lvalue (type %s)" % argtype)
return self
if argtype.is_pyobject:
self.error("Cannot take address of Python variable")
......@@ -9434,6 +9567,7 @@ ERR_STEPS = ("Strides may only be given to indicate contiguity. "
ERR_NOT_POINTER = "Can only create cython.array from pointer or array"
ERR_BASE_TYPE = "Pointer base type does not match cython.array base type"
class CythonArrayNode(ExprNode):
"""
Used when a pointer of base_type is cast to a memoryviewslice with that
......@@ -9474,8 +9608,6 @@ class CythonArrayNode(ExprNode):
array_dtype = self.base_type_node.base_type_node.analyse(env)
axes = self.base_type_node.axes
MemoryView.validate_memslice_dtype(self.pos, array_dtype)
self.type = error_type
self.shapes = []
ndim = len(axes)
......@@ -9564,6 +9696,7 @@ class CythonArrayNode(ExprNode):
axes[-1] = ('direct', 'contig')
self.coercion_type = PyrexTypes.MemoryViewSliceType(array_dtype, axes)
self.coercion_type.validate_memslice_dtype(self.pos)
self.type = self.get_cython_array_type(env)
MemoryView.use_cython_array_utility_code(env)
env.use_utility_code(MemoryView.typeinfo_to_format_code)
......@@ -11639,6 +11772,7 @@ class CoercionNode(ExprNode):
code.annotate((file, line, col-1), AnnotationItem(
style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type)))
class CoerceToMemViewSliceNode(CoercionNode):
"""
Coerce an object to a memoryview slice. This holds a new reference in
......
......@@ -200,7 +200,7 @@ class FusedCFuncDefNode(StatListNode):
if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype)
arg.type.validate_memslice_dtype(arg.pos)
def create_new_local_scope(self, node, env, f2s):
"""
......
......@@ -21,15 +21,11 @@ CF_ERR = "Invalid axis specification for a C/Fortran contiguous array."
ERR_UNINITIALIZED = ("Cannot check if memoryview %s is initialized without the "
"GIL, consider using initializedcheck(False)")
def err_if_nogil_initialized_check(pos, env, name='variable'):
"This raises an exception at runtime now"
pass
#if env.nogil and env.directives['initializedcheck']:
#error(pos, ERR_UNINITIALIZED % name)
def concat_flags(*flags):
return "(%s)" % "|".join(flags)
format_flag = "PyBUF_FORMAT"
memview_c_contiguous = "(PyBUF_C_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
......@@ -71,18 +67,16 @@ memview_typeptr_cname = '__pyx_memoryview_type'
memview_objstruct_cname = '__pyx_memoryview_obj'
memviewslice_cname = u'__Pyx_memviewslice'
def put_init_entry(mv_cname, code):
code.putln("%s.data = NULL;" % mv_cname)
code.putln("%s.memview = NULL;" % mv_cname)
def mangle_dtype_name(dtype):
# a dumb wrapper for now; move Buffer.mangle_dtype_name in here later?
from . import Buffer
return Buffer.mangle_dtype_name(dtype)
#def axes_to_str(axes):
# return "".join([access[0].upper()+packing[0] for (access, packing) in axes])
def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
have_gil=False, first_assignment=True):
"We can avoid decreffing the lhs if we know it is the first assignment"
......@@ -103,6 +97,7 @@ def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
if not pretty_rhs:
code.funcstate.release_temp(rhstmp)
def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code,
have_gil=False, first_assignment=False):
if not first_assignment:
......@@ -113,6 +108,7 @@ def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code
code.putln("%s = %s;" % (lhs_cname, rhs_cname))
def get_buf_flags(specs):
is_c_contig, is_f_contig = is_cf_contig(specs)
......@@ -128,11 +124,13 @@ def get_buf_flags(specs):
else:
return memview_strided_access
def insert_newaxes(memoryviewtype, n):
axes = [('direct', 'strided')] * n
axes.extend(memoryviewtype.axes)
return PyrexTypes.MemoryViewSliceType(memoryviewtype.dtype, axes)
def broadcast_types(src, dst):
n = abs(src.ndim - dst.ndim)
if src.ndim < dst.ndim:
......@@ -140,37 +138,6 @@ def broadcast_types(src, dst):
else:
return src, insert_newaxes(dst, n)
def src_conforms_to_dst(src, dst, broadcast=False):
'''
returns True if src conforms to dst, False otherwise.
If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
Any packing/access spec is conformable to itself.
'direct' and 'ptr' are conformable to 'full'.
'contig' and 'follow' are conformable to 'strided'.
Any other combo is not conformable.
'''
if src.dtype != dst.dtype:
return False
if src.ndim != dst.ndim:
if broadcast:
src, dst = broadcast_types(src, dst)
else:
return False
for src_spec, dst_spec in zip(src.axes, dst.axes):
src_access, src_packing = src_spec
dst_access, dst_packing = dst_spec
if src_access != dst_access and dst_access != 'full':
return False
if src_packing != dst_packing and dst_packing != 'strided':
return False
return True
def valid_memslice_dtype(dtype, i=0):
"""
......@@ -204,22 +171,22 @@ def valid_memslice_dtype(dtype, i=0):
(dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
)
def validate_memslice_dtype(pos, dtype):
if not valid_memslice_dtype(dtype):
error(pos, "Invalid base type for memoryview slice: %s" % dtype)
class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
"""
May be used during code generation time to be queried for
shape/strides/suboffsets attributes, or to perform indexing or slicing.
"""
def __init__(self, entry):
self.entry = entry
self.type = entry.type
self.cname = entry.cname
self.buf_ptr = "%s.data" % self.cname
dtype = self.entry.type.dtype
dtype = PyrexTypes.CPtrType(dtype)
self.buf_ptr_type = dtype
self.buf_ptr_type = PyrexTypes.CPtrType(dtype)
self.init_attributes()
def get_buf_suboffsetvars(self):
return self._for_all_ndim("%s.suboffsets[%d]")
......@@ -236,6 +203,10 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
return self._generate_buffer_lookup_code(code, axes)
def _generate_buffer_lookup_code(self, code, axes, cast_result=True):
"""
Generate a single expression that indexes the memory view slice
in each dimension.
"""
bufp = self.buf_ptr
type_decl = self.type.dtype.empty_declaration_code()
......@@ -286,7 +257,9 @@ class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
then it must be coercible to Py_ssize_t
Simply call __pyx_memoryview_slice_memviewslice with the right
arguments.
arguments, unless the dimension is omitted or a bare ':', in which
case we copy over the shape/strides/suboffsets attributes directly
for that dimension.
"""
src = self.cname
......@@ -368,11 +341,13 @@ def empty_slice(pos):
return ExprNodes.SliceNode(pos, start=none,
stop=none, step=none)
def unellipsify(indices, newaxes, ndim):
def unellipsify(indices, ndim):
result = []
seen_ellipsis = False
have_slices = False
newaxes = [newaxis for newaxis in indices if newaxis.is_none]
n_indices = len(indices) - len(newaxes)
for index in indices:
......@@ -387,9 +362,7 @@ def unellipsify(indices, newaxes, ndim):
result.extend([full_slice] * nslices)
seen_ellipsis = True
else:
have_slices = (have_slices or
isinstance(index, ExprNodes.SliceNode) or
index.is_none)
have_slices = have_slices or index.is_slice or index.is_none
result.append(index)
result_length = len(result) - len(newaxes)
......@@ -398,7 +371,8 @@ def unellipsify(indices, newaxes, ndim):
nslices = ndim - result_length
result.extend([empty_slice(indices[-1].pos)] * nslices)
return have_slices, result
return have_slices, result, newaxes
def get_memoryview_flag(access, packing):
if access == 'full' and packing in ('strided', 'follow'):
......@@ -415,9 +389,11 @@ def get_memoryview_flag(access, packing):
assert (access, packing) == ('direct', 'contig'), (access, packing)
return 'contiguous'
def get_is_contig_func_name(c_or_f, ndim):
return "__pyx_memviewslice_is_%s_contig%d" % (c_or_f, ndim)
def get_is_contig_utility(c_contig, ndim):
C = dict(context, ndim=ndim)
if c_contig:
......@@ -430,88 +406,21 @@ def get_is_contig_utility(c_contig, ndim):
return utility
def copy_src_to_dst_cname():
return "__pyx_memoryview_copy_contents"
def verify_direct_dimensions(node):
for access, packing in node.type.axes:
if access != 'direct':
error(node.pos, "All dimensions must be direct")
def copy_broadcast_memview_src_to_dst(src, dst, code):
"""
Copy the contents of slice src to slice dst. Does not support indirect
slices.
"""
verify_direct_dimensions(src)
verify_direct_dimensions(dst)
code.putln(code.error_goto_if_neg(
"%s(%s, %s, %d, %d, %d)" % (copy_src_to_dst_cname(),
src.result(), dst.result(),
src.type.ndim, dst.type.ndim,
dst.type.dtype.is_pyobject),
dst.pos))
def get_1d_fill_scalar_func(type, code):
dtype = type.dtype
type_decl = dtype.empty_declaration_code()
dtype_name = mangle_dtype_name(dtype)
context = dict(dtype_name=dtype_name, type_decl=type_decl)
utility = load_memview_c_utility("FillStrided1DScalar", context)
code.globalstate.use_utility_code(utility)
return '__pyx_fill_slice_%s' % dtype_name
def assign_scalar(dst, scalar, code):
"""
Assign a scalar to a slice. dst must be a temp, scalar will be assigned
to a correct type and not just something assignable.
"""
verify_direct_dimensions(dst)
dtype = dst.type.dtype
type_decl = dtype.empty_declaration_code()
slice_decl = dst.type.empty_declaration_code()
code.begin_block()
code.putln("%s __pyx_temp_scalar = %s;" % (type_decl, scalar.result()))
if dst.result_in_temp() or (dst.base.is_name and
isinstance(dst.index, ExprNodes.EllipsisNode)):
dst_temp = dst.result()
else:
code.putln("%s __pyx_temp_slice = %s;" % (slice_decl, dst.result()))
dst_temp = "__pyx_temp_slice"
# with slice_iter(dst.type, dst_temp, dst.type.ndim, code) as p:
slice_iter_obj = slice_iter(dst.type, dst_temp, dst.type.ndim, code)
p = slice_iter_obj.start_loops()
if dtype.is_pyobject:
code.putln("Py_DECREF(*(PyObject **) %s);" % p)
code.putln("*((%s *) %s) = __pyx_temp_scalar;" % (type_decl, p))
if dtype.is_pyobject:
code.putln("Py_INCREF(__pyx_temp_scalar);")
slice_iter_obj.end_loops()
code.end_block()
def slice_iter(slice_type, slice_temp, ndim, code):
def slice_iter(slice_type, slice_result, ndim, code):
if slice_type.is_c_contig or slice_type.is_f_contig:
return ContigSliceIter(slice_type, slice_temp, ndim, code)
return ContigSliceIter(slice_type, slice_result, ndim, code)
else:
return StridedSliceIter(slice_type, slice_temp, ndim, code)
return StridedSliceIter(slice_type, slice_result, ndim, code)
class SliceIter(object):
def __init__(self, slice_type, slice_temp, ndim, code):
def __init__(self, slice_type, slice_result, ndim, code):
self.slice_type = slice_type
self.slice_temp = slice_temp
self.slice_result = slice_result
self.code = code
self.ndim = ndim
class ContigSliceIter(SliceIter):
def start_loops(self):
code = self.code
......@@ -519,12 +428,12 @@ class ContigSliceIter(SliceIter):
type_decl = self.slice_type.dtype.empty_declaration_code()
total_size = ' * '.join("%s.shape[%d]" % (self.slice_temp, i)
for i in range(self.ndim))
total_size = ' * '.join("%s.shape[%d]" % (self.slice_result, i)
for i in range(self.ndim))
code.putln("Py_ssize_t __pyx_temp_extent = %s;" % total_size)
code.putln("Py_ssize_t __pyx_temp_idx;")
code.putln("%s *__pyx_temp_pointer = (%s *) %s.data;" % (
type_decl, type_decl, self.slice_temp))
type_decl, type_decl, self.slice_result))
code.putln("for (__pyx_temp_idx = 0; "
"__pyx_temp_idx < __pyx_temp_extent; "
"__pyx_temp_idx++) {")
......@@ -536,19 +445,20 @@ class ContigSliceIter(SliceIter):
self.code.putln("}")
self.code.end_block()
class StridedSliceIter(SliceIter):
def start_loops(self):
code = self.code
code.begin_block()
for i in range(self.ndim):
t = i, self.slice_temp, i
t = i, self.slice_result, i
code.putln("Py_ssize_t __pyx_temp_extent_%d = %s.shape[%d];" % t)
code.putln("Py_ssize_t __pyx_temp_stride_%d = %s.strides[%d];" % t)
code.putln("char *__pyx_temp_pointer_%d;" % i)
code.putln("Py_ssize_t __pyx_temp_idx_%d;" % i)
code.putln("__pyx_temp_pointer_0 = %s.data;" % self.slice_temp)
code.putln("__pyx_temp_pointer_0 = %s.data;" % self.slice_result)
for i in range(self.ndim):
if i > 0:
......
......@@ -1054,8 +1054,8 @@ class MemoryViewSliceTypeNode(CBaseTypeNode):
if not MemoryView.validate_axes(self.pos, axes_specs):
self.type = error_type
else:
MemoryView.validate_memslice_dtype(self.pos, base_type)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
self.type.validate_memslice_dtype(self.pos)
self.use_memview_utilities(env)
return self.type
......@@ -4896,26 +4896,14 @@ class SingleAssignmentNode(AssignmentNode):
if unrolled_assignment:
return unrolled_assignment
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast:
self.lhs.memslice_broadcast = True
self.rhs.memslice_broadcast = True
if (self.lhs.is_subscript and not self.rhs.type.is_memoryviewslice and
(self.lhs.memslice_slice or self.lhs.is_memslice_copy) and
(self.lhs.type.dtype.assignable_from(self.rhs.type) or
self.rhs.type.is_pyobject)):
# scalar slice assignment
self.lhs.is_memslice_scalar_assignment = True
dtype = self.lhs.type.dtype
if isinstance(self.lhs, ExprNodes.MemoryViewIndexNode):
self.lhs.analyse_broadcast_operation(self.rhs)
self.lhs = self.lhs.analyse_as_memview_scalar_assignment(self.rhs)
elif self.lhs.type.is_array:
if not isinstance(self.lhs, ExprNodes.SliceIndexNode):
# cannot assign to C array, only to its full slice
self.lhs = ExprNodes.SliceIndexNode(
self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs = ExprNodes.SliceIndexNode(self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs = self.lhs.analyse_target_types(env)
dtype = self.lhs.type
else:
dtype = self.lhs.type
if self.lhs.type.is_cpp_class:
op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type])
......@@ -4923,9 +4911,10 @@ class SingleAssignmentNode(AssignmentNode):
rhs = self.rhs
self.is_overloaded_assignment = True
else:
rhs = self.rhs.coerce_to(dtype, env)
rhs = self.rhs.coerce_to(self.lhs.type, env)
else:
rhs = self.rhs.coerce_to(dtype, env)
rhs = self.rhs.coerce_to(self.lhs.type, env)
if use_temp or rhs.is_attribute or (
not rhs.is_name and not rhs.is_literal and
rhs.type.is_pyobject):
......@@ -5035,12 +5024,12 @@ class SingleAssignmentNode(AssignmentNode):
assignments = []
for lhs, rhs in zip(lhs_list, rhs_list):
assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first))
all = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env)
node = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env)
if check_node:
all = StatListNode(pos=self.pos, stats=[check_node, all])
node = StatListNode(pos=self.pos, stats=[check_node, node])
for ref in refs[::-1]:
all = UtilNodes.LetNode(ref, all)
return all
node = UtilNodes.LetNode(ref, node)
return node
def unroll_rhs(self, env):
from . import ExprNodes
......@@ -5059,7 +5048,7 @@ class SingleAssignmentNode(AssignmentNode):
if self.lhs.type.is_ctuple:
# Handled directly.
return
from . import ExprNodes, UtilNodes
from . import ExprNodes
if not isinstance(self.rhs, ExprNodes.TupleNode):
return
......@@ -5261,8 +5250,7 @@ class InPlaceAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env)
# When assigning to a fully indexed buffer or memoryview, coerce the rhs
if (self.lhs.is_subscript and
(self.lhs.memslice_index or self.lhs.is_buffer_access)):
if self.lhs.is_memview_index or self.lhs.is_buffer_access:
self.rhs = self.rhs.coerce_to(self.lhs.type, env)
elif self.lhs.type.is_string and self.operator in '+-':
# use pointer arithmetic for char* LHS instead of string concat
......@@ -5271,28 +5259,30 @@ class InPlaceAssignmentNode(AssignmentNode):
def generate_execution_code(self, code):
code.mark_pos(self.pos)
self.rhs.generate_evaluation_code(code)
self.lhs.generate_subexpr_evaluation_code(code)
lhs, rhs = self.lhs, self.rhs
rhs.generate_evaluation_code(code)
lhs.generate_subexpr_evaluation_code(code)
c_op = self.operator
if c_op == "//":
c_op = "/"
elif c_op == "**":
error(self.pos, "No C inplace power operator")
if self.lhs.is_subscript and self.lhs.is_buffer_access:
if self.lhs.type.is_pyobject:
if lhs.is_buffer_access or lhs.is_memview_index:
if lhs.type.is_pyobject:
error(self.pos, "In-place operators not allowed on object buffers in this release.")
if (c_op in ('/', '%') and self.lhs.type.is_int
and not code.globalstate.directives['cdivision']):
if c_op in ('/', '%') and lhs.type.is_int and not code.globalstate.directives['cdivision']:
error(self.pos, "In-place non-c divide operators not allowed on int buffers.")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
lhs.generate_buffer_setitem_code(rhs, code, c_op)
elif lhs.is_memview_slice:
error(self.pos, "Inplace operators not supported on memoryview slices")
else:
# C++
# TODO: make sure overload is declared
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()))
self.lhs.generate_subexpr_disposal_code(code)
self.lhs.free_subexpr_temps(code)
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
code.putln("%s %s= %s;" % (lhs.result(), c_op, rhs.result()))
lhs.generate_subexpr_disposal_code(code)
lhs.free_subexpr_temps(code)
rhs.generate_disposal_code(code)
rhs.free_temps(code)
def annotate(self, code):
self.lhs.annotate(code)
......@@ -6344,8 +6334,8 @@ class ForFromStatNode(LoopNode, StatNode):
"for-from loop variable must be c numeric type or Python object")
if target_type.is_numeric:
self.is_py_target = False
if isinstance(self.target, ExprNodes.IndexNode) and self.target.is_buffer_access:
raise error(self.pos, "Buffer indexing not allowed as for loop target.")
if isinstance(self.target, ExprNodes.BufferIndexNode):
raise error(self.pos, "Buffer or memoryview slicing/indexing not allowed as for-loop target.")
self.loopvar_node = self.target
self.py_loopvar_node = None
else:
......
......@@ -132,7 +132,7 @@ class IterationTransform(Visitor.EnvTransform):
pos = node.pos
result_ref = UtilNodes.ResultRefNode(node)
if isinstance(node.operand2, ExprNodes.IndexNode):
if node.operand2.is_subscript:
base_type = node.operand2.base.type.base_type
else:
base_type = node.operand2.type.base_type
......@@ -442,7 +442,7 @@ class IterationTransform(Visitor.EnvTransform):
error(slice_node.pos, "C array iteration requires known end index")
return node
elif isinstance(slice_node, ExprNodes.IndexNode):
elif slice_node.is_subscript:
assert isinstance(slice_node.index, ExprNodes.SliceNode)
slice_base = slice_node.base
index = slice_node.index
......@@ -564,7 +564,6 @@ class IterationTransform(Visitor.EnvTransform):
constant_result=0,
type=PyrexTypes.c_int_type),
base=counter_temp,
is_buffer_access=False,
type=ptr_type.base_type)
if target_value.type != node.target.type:
......@@ -1334,20 +1333,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
node = node.arg
name_path = []
obj_node = node
while isinstance(obj_node, ExprNodes.AttributeNode):
while obj_node.is_attribute:
if obj_node.is_py_attr:
return False
name_path.append(obj_node.member)
obj_node = obj_node.obj
if isinstance(obj_node, ExprNodes.NameNode):
if obj_node.is_name:
name_path.append(obj_node.name)
names.append( ('.'.join(name_path[::-1]), node) )
elif isinstance(node, ExprNodes.IndexNode):
elif node.is_subscript:
if node.base.type != Builtin.list_type:
return False
if not node.index.type.is_int:
return False
if not isinstance(node.base, ExprNodes.NameNode):
if not node.base.is_name:
return False
indices.append(node)
else:
......@@ -1979,7 +1978,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
elif isinstance(arg, ExprNodes.SimpleCallNode):
if node.type.is_int or node.type.is_float:
return self._optimise_numeric_cast_call(node, arg)
elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
elif arg.is_subscript:
index_node = arg.index
if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
index_node = index_node.arg
......
......@@ -17,7 +17,7 @@ from . import Builtin
from .Visitor import VisitorTransform, TreeVisitor
from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from .UtilNodes import LetNode, LetRefNode, ResultRefNode
from .UtilNodes import LetNode, LetRefNode
from .TreeFragment import TreeFragment
from .StringEncoding import EncodedString, _unicode
from .Errors import error, warning, CompileError, InternalError
......@@ -1931,13 +1931,8 @@ class AnalyseExpressionsTransform(CythonTransform):
re-analyse the types.
"""
self.visit_Node(node)
if node.is_fused_index and not node.type.is_error:
node = node.base
elif node.memslice_ellipsis_noop:
# memoryviewslice[...] expression, drop the IndexNode
node = node.base
return node
......@@ -1971,26 +1966,26 @@ class ExpandInplaceOperators(EnvTransform):
if lhs.type.is_cpp_class:
# No getting around this exact operator here.
return node
if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access:
# There is code to handle this case.
if isinstance(lhs, ExprNodes.BufferIndexNode):
# There is code to handle this case in InPlaceAssignmentNode
return node
env = self.current_env()
def side_effect_free_reference(node, setting=False):
if isinstance(node, ExprNodes.NameNode):
if node.is_name:
return node, []
elif node.type.is_pyobject and not setting:
node = LetRefNode(node)
return node, [node]
elif isinstance(node, ExprNodes.IndexNode):
if node.is_buffer_access:
raise ValueError("Buffer access")
elif node.is_subscript:
base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index)
return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, ExprNodes.AttributeNode):
elif node.is_attribute:
obj, temps = side_effect_free_reference(node.obj)
return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
elif isinstance(node, ExprNodes.BufferIndexNode):
raise ValueError("Don't allow things like attributes of buffer indexing operations")
else:
node = LetRefNode(node)
return node, [node]
......
......@@ -541,7 +541,7 @@ class MemoryViewSliceType(PyrexType):
the *first* axis' packing spec and 'follow' for all other packing
specs.
"""
from . import MemoryView
from . import Buffer, MemoryView
self.dtype = base_dtype
self.axes = axes
......@@ -555,7 +555,7 @@ class MemoryViewSliceType(PyrexType):
self.writable_needed = False
if not self.dtype.is_fused:
self.dtype_name = MemoryView.mangle_dtype_name(self.dtype)
self.dtype_name = Buffer.mangle_dtype_name(self.dtype)
def __hash__(self):
return hash(self.__class__) ^ hash(self.dtype) ^ hash(tuple(self.axes))
......@@ -638,25 +638,28 @@ class MemoryViewSliceType(PyrexType):
elif attribute in ("copy", "copy_fortran"):
ndim = len(self.axes)
to_axes_c = [('direct', 'contig')]
to_axes_f = [('direct', 'contig')]
if ndim - 1:
to_axes_c = [('direct', 'follow')]*(ndim-1) + to_axes_c
to_axes_f = to_axes_f + [('direct', 'follow')]*(ndim-1)
follow_dim = [('direct', 'follow')]
contig_dim = [('direct', 'contig')]
to_axes_c = follow_dim * (ndim - 1) + contig_dim
to_axes_f = contig_dim + follow_dim * (ndim -1)
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
for to_memview, cython_name in [(to_memview_c, "copy"),
(to_memview_f, "copy_fortran")]:
entry = scope.declare_cfunction(cython_name,
CFuncType(self, [CFuncTypeArg("memviewslice", self, None)]),
pos=pos,
defining=1,
cname=MemoryView.copy_c_or_fortran_cname(to_memview))
copy_func_type = CFuncType(
to_memview,
[CFuncTypeArg("memviewslice", self, None)])
copy_cname = MemoryView.copy_c_or_fortran_cname(to_memview)
entry = scope.declare_cfunction(
cython_name,
copy_func_type, pos=pos, defining=1,
cname=copy_cname)
#entry.utility_code_definition = \
env.use_utility_code(MemoryView.get_copy_new_utility(pos, self, to_memview))
utility = MemoryView.get_copy_new_utility(pos, self, to_memview)
env.use_utility_code(utility)
MemoryView.use_cython_array_utility_code(env)
......@@ -684,9 +687,102 @@ class MemoryViewSliceType(PyrexType):
return True
def get_entry(self, node, cname=None, type=None):
from . import MemoryView, Symtab
if cname is None:
assert node.is_simple() or node.is_temp or node.is_elemental
cname = node.result()
if type is None:
type = node.type
entry = Symtab.Entry(cname, cname, type, node.pos)
return MemoryView.MemoryViewSliceBufferEntry(entry)
def conforms_to(self, dst, broadcast=False, copying=False):
"""
Returns True if src conforms to dst, False otherwise.
If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
Any packing/access spec is conformable to itself.
'direct' and 'ptr' are conformable to 'full'.
'contig' and 'follow' are conformable to 'strided'.
Any other combo is not conformable.
"""
from . import MemoryView
src = self
if src.dtype != dst.dtype:
return False
if src.ndim != dst.ndim:
if broadcast:
src, dst = MemoryView.broadcast_types(src, dst)
else:
return False
for src_spec, dst_spec in zip(src.axes, dst.axes):
src_access, src_packing = src_spec
dst_access, dst_packing = dst_spec
if src_access != dst_access and dst_access != 'full':
return False
if src_packing != dst_packing and dst_packing != 'strided' and not copying:
return False
return True
def valid_dtype(self, dtype, i=0):
"""
Return whether type dtype can be used as the base type of a
memoryview slice.
We support structs, numeric types and objects
"""
if dtype.is_complex and dtype.real_type.is_int:
return False
if dtype.is_struct and dtype.kind == 'struct':
for member in dtype.scope.var_entries:
if not self.valid_dtype(member.type):
return False
return True
return (
dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
(dtype.is_array and i < 8 and self.valid_dtype(dtype.base_type, i + 1)) or
dtype.is_numeric or
dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and self.valid_dtype(dtype.typedef_base_type))
)
def validate_memslice_dtype(self, pos):
if not self.valid_dtype(self.dtype):
error(pos, "Invalid base type for memoryview slice: %s" % self.dtype)
def assert_direct_dims(self, pos):
for access, packing in self.axes:
if access != 'direct':
error(pos, "All dimensions must be direct")
return False
return True
def transpose(self, pos):
if not self.assert_direct_dims(pos):
return error_type
return MemoryViewSliceType(self.dtype, self.axes[::-1])
def specialization_name(self):
return super(MemoryViewSliceType,self).specialization_name() \
+ '_' + self.specialization_suffix()
return '%s_%s' % (
super(MemoryViewSliceType,self).specialization_name(),
self.specialization_suffix())
def specialization_suffix(self):
return "%s_%s" % (self.axes_to_name(), self.dtype_name)
......@@ -874,6 +970,11 @@ class BufferType(BaseType):
self.negative_indices, self.cast)
return self
def get_entry(self, node):
from . import Buffer
assert node.is_name
return Buffer.BufferEntry(node.entry)
def __getattr__(self, name):
return getattr(self.base, name)
......
......@@ -79,7 +79,7 @@ cdef extern from *:
size_t sizeof_dtype, int contig_flag,
bint dtype_is_object) nogil except *
bint slice_is_contig "__pyx_memviewslice_is_contig" (
{{memviewslice_name}} *mvs, char order, int ndim) nogil
{{memviewslice_name}} mvs, char order, int ndim) nogil
bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1,
{{memviewslice_name}} *slice2,
int ndim, size_t itemsize) nogil
......@@ -578,13 +578,13 @@ cdef class memoryview(object):
cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'C', self.view.ndim)
return slice_is_contig(mslice[0], 'C', self.view.ndim)
def is_f_contig(self):
cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'F', self.view.ndim)
return slice_is_contig(mslice[0], 'F', self.view.ndim)
def copy(self):
cdef {{memviewslice_name}} mslice
......@@ -1195,7 +1195,7 @@ cdef void *copy_data_to_temp({{memviewslice_name}} *src,
if tmpslice.shape[i] == 1:
tmpslice.strides[i] = 0
if slice_is_contig(src, order, ndim):
if slice_is_contig(src[0], order, ndim):
memcpy(result, src.data, size)
else:
copy_strided_to_strided(src, tmpslice, ndim, itemsize)
......@@ -1258,7 +1258,7 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if slices_overlap(&src, &dst, ndim, itemsize):
# slices overlap, copy to temp, copy temp to dst
if not slice_is_contig(&src, order, ndim):
if not slice_is_contig(src, order, ndim):
order = get_best_order(&dst, ndim)
tmpdata = copy_data_to_temp(&src, &tmp, order, ndim)
......@@ -1267,10 +1267,10 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if not broadcasting:
# See if both slices have equal contiguity, in that case perform a
# direct copy. This only works when we are not broadcasting.
if slice_is_contig(&src, 'C', ndim):
direct_copy = slice_is_contig(&dst, 'C', ndim)
elif slice_is_contig(&src, 'F', ndim):
direct_copy = slice_is_contig(&dst, 'F', ndim)
if slice_is_contig(src, 'C', ndim):
direct_copy = slice_is_contig(dst, 'C', ndim)
elif slice_is_contig(src, 'F', ndim):
direct_copy = slice_is_contig(dst, 'F', ndim)
if direct_copy:
# Contiguous slices with same order
......
......@@ -692,29 +692,29 @@ __pyx_slices_overlap({{memviewslice_name}} *slice1,
////////// MemviewSliceIsCContig.proto //////////
#define __pyx_memviewslice_is_c_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'C', {{ndim}})
__pyx_memviewslice_is_contig(slice, 'C', {{ndim}})
////////// MemviewSliceIsFContig.proto //////////
#define __pyx_memviewslice_is_f_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'F', {{ndim}})
__pyx_memviewslice_is_contig(slice, 'F', {{ndim}})
////////// MemviewSliceIsContig.proto //////////
static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs,
static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim);
////////// MemviewSliceIsContig //////////
static int
__pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs,
__pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim)
{
int i, index, step, start;
Py_ssize_t itemsize = mvs->memview->view.itemsize;
Py_ssize_t itemsize = mvs.memview->view.itemsize;
if (order == 'F') {
step = 1;
......@@ -726,10 +726,10 @@ __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs,
for (i = 0; i < ndim; i++) {
index = start + step * i;
if (mvs->suboffsets[index] >= 0 || mvs->strides[index] != itemsize)
if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize)
return 0;
itemsize *= mvs->shape[index];
itemsize *= mvs.shape[index];
}
return 1;
......
......@@ -14,6 +14,7 @@ from cython.view cimport array
import numpy as np
cimport numpy as np
@testcase
def test_shape_stride_suboffset():
u'''
......@@ -47,6 +48,7 @@ def test_shape_stride_suboffset():
print c_contig.strides[0], c_contig.strides[1], c_contig.strides[2]
print c_contig.suboffsets[0], c_contig.suboffsets[1], c_contig.suboffsets[2]
@testcase
def test_copy_to():
u'''
......@@ -57,15 +59,19 @@ def test_copy_to():
'''
cdef int[:, :, :] from_mvs, to_mvs
from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2)
cdef int *from_data = <int *> from_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2))
to_mvs = array((2,2,2), sizeof(int), 'i')
to_mvs[...] = from_mvs
# TODO Mark: remove this _data attribute
cdef int *to_data = <int*>to_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2))
print ' '.join(str(to_data[i]) for i in range(2*2*2))
@testcase
def test_overlapping_copy():
"""
......@@ -81,6 +87,22 @@ def test_overlapping_copy():
for i in range(10):
assert slice[i] == 10 - 1 - i
@testcase
def test_copy_return_type():
"""
>>> test_copy_return_type()
60.0
60.0
"""
cdef double[:, :, :] a = np.arange(5 * 5 * 5, dtype=np.float64).reshape(5, 5, 5)
cdef double[:, ::1] c_contig = a[..., 0].copy()
cdef double[::1, :] f_contig = a[..., 0].copy_fortran()
print(c_contig[2, 2])
print(f_contig[2, 2])
@testcase
def test_partly_overlapping():
"""
......@@ -170,30 +192,34 @@ def test_copy_mismatch():
mv1[...] = mv2
@testcase
def test_is_contiguous():
u'''
u"""
>>> test_is_contiguous()
True True
False True
True False
True False
<BLANKLINE>
False True
True False
'''
one sized is_c/f_contig True True
is_c/f_contig False True
f_contig.copy().is_c/f_contig True False
f_contig.copy_fortran().is_c/f_contig False True
one sized strided contig True True
strided False
"""
cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig() , fort_contig.is_f_contig()
fort_contig = array((200,100,100), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = fort_contig.copy()
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
cdef int[:,:,:] strided = fort_contig
print strided.is_c_contig(), strided.is_f_contig()
print
fort_contig = fort_contig.copy_fortran()
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
print strided.is_c_contig(), strided.is_f_contig()
print 'one sized is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = array((2,2,2), sizeof(int), 'i', mode='fortran')
print 'is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
print 'f_contig.copy().is_c/f_contig', fort_contig.copy().is_c_contig(), \
fort_contig.copy().is_f_contig()
print 'f_contig.copy_fortran().is_c/f_contig', \
fort_contig.copy_fortran().is_c_contig(), \
fort_contig.copy_fortran().is_f_contig()
print 'one sized strided contig', strided.is_c_contig(), strided.is_f_contig()
print 'strided', strided[::2].is_c_contig()
@testcase
......@@ -272,6 +298,7 @@ def two_dee():
print (<long*>mv3._data)[0] , (<long*>mv3._data)[1] , (<long*>mv3._data)[2] , (<long*>mv3._data)[3]
@testcase
def fort_two_dee():
u'''
......@@ -283,7 +310,8 @@ def fort_two_dee():
1 2 3 -4
'''
cdef array arr = array((2,2), sizeof(long), 'l', mode='fortran')
cdef long[::1,:] mv1, mv2, mv3
cdef long[::1,:] mv1, mv2, mv4
cdef long[:, ::1] mv3
cdef long *arr_data
arr_data = <long*>arr.data
......@@ -311,6 +339,6 @@ def fort_two_dee():
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3]
mv3 = mv3.copy_fortran()
mv4 = mv3.copy_fortran()
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3]
print (<long*>mv4._data)[0], (<long*>mv4._data)[1], (<long*>mv4._data)[2], (<long*>mv4._data)[3]
......@@ -163,6 +163,7 @@ def test_ellipsis_memoryview(array):
ae(e.shape[0], e_obj.shape[0])
ae(e.strides[0], e_obj.strides[0])
@testcase
def test_transpose():
"""
......@@ -193,6 +194,20 @@ def test_transpose():
print a[3, 2], a.T[2, 3], a_obj[3, 2], a_obj.T[2, 3], numpy_obj[3, 2], numpy_obj.T[2, 3]
@testcase
def test_transpose_type(a):
"""
>>> a = np.zeros((5, 10), dtype=np.float64)
>>> a[4, 6] = 9
>>> test_transpose_type(a)
9.0
"""
cdef double[:, ::1] m = a
cdef double[::1, :] m_transpose = a.T
print m_transpose[6, 4]
@testcase_numpy_1_5
def test_numpy_like_attributes(cyarray):
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment