Commit 4fce6db0 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

cleanup in MemoryViewSliceType

parent b297ae99
......@@ -302,7 +302,42 @@ no_fail:
}
'''
def get_copy_contents_code(from_mvs, to_mvs, cfunc_name):
def memoryviewslice_get_copy_func(from_memview, to_memview, mode, scope):
from PyrexTypes import CFuncType, CFuncTypeArg
if mode == 'c':
cython_name = "copy"
copy_name = '__Pyx_BufferNew_C_From_'+from_memview.specialization_suffix()
contig_flag = 'PyBUF_C_CONTIGUOUS'
elif mode == 'fortran':
cython_name = "copy_fortran"
copy_name = "__Pyx_BufferNew_F_From_"+from_memview.specialization_suffix()
contig_flag = 'PyBUF_F_CONTIGUOUS'
else:
assert False
copy_contents_name = get_copy_contents_name(from_memview, to_memview)
scope.declare_cfunction(cython_name,
CFuncType(from_memview,
[CFuncTypeArg("memviewslice", from_memview, None)]),
pos = None,
defining = 1,
cname = copy_name)
copy_impl = copy_template % dict(
copy_name=copy_name,
mode=mode,
sizeof_dtype="sizeof(%s)" % from_memview.dtype.declaration_code(''),
contig_flag=contig_flag,
copy_contents_name=copy_contents_name)
copy_decl = ("static __Pyx_memviewslice "
"%s(const __Pyx_memviewslice); /* proto */\n" % (copy_name,))
return (copy_decl, copy_impl)
def get_copy_contents_func(from_mvs, to_mvs, cfunc_name):
assert from_mvs.dtype == to_mvs.dtype
assert len(from_mvs.axes) == len(to_mvs.axes)
......@@ -313,7 +348,10 @@ def get_copy_contents_code(from_mvs, to_mvs, cfunc_name):
if access != 'direct':
raise NotImplementedError("only direct access supported currently.")
code = '''
code_decl = ("static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs,"
"__Pyx_memviewslice *to_mvs); /* proto */" % {'cfunc_name' : cfunc_name})
code_impl = '''
static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice *to_mvs) {
......@@ -338,44 +376,44 @@ static int %(cfunc_name)s(const __Pyx_memviewslice *from_mvs, __Pyx_memviewslice
# 'i' always goes up from zero to ndim-1.
# 'idx' is the same as 'i' for c_contig, and goes from ndim-1 to 0 for f_contig.
# this makes the loop code below identical in both cases.
code += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i)
code += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx}
code += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx}
code_impl += INDENT+"Py_ssize_t i%d = 0, idx%d = 0;\n" % (i,i)
code_impl += INDENT+"Py_ssize_t stride%(i)d = from_mvs->diminfo[%(idx)d].strides;\n" % {'i':i, 'idx':idx}
code_impl += INDENT+"Py_ssize_t shape%(i)d = from_mvs->diminfo[%(idx)d].shape;\n" % {'i':i, 'idx':idx}
code += "\n"
code_impl += "\n"
# put down the nested for-loop.
for k in range(ndim):
code += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k}
code_impl += INDENT*(k+1) + "for(i%(k)d=0; i%(k)d<shape%(k)d; i%(k)d++) {\n" % {'k' : k}
if k >= 1:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1}
code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d + idx%(km1)d;\n" % {'k' : k, 'km1' : k-1}
else:
code += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}
code_impl += INDENT*(k+2) + "idx%(k)d = i%(k)d * stride%(k)d;\n" % {'k' : k}
# the inner part of the loop.
dtype_decl = from_mvs.dtype.declaration_code("")
last_idx = ndim-1
code += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
code += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
code_impl += INDENT*ndim+"memcpy(to_buf, from_buf+idx%(last_idx)d, sizeof(%(dtype_decl)s));\n" % locals()
code_impl += INDENT*ndim+"to_buf += sizeof(%(dtype_decl)s);\n" % locals()
# for-loop closing braces
for k in range(ndim-1, -1, -1):
code += INDENT*(k+1)+"}\n"
code_impl += INDENT*(k+1)+"}\n"
# init to_mvs->data and to_mvs->diminfo.
code += INDENT+"temp_memview = to_mvs->memview;\n"
code += INDENT+"temp_data = to_mvs->data;\n"
code += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n"
code += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,)
code += INDENT*2+"return -1;\n"
code += INDENT+"}\n"
code_impl += INDENT+"temp_memview = to_mvs->memview;\n"
code_impl += INDENT+"temp_data = to_mvs->data;\n"
code_impl += INDENT+"to_mvs->memview = 0; to_mvs->data = 0;\n"
code_impl += INDENT+"if(unlikely(-1 == __Pyx_init_memviewslice(temp_memview, %d, to_mvs))) {\n" % (ndim,)
code_impl += INDENT*2+"return -1;\n"
code_impl += INDENT+"}\n"
code += INDENT + "return 0;\n"
code_impl += INDENT + "return 0;\n"
code += '}\n'
code_impl += '}\n'
return code
return code_decl, code_impl
def get_axes_specs(env, axes):
'''
......
......@@ -406,71 +406,42 @@ class MemoryViewSliceType(PyrexType):
to_memview_c = MemoryViewSliceType(self.dtype, to_axes_c, self.env)
to_memview_f = MemoryViewSliceType(self.dtype, to_axes_f, self.env)
cython_name_c = 'copy'
cython_name_f = 'copy_fortran'
copy_name_c = '__Pyx_BufferNew_C_From_'+self.specialization_suffix()
copy_name_f = '__Pyx_BufferNew_F_From_'+self.specialization_suffix()
c_copy_util_code = UtilityCode()
f_copy_util_code = UtilityCode()
for (to_memview, copy_name, cython_name, mode, contig_flag, util_code) in (
(to_memview_c, copy_name_c, cython_name_c, 'c', 'PyBUF_C_CONTIGUOUS', c_copy_util_code),
(to_memview_f, copy_name_f, cython_name_f, 'fortran', 'PyBUF_F_CONTIGUOUS', f_copy_util_code)):
copy_contents_name = MemoryView.get_copy_contents_name(self, to_memview)
scope.declare_cfunction(cython_name,
CFuncType(self,
[CFuncTypeArg("memviewslice", self, None)]),
pos = None,
defining = 1,
cname = copy_name)
copy_impl = MemoryView.copy_template %\
dict(copy_name=copy_name,
mode=mode,
sizeof_dtype="sizeof(%s)" % self.dtype.declaration_code(''),
contig_flag=contig_flag,
copy_contents_name=copy_contents_name)
copy_decl = '''\
static __Pyx_memviewslice %s(const __Pyx_memviewslice); /* proto */
''' % (copy_name,)
util_code.proto = copy_decl
util_code.impl = copy_impl
copy_contents_name_c = MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f = MemoryView.get_copy_contents_name(self, to_memview_f)
c_copy_util_code.proto += ('static int %s'
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_c,))
c_copy_util_code.impl += \
MemoryView.get_copy_contents_code(self, to_memview_c, copy_contents_name_c)
copy_contents_name_c =\
MemoryView.get_copy_contents_name(self, to_memview_c)
copy_contents_name_f =\
MemoryView.get_copy_contents_name(self, to_memview_f)
c_copy_decl, c_copy_impl = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_c, 'c', self.scope)
f_copy_decl, f_copy_impl = \
MemoryView.memoryviewslice_get_copy_func(self, to_memview_f, 'fortran', self.scope)
c_copy_contents_decl, c_copy_contents_impl = \
MemoryView.get_copy_contents_func(
self, to_memview_c, copy_contents_name_c)
f_copy_contents_decl, f_copy_contents_impl = \
MemoryView.get_copy_contents_func(
self, to_memview_f, copy_contents_name_f)
c_util_code = UtilityCode(
proto = "%s%s" % (c_copy_decl, c_copy_contents_decl),
impl = "%s%s" % (c_copy_impl, c_copy_contents_impl))
f_util_code = UtilityCode(
proto = f_copy_decl,
impl = f_copy_impl)
if copy_contents_name_c != copy_contents_name_f:
f_copy_util_code.proto += ('static int %s'
'(const __Pyx_memviewslice *,'
' __Pyx_memviewslice *); /* proto */\n' %
(copy_contents_name_f,))
f_copy_util_code.impl += \
MemoryView.get_copy_contents_code(self, to_memview_f, copy_contents_name_f)
f_util_code.proto += f_copy_contents_decl
f_util_code.impl += f_copy_contents_impl
c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_copy_util_code.proto == util_code.proto]
f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_copy_util_code.proto == util_code.proto]
c_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if c_util_code.proto == util_code.proto]
f_copy_used = [1 for util_code in self.env.global_scope().utility_code_list if f_util_code.proto == util_code.proto]
if not c_copy_used:
self.env.use_utility_code(c_copy_util_code)
self.env.use_utility_code(c_util_code)
if not f_copy_used:
self.env.use_utility_code(f_copy_util_code)
self.env.use_utility_code(f_util_code)
# is_c_contiguous and is_f_contiguous functions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment