Commit ab78f93b authored by Stefan Behnel's avatar Stefan Behnel

adapt and apply major refactoring of IndexNode originally written by Mark Florisson

parent 7da49602
...@@ -201,7 +201,13 @@ class BufferEntry(object): ...@@ -201,7 +201,13 @@ class BufferEntry(object):
self.type = entry.type self.type = entry.type
self.cname = entry.buffer_aux.buflocal_nd_var.cname self.cname = entry.buffer_aux.buflocal_nd_var.cname
self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
self.buf_ptr_type = self.entry.type.buffer_ptr_type self.buf_ptr_type = entry.type.buffer_ptr_type
self.init_attributes()
def init_attributes(self):
self.shape = self.get_buf_shapevars()
self.strides = self.get_buf_stridevars()
self.suboffsets = self.get_buf_suboffsetvars()
def get_buf_suboffsetvars(self): def get_buf_suboffsetvars(self):
return self._for_all_ndim("%s.diminfo[%d].suboffsets") return self._for_all_ndim("%s.diminfo[%d].suboffsets")
......
This diff is collapsed.
...@@ -200,7 +200,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -200,7 +200,7 @@ class FusedCFuncDefNode(StatListNode):
if arg.type.is_fused: if arg.type.is_fused:
arg.type = arg.type.specialize(fused_to_specific) arg.type = arg.type.specialize(fused_to_specific)
if arg.type.is_memoryviewslice: if arg.type.is_memoryviewslice:
MemoryView.validate_memslice_dtype(arg.pos, arg.type.dtype) arg.type.validate_memslice_dtype(arg.pos)
def create_new_local_scope(self, node, env, f2s): def create_new_local_scope(self, node, env, f2s):
""" """
......
This diff is collapsed.
...@@ -1054,8 +1054,8 @@ class MemoryViewSliceTypeNode(CBaseTypeNode): ...@@ -1054,8 +1054,8 @@ class MemoryViewSliceTypeNode(CBaseTypeNode):
if not MemoryView.validate_axes(self.pos, axes_specs): if not MemoryView.validate_axes(self.pos, axes_specs):
self.type = error_type self.type = error_type
else: else:
MemoryView.validate_memslice_dtype(self.pos, base_type)
self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs) self.type = PyrexTypes.MemoryViewSliceType(base_type, axes_specs)
self.type.validate_memslice_dtype(self.pos)
self.use_memview_utilities(env) self.use_memview_utilities(env)
return self.type return self.type
...@@ -4896,26 +4896,14 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4896,26 +4896,14 @@ class SingleAssignmentNode(AssignmentNode):
if unrolled_assignment: if unrolled_assignment:
return unrolled_assignment return unrolled_assignment
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast: if isinstance(self.lhs, ExprNodes.MemoryViewIndexNode):
self.lhs.memslice_broadcast = True self.lhs.analyse_broadcast_operation(self.rhs)
self.rhs.memslice_broadcast = True self.lhs = self.lhs.analyse_as_memview_scalar_assignment(self.rhs)
if (self.lhs.is_subscript and not self.rhs.type.is_memoryviewslice and
(self.lhs.memslice_slice or self.lhs.is_memslice_copy) and
(self.lhs.type.dtype.assignable_from(self.rhs.type) or
self.rhs.type.is_pyobject)):
# scalar slice assignment
self.lhs.is_memslice_scalar_assignment = True
dtype = self.lhs.type.dtype
elif self.lhs.type.is_array: elif self.lhs.type.is_array:
if not isinstance(self.lhs, ExprNodes.SliceIndexNode): if not isinstance(self.lhs, ExprNodes.SliceIndexNode):
# cannot assign to C array, only to its full slice # cannot assign to C array, only to its full slice
self.lhs = ExprNodes.SliceIndexNode( self.lhs = ExprNodes.SliceIndexNode(self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs.pos, base=self.lhs, start=None, stop=None)
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
dtype = self.lhs.type
else:
dtype = self.lhs.type
if self.lhs.type.is_cpp_class: if self.lhs.type.is_cpp_class:
op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type]) op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type])
...@@ -4923,9 +4911,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4923,9 +4911,10 @@ class SingleAssignmentNode(AssignmentNode):
rhs = self.rhs rhs = self.rhs
self.is_overloaded_assignment = True self.is_overloaded_assignment = True
else: else:
rhs = self.rhs.coerce_to(dtype, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
else: else:
rhs = self.rhs.coerce_to(dtype, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
if use_temp or rhs.is_attribute or ( if use_temp or rhs.is_attribute or (
not rhs.is_name and not rhs.is_literal and not rhs.is_name and not rhs.is_literal and
rhs.type.is_pyobject): rhs.type.is_pyobject):
...@@ -5035,12 +5024,12 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5035,12 +5024,12 @@ class SingleAssignmentNode(AssignmentNode):
assignments = [] assignments = []
for lhs, rhs in zip(lhs_list, rhs_list): for lhs, rhs in zip(lhs_list, rhs_list):
assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first)) assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first))
all = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env) node = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env)
if check_node: if check_node:
all = StatListNode(pos=self.pos, stats=[check_node, all]) node = StatListNode(pos=self.pos, stats=[check_node, node])
for ref in refs[::-1]: for ref in refs[::-1]:
all = UtilNodes.LetNode(ref, all) node = UtilNodes.LetNode(ref, node)
return all return node
def unroll_rhs(self, env): def unroll_rhs(self, env):
from . import ExprNodes from . import ExprNodes
...@@ -5059,7 +5048,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5059,7 +5048,7 @@ class SingleAssignmentNode(AssignmentNode):
if self.lhs.type.is_ctuple: if self.lhs.type.is_ctuple:
# Handled directly. # Handled directly.
return return
from . import ExprNodes, UtilNodes from . import ExprNodes
if not isinstance(self.rhs, ExprNodes.TupleNode): if not isinstance(self.rhs, ExprNodes.TupleNode):
return return
...@@ -5261,8 +5250,7 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -5261,8 +5250,7 @@ class InPlaceAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
# When assigning to a fully indexed buffer or memoryview, coerce the rhs # When assigning to a fully indexed buffer or memoryview, coerce the rhs
if (self.lhs.is_subscript and if self.lhs.is_memview_index or self.lhs.is_buffer_access:
(self.lhs.memslice_index or self.lhs.is_buffer_access)):
self.rhs = self.rhs.coerce_to(self.lhs.type, env) self.rhs = self.rhs.coerce_to(self.lhs.type, env)
elif self.lhs.type.is_string and self.operator in '+-': elif self.lhs.type.is_string and self.operator in '+-':
# use pointer arithmetic for char* LHS instead of string concat # use pointer arithmetic for char* LHS instead of string concat
...@@ -5271,28 +5259,30 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -5271,28 +5259,30 @@ class InPlaceAssignmentNode(AssignmentNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.mark_pos(self.pos) code.mark_pos(self.pos)
self.rhs.generate_evaluation_code(code) lhs, rhs = self.lhs, self.rhs
self.lhs.generate_subexpr_evaluation_code(code) rhs.generate_evaluation_code(code)
lhs.generate_subexpr_evaluation_code(code)
c_op = self.operator c_op = self.operator
if c_op == "//": if c_op == "//":
c_op = "/" c_op = "/"
elif c_op == "**": elif c_op == "**":
error(self.pos, "No C inplace power operator") error(self.pos, "No C inplace power operator")
if self.lhs.is_subscript and self.lhs.is_buffer_access: if lhs.is_buffer_access or lhs.is_memview_index:
if self.lhs.type.is_pyobject: if lhs.type.is_pyobject:
error(self.pos, "In-place operators not allowed on object buffers in this release.") error(self.pos, "In-place operators not allowed on object buffers in this release.")
if (c_op in ('/', '%') and self.lhs.type.is_int if c_op in ('/', '%') and lhs.type.is_int and not code.globalstate.directives['cdivision']:
and not code.globalstate.directives['cdivision']):
error(self.pos, "In-place non-c divide operators not allowed on int buffers.") error(self.pos, "In-place non-c divide operators not allowed on int buffers.")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op) lhs.generate_buffer_setitem_code(rhs, code, c_op)
elif lhs.is_memview_slice:
error(self.pos, "Inplace operators not supported on memoryview slices")
else: else:
# C++ # C++
# TODO: make sure overload is declared # TODO: make sure overload is declared
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result())) code.putln("%s %s= %s;" % (lhs.result(), c_op, rhs.result()))
self.lhs.generate_subexpr_disposal_code(code) lhs.generate_subexpr_disposal_code(code)
self.lhs.free_subexpr_temps(code) lhs.free_subexpr_temps(code)
self.rhs.generate_disposal_code(code) rhs.generate_disposal_code(code)
self.rhs.free_temps(code) rhs.free_temps(code)
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
...@@ -6344,8 +6334,8 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -6344,8 +6334,8 @@ class ForFromStatNode(LoopNode, StatNode):
"for-from loop variable must be c numeric type or Python object") "for-from loop variable must be c numeric type or Python object")
if target_type.is_numeric: if target_type.is_numeric:
self.is_py_target = False self.is_py_target = False
if isinstance(self.target, ExprNodes.IndexNode) and self.target.is_buffer_access: if isinstance(self.target, ExprNodes.BufferIndexNode):
raise error(self.pos, "Buffer indexing not allowed as for loop target.") raise error(self.pos, "Buffer or memoryview slicing/indexing not allowed as for-loop target.")
self.loopvar_node = self.target self.loopvar_node = self.target
self.py_loopvar_node = None self.py_loopvar_node = None
else: else:
......
...@@ -132,7 +132,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -132,7 +132,7 @@ class IterationTransform(Visitor.EnvTransform):
pos = node.pos pos = node.pos
result_ref = UtilNodes.ResultRefNode(node) result_ref = UtilNodes.ResultRefNode(node)
if isinstance(node.operand2, ExprNodes.IndexNode): if node.operand2.is_subscript:
base_type = node.operand2.base.type.base_type base_type = node.operand2.base.type.base_type
else: else:
base_type = node.operand2.type.base_type base_type = node.operand2.type.base_type
...@@ -442,7 +442,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -442,7 +442,7 @@ class IterationTransform(Visitor.EnvTransform):
error(slice_node.pos, "C array iteration requires known end index") error(slice_node.pos, "C array iteration requires known end index")
return node return node
elif isinstance(slice_node, ExprNodes.IndexNode): elif slice_node.is_subscript:
assert isinstance(slice_node.index, ExprNodes.SliceNode) assert isinstance(slice_node.index, ExprNodes.SliceNode)
slice_base = slice_node.base slice_base = slice_node.base
index = slice_node.index index = slice_node.index
...@@ -564,7 +564,6 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -564,7 +564,6 @@ class IterationTransform(Visitor.EnvTransform):
constant_result=0, constant_result=0,
type=PyrexTypes.c_int_type), type=PyrexTypes.c_int_type),
base=counter_temp, base=counter_temp,
is_buffer_access=False,
type=ptr_type.base_type) type=ptr_type.base_type)
if target_value.type != node.target.type: if target_value.type != node.target.type:
...@@ -1334,20 +1333,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform): ...@@ -1334,20 +1333,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
node = node.arg node = node.arg
name_path = [] name_path = []
obj_node = node obj_node = node
while isinstance(obj_node, ExprNodes.AttributeNode): while obj_node.is_attribute:
if obj_node.is_py_attr: if obj_node.is_py_attr:
return False return False
name_path.append(obj_node.member) name_path.append(obj_node.member)
obj_node = obj_node.obj obj_node = obj_node.obj
if isinstance(obj_node, ExprNodes.NameNode): if obj_node.is_name:
name_path.append(obj_node.name) name_path.append(obj_node.name)
names.append( ('.'.join(name_path[::-1]), node) ) names.append( ('.'.join(name_path[::-1]), node) )
elif isinstance(node, ExprNodes.IndexNode): elif node.is_subscript:
if node.base.type != Builtin.list_type: if node.base.type != Builtin.list_type:
return False return False
if not node.index.type.is_int: if not node.index.type.is_int:
return False return False
if not isinstance(node.base, ExprNodes.NameNode): if not node.base.is_name:
return False return False
indices.append(node) indices.append(node)
else: else:
...@@ -1979,7 +1978,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -1979,7 +1978,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
elif isinstance(arg, ExprNodes.SimpleCallNode): elif isinstance(arg, ExprNodes.SimpleCallNode):
if node.type.is_int or node.type.is_float: if node.type.is_int or node.type.is_float:
return self._optimise_numeric_cast_call(node, arg) return self._optimise_numeric_cast_call(node, arg)
elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access: elif arg.is_subscript:
index_node = arg.index index_node = arg.index
if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
index_node = index_node.arg index_node = index_node.arg
......
...@@ -17,7 +17,7 @@ from . import Builtin ...@@ -17,7 +17,7 @@ from . import Builtin
from .Visitor import VisitorTransform, TreeVisitor from .Visitor import VisitorTransform, TreeVisitor
from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from .UtilNodes import LetNode, LetRefNode, ResultRefNode from .UtilNodes import LetNode, LetRefNode
from .TreeFragment import TreeFragment from .TreeFragment import TreeFragment
from .StringEncoding import EncodedString, _unicode from .StringEncoding import EncodedString, _unicode
from .Errors import error, warning, CompileError, InternalError from .Errors import error, warning, CompileError, InternalError
...@@ -1931,13 +1931,8 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1931,13 +1931,8 @@ class AnalyseExpressionsTransform(CythonTransform):
re-analyse the types. re-analyse the types.
""" """
self.visit_Node(node) self.visit_Node(node)
if node.is_fused_index and not node.type.is_error: if node.is_fused_index and not node.type.is_error:
node = node.base node = node.base
elif node.memslice_ellipsis_noop:
# memoryviewslice[...] expression, drop the IndexNode
node = node.base
return node return node
...@@ -1971,26 +1966,26 @@ class ExpandInplaceOperators(EnvTransform): ...@@ -1971,26 +1966,26 @@ class ExpandInplaceOperators(EnvTransform):
if lhs.type.is_cpp_class: if lhs.type.is_cpp_class:
# No getting around this exact operator here. # No getting around this exact operator here.
return node return node
if isinstance(lhs, ExprNodes.IndexNode) and lhs.is_buffer_access: if isinstance(lhs, ExprNodes.BufferIndexNode):
# There is code to handle this case. # There is code to handle this case in InPlaceAssignmentNode
return node return node
env = self.current_env() env = self.current_env()
def side_effect_free_reference(node, setting=False): def side_effect_free_reference(node, setting=False):
if isinstance(node, ExprNodes.NameNode): if node.is_name:
return node, [] return node, []
elif node.type.is_pyobject and not setting: elif node.type.is_pyobject and not setting:
node = LetRefNode(node) node = LetRefNode(node)
return node, [node] return node, [node]
elif isinstance(node, ExprNodes.IndexNode): elif node.is_subscript:
if node.is_buffer_access:
raise ValueError("Buffer access")
base, temps = side_effect_free_reference(node.base) base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index) index = LetRefNode(node.index)
return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index] return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, ExprNodes.AttributeNode): elif node.is_attribute:
obj, temps = side_effect_free_reference(node.obj) obj, temps = side_effect_free_reference(node.obj)
return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
elif isinstance(node, ExprNodes.BufferIndexNode):
raise ValueError("Don't allow things like attributes of buffer indexing operations")
else: else:
node = LetRefNode(node) node = LetRefNode(node)
return node, [node] return node, [node]
......
...@@ -541,7 +541,7 @@ class MemoryViewSliceType(PyrexType): ...@@ -541,7 +541,7 @@ class MemoryViewSliceType(PyrexType):
the *first* axis' packing spec and 'follow' for all other packing the *first* axis' packing spec and 'follow' for all other packing
specs. specs.
""" """
from . import MemoryView from . import Buffer, MemoryView
self.dtype = base_dtype self.dtype = base_dtype
self.axes = axes self.axes = axes
...@@ -555,7 +555,7 @@ class MemoryViewSliceType(PyrexType): ...@@ -555,7 +555,7 @@ class MemoryViewSliceType(PyrexType):
self.writable_needed = False self.writable_needed = False
if not self.dtype.is_fused: if not self.dtype.is_fused:
self.dtype_name = MemoryView.mangle_dtype_name(self.dtype) self.dtype_name = Buffer.mangle_dtype_name(self.dtype)
def __hash__(self): def __hash__(self):
return hash(self.__class__) ^ hash(self.dtype) ^ hash(tuple(self.axes)) return hash(self.__class__) ^ hash(self.dtype) ^ hash(tuple(self.axes))
...@@ -638,25 +638,28 @@ class MemoryViewSliceType(PyrexType): ...@@ -638,25 +638,28 @@ class MemoryViewSliceType(PyrexType):
elif attribute in ("copy", "copy_fortran"): elif attribute in ("copy", "copy_fortran"):
ndim = len(self.axes) ndim = len(self.axes)
to_axes_c = [('direct', 'contig')] follow_dim = [('direct', 'follow')]
to_axes_f = [('direct', 'contig')] contig_dim = [('direct', 'contig')]
if ndim - 1: to_axes_c = follow_dim * (ndim - 1) + contig_dim
to_axes_c = [('direct', 'follow')]*(ndim-1) + to_axes_c to_axes_f = contig_dim + follow_dim * (ndim -1)
to_axes_f = to_axes_f + [('direct', 'follow')]*(ndim-1)
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c) to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f) to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f)
for to_memview, cython_name in [(to_memview_c, "copy"), for to_memview, cython_name in [(to_memview_c, "copy"),
(to_memview_f, "copy_fortran")]: (to_memview_f, "copy_fortran")]:
entry = scope.declare_cfunction(cython_name, copy_func_type = CFuncType(
CFuncType(self, [CFuncTypeArg("memviewslice", self, None)]), to_memview,
pos=pos, [CFuncTypeArg("memviewslice", self, None)])
defining=1, copy_cname = MemoryView.copy_c_or_fortran_cname(to_memview)
cname=MemoryView.copy_c_or_fortran_cname(to_memview))
entry = scope.declare_cfunction(
cython_name,
copy_func_type, pos=pos, defining=1,
cname=copy_cname)
#entry.utility_code_definition = \ utility = MemoryView.get_copy_new_utility(pos, self, to_memview)
env.use_utility_code(MemoryView.get_copy_new_utility(pos, self, to_memview)) env.use_utility_code(utility)
MemoryView.use_cython_array_utility_code(env) MemoryView.use_cython_array_utility_code(env)
...@@ -684,9 +687,102 @@ class MemoryViewSliceType(PyrexType): ...@@ -684,9 +687,102 @@ class MemoryViewSliceType(PyrexType):
return True return True
def get_entry(self, node, cname=None, type=None):
from . import MemoryView, Symtab
if cname is None:
assert node.is_simple() or node.is_temp or node.is_elemental
cname = node.result()
if type is None:
type = node.type
entry = Symtab.Entry(cname, cname, type, node.pos)
return MemoryView.MemoryViewSliceBufferEntry(entry)
def conforms_to(self, dst, broadcast=False, copying=False):
"""
Returns True if src conforms to dst, False otherwise.
If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
Any packing/access spec is conformable to itself.
'direct' and 'ptr' are conformable to 'full'.
'contig' and 'follow' are conformable to 'strided'.
Any other combo is not conformable.
"""
from . import MemoryView
src = self
if src.dtype != dst.dtype:
return False
if src.ndim != dst.ndim:
if broadcast:
src, dst = MemoryView.broadcast_types(src, dst)
else:
return False
for src_spec, dst_spec in zip(src.axes, dst.axes):
src_access, src_packing = src_spec
dst_access, dst_packing = dst_spec
if src_access != dst_access and dst_access != 'full':
return False
if src_packing != dst_packing and dst_packing != 'strided' and not copying:
return False
return True
def valid_dtype(self, dtype, i=0):
"""
Return whether type dtype can be used as the base type of a
memoryview slice.
We support structs, numeric types and objects
"""
if dtype.is_complex and dtype.real_type.is_int:
return False
if dtype.is_struct and dtype.kind == 'struct':
for member in dtype.scope.var_entries:
if not self.valid_dtype(member.type):
return False
return True
return (
dtype.is_error or
# Pointers are not valid (yet)
# (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
(dtype.is_array and i < 8 and self.valid_dtype(dtype.base_type, i + 1)) or
dtype.is_numeric or
dtype.is_pyobject or
dtype.is_fused or # accept this as it will be replaced by specializations later
(dtype.is_typedef and self.valid_dtype(dtype.typedef_base_type))
)
def validate_memslice_dtype(self, pos):
if not self.valid_dtype(self.dtype):
error(pos, "Invalid base type for memoryview slice: %s" % self.dtype)
def assert_direct_dims(self, pos):
for access, packing in self.axes:
if access != 'direct':
error(pos, "All dimensions must be direct")
return False
return True
def transpose(self, pos):
if not self.assert_direct_dims(pos):
return error_type
return MemoryViewSliceType(self.dtype, self.axes[::-1])
def specialization_name(self): def specialization_name(self):
return super(MemoryViewSliceType,self).specialization_name() \ return '%s_%s' % (
+ '_' + self.specialization_suffix() super(MemoryViewSliceType,self).specialization_name(),
self.specialization_suffix())
def specialization_suffix(self): def specialization_suffix(self):
return "%s_%s" % (self.axes_to_name(), self.dtype_name) return "%s_%s" % (self.axes_to_name(), self.dtype_name)
...@@ -874,6 +970,11 @@ class BufferType(BaseType): ...@@ -874,6 +970,11 @@ class BufferType(BaseType):
self.negative_indices, self.cast) self.negative_indices, self.cast)
return self return self
def get_entry(self, node):
from . import Buffer
assert node.is_name
return Buffer.BufferEntry(node.entry)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.base, name) return getattr(self.base, name)
......
...@@ -79,7 +79,7 @@ cdef extern from *: ...@@ -79,7 +79,7 @@ cdef extern from *:
size_t sizeof_dtype, int contig_flag, size_t sizeof_dtype, int contig_flag,
bint dtype_is_object) nogil except * bint dtype_is_object) nogil except *
bint slice_is_contig "__pyx_memviewslice_is_contig" ( bint slice_is_contig "__pyx_memviewslice_is_contig" (
{{memviewslice_name}} *mvs, char order, int ndim) nogil {{memviewslice_name}} mvs, char order, int ndim) nogil
bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1, bint slices_overlap "__pyx_slices_overlap" ({{memviewslice_name}} *slice1,
{{memviewslice_name}} *slice2, {{memviewslice_name}} *slice2,
int ndim, size_t itemsize) nogil int ndim, size_t itemsize) nogil
...@@ -578,13 +578,13 @@ cdef class memoryview(object): ...@@ -578,13 +578,13 @@ cdef class memoryview(object):
cdef {{memviewslice_name}} *mslice cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp) mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'C', self.view.ndim) return slice_is_contig(mslice[0], 'C', self.view.ndim)
def is_f_contig(self): def is_f_contig(self):
cdef {{memviewslice_name}} *mslice cdef {{memviewslice_name}} *mslice
cdef {{memviewslice_name}} tmp cdef {{memviewslice_name}} tmp
mslice = get_slice_from_memview(self, &tmp) mslice = get_slice_from_memview(self, &tmp)
return slice_is_contig(mslice, 'F', self.view.ndim) return slice_is_contig(mslice[0], 'F', self.view.ndim)
def copy(self): def copy(self):
cdef {{memviewslice_name}} mslice cdef {{memviewslice_name}} mslice
...@@ -1195,7 +1195,7 @@ cdef void *copy_data_to_temp({{memviewslice_name}} *src, ...@@ -1195,7 +1195,7 @@ cdef void *copy_data_to_temp({{memviewslice_name}} *src,
if tmpslice.shape[i] == 1: if tmpslice.shape[i] == 1:
tmpslice.strides[i] = 0 tmpslice.strides[i] = 0
if slice_is_contig(src, order, ndim): if slice_is_contig(src[0], order, ndim):
memcpy(result, src.data, size) memcpy(result, src.data, size)
else: else:
copy_strided_to_strided(src, tmpslice, ndim, itemsize) copy_strided_to_strided(src, tmpslice, ndim, itemsize)
...@@ -1258,7 +1258,7 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src, ...@@ -1258,7 +1258,7 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if slices_overlap(&src, &dst, ndim, itemsize): if slices_overlap(&src, &dst, ndim, itemsize):
# slices overlap, copy to temp, copy temp to dst # slices overlap, copy to temp, copy temp to dst
if not slice_is_contig(&src, order, ndim): if not slice_is_contig(src, order, ndim):
order = get_best_order(&dst, ndim) order = get_best_order(&dst, ndim)
tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) tmpdata = copy_data_to_temp(&src, &tmp, order, ndim)
...@@ -1267,10 +1267,10 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src, ...@@ -1267,10 +1267,10 @@ cdef int memoryview_copy_contents({{memviewslice_name}} src,
if not broadcasting: if not broadcasting:
# See if both slices have equal contiguity, in that case perform a # See if both slices have equal contiguity, in that case perform a
# direct copy. This only works when we are not broadcasting. # direct copy. This only works when we are not broadcasting.
if slice_is_contig(&src, 'C', ndim): if slice_is_contig(src, 'C', ndim):
direct_copy = slice_is_contig(&dst, 'C', ndim) direct_copy = slice_is_contig(dst, 'C', ndim)
elif slice_is_contig(&src, 'F', ndim): elif slice_is_contig(src, 'F', ndim):
direct_copy = slice_is_contig(&dst, 'F', ndim) direct_copy = slice_is_contig(dst, 'F', ndim)
if direct_copy: if direct_copy:
# Contiguous slices with same order # Contiguous slices with same order
......
...@@ -692,29 +692,29 @@ __pyx_slices_overlap({{memviewslice_name}} *slice1, ...@@ -692,29 +692,29 @@ __pyx_slices_overlap({{memviewslice_name}} *slice1,
////////// MemviewSliceIsCContig.proto ////////// ////////// MemviewSliceIsCContig.proto //////////
#define __pyx_memviewslice_is_c_contig{{ndim}}(slice) \ #define __pyx_memviewslice_is_c_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'C', {{ndim}}) __pyx_memviewslice_is_contig(slice, 'C', {{ndim}})
////////// MemviewSliceIsFContig.proto ////////// ////////// MemviewSliceIsFContig.proto //////////
#define __pyx_memviewslice_is_f_contig{{ndim}}(slice) \ #define __pyx_memviewslice_is_f_contig{{ndim}}(slice) \
__pyx_memviewslice_is_contig(&slice, 'F', {{ndim}}) __pyx_memviewslice_is_contig(slice, 'F', {{ndim}})
////////// MemviewSliceIsContig.proto ////////// ////////// MemviewSliceIsContig.proto //////////
static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, static int __pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim); char order, int ndim);
////////// MemviewSliceIsContig ////////// ////////// MemviewSliceIsContig //////////
static int static int
__pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, __pyx_memviewslice_is_contig(const {{memviewslice_name}} mvs,
char order, int ndim) char order, int ndim)
{ {
int i, index, step, start; int i, index, step, start;
Py_ssize_t itemsize = mvs->memview->view.itemsize; Py_ssize_t itemsize = mvs.memview->view.itemsize;
if (order == 'F') { if (order == 'F') {
step = 1; step = 1;
...@@ -726,10 +726,10 @@ __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs, ...@@ -726,10 +726,10 @@ __pyx_memviewslice_is_contig(const {{memviewslice_name}} *mvs,
for (i = 0; i < ndim; i++) { for (i = 0; i < ndim; i++) {
index = start + step * i; index = start + step * i;
if (mvs->suboffsets[index] >= 0 || mvs->strides[index] != itemsize) if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize)
return 0; return 0;
itemsize *= mvs->shape[index]; itemsize *= mvs.shape[index];
} }
return 1; return 1;
......
...@@ -14,6 +14,7 @@ from cython.view cimport array ...@@ -14,6 +14,7 @@ from cython.view cimport array
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
@testcase @testcase
def test_shape_stride_suboffset(): def test_shape_stride_suboffset():
u''' u'''
...@@ -47,6 +48,7 @@ def test_shape_stride_suboffset(): ...@@ -47,6 +48,7 @@ def test_shape_stride_suboffset():
print c_contig.strides[0], c_contig.strides[1], c_contig.strides[2] print c_contig.strides[0], c_contig.strides[1], c_contig.strides[2]
print c_contig.suboffsets[0], c_contig.suboffsets[1], c_contig.suboffsets[2] print c_contig.suboffsets[0], c_contig.suboffsets[1], c_contig.suboffsets[2]
@testcase @testcase
def test_copy_to(): def test_copy_to():
u''' u'''
...@@ -57,15 +59,19 @@ def test_copy_to(): ...@@ -57,15 +59,19 @@ def test_copy_to():
''' '''
cdef int[:, :, :] from_mvs, to_mvs cdef int[:, :, :] from_mvs, to_mvs
from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2) from_mvs = np.arange(8, dtype=np.int32).reshape(2,2,2)
cdef int *from_data = <int *> from_mvs._data cdef int *from_data = <int *> from_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2)) print ' '.join(str(from_data[i]) for i in range(2*2*2))
to_mvs = array((2,2,2), sizeof(int), 'i') to_mvs = array((2,2,2), sizeof(int), 'i')
to_mvs[...] = from_mvs to_mvs[...] = from_mvs
# TODO Mark: remove this _data attribute
cdef int *to_data = <int*>to_mvs._data cdef int *to_data = <int*>to_mvs._data
print ' '.join(str(from_data[i]) for i in range(2*2*2)) print ' '.join(str(from_data[i]) for i in range(2*2*2))
print ' '.join(str(to_data[i]) for i in range(2*2*2)) print ' '.join(str(to_data[i]) for i in range(2*2*2))
@testcase @testcase
def test_overlapping_copy(): def test_overlapping_copy():
""" """
...@@ -81,6 +87,22 @@ def test_overlapping_copy(): ...@@ -81,6 +87,22 @@ def test_overlapping_copy():
for i in range(10): for i in range(10):
assert slice[i] == 10 - 1 - i assert slice[i] == 10 - 1 - i
@testcase
def test_copy_return_type():
"""
>>> test_copy_return_type()
60.0
60.0
"""
cdef double[:, :, :] a = np.arange(5 * 5 * 5, dtype=np.float64).reshape(5, 5, 5)
cdef double[:, ::1] c_contig = a[..., 0].copy()
cdef double[::1, :] f_contig = a[..., 0].copy_fortran()
print(c_contig[2, 2])
print(f_contig[2, 2])
@testcase @testcase
def test_partly_overlapping(): def test_partly_overlapping():
""" """
...@@ -170,30 +192,34 @@ def test_copy_mismatch(): ...@@ -170,30 +192,34 @@ def test_copy_mismatch():
mv1[...] = mv2 mv1[...] = mv2
@testcase @testcase
def test_is_contiguous(): def test_is_contiguous():
u''' u"""
>>> test_is_contiguous() >>> test_is_contiguous()
True True one sized is_c/f_contig True True
False True is_c/f_contig False True
True False f_contig.copy().is_c/f_contig True False
True False f_contig.copy_fortran().is_c/f_contig False True
<BLANKLINE> one sized strided contig True True
False True strided False
True False """
'''
cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran') cdef int[::1, :, :] fort_contig = array((1,1,1), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig() , fort_contig.is_f_contig()
fort_contig = array((200,100,100), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = fort_contig.copy()
print fort_contig.is_c_contig(), fort_contig.is_f_contig()
cdef int[:,:,:] strided = fort_contig cdef int[:,:,:] strided = fort_contig
print strided.is_c_contig(), strided.is_f_contig()
print print 'one sized is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
fort_contig = fort_contig.copy_fortran() fort_contig = array((2,2,2), sizeof(int), 'i', mode='fortran')
print fort_contig.is_c_contig(), fort_contig.is_f_contig() print 'is_c/f_contig', fort_contig.is_c_contig(), fort_contig.is_f_contig()
print strided.is_c_contig(), strided.is_f_contig()
print 'f_contig.copy().is_c/f_contig', fort_contig.copy().is_c_contig(), \
fort_contig.copy().is_f_contig()
print 'f_contig.copy_fortran().is_c/f_contig', \
fort_contig.copy_fortran().is_c_contig(), \
fort_contig.copy_fortran().is_f_contig()
print 'one sized strided contig', strided.is_c_contig(), strided.is_f_contig()
print 'strided', strided[::2].is_c_contig()
@testcase @testcase
...@@ -272,6 +298,7 @@ def two_dee(): ...@@ -272,6 +298,7 @@ def two_dee():
print (<long*>mv3._data)[0] , (<long*>mv3._data)[1] , (<long*>mv3._data)[2] , (<long*>mv3._data)[3] print (<long*>mv3._data)[0] , (<long*>mv3._data)[1] , (<long*>mv3._data)[2] , (<long*>mv3._data)[3]
@testcase @testcase
def fort_two_dee(): def fort_two_dee():
u''' u'''
...@@ -283,7 +310,8 @@ def fort_two_dee(): ...@@ -283,7 +310,8 @@ def fort_two_dee():
1 2 3 -4 1 2 3 -4
''' '''
cdef array arr = array((2,2), sizeof(long), 'l', mode='fortran') cdef array arr = array((2,2), sizeof(long), 'l', mode='fortran')
cdef long[::1,:] mv1, mv2, mv3 cdef long[::1,:] mv1, mv2, mv4
cdef long[:, ::1] mv3
cdef long *arr_data cdef long *arr_data
arr_data = <long*>arr.data arr_data = <long*>arr.data
...@@ -311,6 +339,6 @@ def fort_two_dee(): ...@@ -311,6 +339,6 @@ def fort_two_dee():
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3] print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3]
mv3 = mv3.copy_fortran() mv4 = mv3.copy_fortran()
print (<long*>mv3._data)[0], (<long*>mv3._data)[1], (<long*>mv3._data)[2], (<long*>mv3._data)[3] print (<long*>mv4._data)[0], (<long*>mv4._data)[1], (<long*>mv4._data)[2], (<long*>mv4._data)[3]
...@@ -163,6 +163,7 @@ def test_ellipsis_memoryview(array): ...@@ -163,6 +163,7 @@ def test_ellipsis_memoryview(array):
ae(e.shape[0], e_obj.shape[0]) ae(e.shape[0], e_obj.shape[0])
ae(e.strides[0], e_obj.strides[0]) ae(e.strides[0], e_obj.strides[0])
@testcase @testcase
def test_transpose(): def test_transpose():
""" """
...@@ -193,6 +194,20 @@ def test_transpose(): ...@@ -193,6 +194,20 @@ def test_transpose():
print a[3, 2], a.T[2, 3], a_obj[3, 2], a_obj.T[2, 3], numpy_obj[3, 2], numpy_obj.T[2, 3] print a[3, 2], a.T[2, 3], a_obj[3, 2], a_obj.T[2, 3], numpy_obj[3, 2], numpy_obj.T[2, 3]
@testcase
def test_transpose_type(a):
"""
>>> a = np.zeros((5, 10), dtype=np.float64)
>>> a[4, 6] = 9
>>> test_transpose_type(a)
9.0
"""
cdef double[:, ::1] m = a
cdef double[::1, :] m_transpose = a.T
print m_transpose[6, 4]
@testcase_numpy_1_5 @testcase_numpy_1_5
def test_numpy_like_attributes(cyarray): def test_numpy_like_attributes(cyarray):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment