Commit b7a1be05 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

refactored assignment to memoryviewslice for return values.

parent 4ce841dd
...@@ -1655,7 +1655,10 @@ class NameNode(AtomicExprNode): ...@@ -1655,7 +1655,10 @@ class NameNode(AtomicExprNode):
rhs.free_temps(code) rhs.free_temps(code)
else: else:
if self.type.is_memoryviewslice: if self.type.is_memoryviewslice:
self.generate_acquire_memoryviewslice(rhs, code) import MemoryView
MemoryView.gen_acquire_memoryviewslice(rhs, self.type,
self.entry.is_cglobal, self.result(), self.pos, code)
# self.generate_acquire_memoryviewslice(rhs, code)
elif self.type.is_buffer: elif self.type.is_buffer:
# Generate code for doing the buffer release/acquisition. # Generate code for doing the buffer release/acquisition.
...@@ -1699,30 +1702,6 @@ class NameNode(AtomicExprNode): ...@@ -1699,30 +1702,6 @@ class NameNode(AtomicExprNode):
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
rhs.free_temps(code) rhs.free_temps(code)
def generate_acquire_memoryviewslice(self, rhs, code):
# to explicitly manange the memviewslice.memview object correctly.
import MemoryView
assert rhs.type.is_memoryviewslice
if not rhs.result_in_temp():
code.put_incref("%s.memview" % rhs.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_gotref("%s.memview" % self.result())
if not self.lhs_of_first_assignment:
if self.entry.is_local and not Options.init_local_none:
code.put_xdecref("%s.memview" % self.result(), cython_memoryview_ptr_type)
else:
code.put_decref("%s.memview" % self.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_giveref("%s.memview" % rhs.result())
MemoryView.put_assign_to_memviewslice(self.result(), rhs.result(), self.type,
pos=self.pos, code=code)
if rhs.result_in_temp():
code.putln("%s.memview = 0;" % rhs.result())
def generate_acquire_buffer(self, rhs, code): def generate_acquire_buffer(self, rhs, code):
# rhstmp is only used in case the rhs is a complicated expression leading to # rhstmp is only used in case the rhs is a complicated expression leading to
# the object, to avoid repeating the same C expression for every reference # the object, to avoid repeating the same C expression for every reference
...@@ -7623,7 +7602,7 @@ class CoerceToMemViewSliceNode(CoercionNode): ...@@ -7623,7 +7602,7 @@ class CoerceToMemViewSliceNode(CoercionNode):
buf_flag = MemoryView.get_buf_flag(self.type.axes) buf_flag = MemoryView.get_buf_flag(self.type.axes)
code.putln("%s = (PyObject *)" code.putln("%s = (PyObject *)"
"__pyx_viewaxis_memoryview_cwrapper(%s, %s);" %\ "__pyx_viewaxis_memoryview_cwrapper(%s, %s);" %\
(memviewobj, self.arg.result(), buf_flag)) (memviewobj, self.arg.py_result(), buf_flag))
ndim = len(self.type.axes) ndim = len(self.type.axes)
spec_int_arr = code.funcstate.allocate_temp( spec_int_arr = code.funcstate.allocate_temp(
PyrexTypes.c_array_type(PyrexTypes.c_int_type, ndim), PyrexTypes.c_array_type(PyrexTypes.c_int_type, ndim),
......
...@@ -67,8 +67,41 @@ def format_from_type(base_type): ...@@ -67,8 +67,41 @@ def format_from_type(base_type):
def put_init_entry(mv_cname, code): def put_init_entry(mv_cname, code):
code.putln("%s.data = NULL;" % mv_cname) code.putln("%s.data = NULL;" % mv_cname)
code.put_init_to_py_none("%s.memview" % mv_cname, cython_memoryview_ptr_type) code.putln("%s.memview = NULL;" % mv_cname)
code.put_giveref(code.as_pyobject("%s.memview" % mv_cname, cython_memoryview_ptr_type))
def gen_acquire_memoryviewslice(rhs, lhs_type, lhs_is_cglobal, lhs_result, lhs_pos, code):
# import MemoryView
assert rhs.type.is_memoryviewslice
pretty_rhs = isinstance(rhs, NameNode) or rhs.result_in_temp()
if pretty_rhs:
rhstmp = rhs.result()
else:
rhstmp = code.funcstate.allocate_temp(lhs_type, manage_ref=False)
code.putln("%s = %s;" % (rhstmp, rhs.result_as(lhs_type)))
if not rhs.result_in_temp():
code.put_incref("%s.memview" % rhstmp, cython_memoryview_ptr_type)
if lhs_is_cglobal:
code.put_gotref("%s.memview" % lhs_result)
#XXX: this is here because self.lhs_of_first_assignment is not set correctly,
# once that is working this should take that flag into account.
# See NameNode.generate_assignment_code
code.put_xdecref("%s.memview" % lhs_result, cython_memoryview_ptr_type)
if lhs_is_cglobal:
code.put_giveref("%s.memview" % rhstmp)
put_assign_to_memviewslice(lhs_result, rhstmp, lhs_type,
lhs_pos, code=code)
if rhs.result_in_temp() or not pretty_rhs:
code.putln("%s.memview = 0;" % rhstmp)
if not pretty_rhs:
code.funcstate.release_temp(rhstmp)
def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, pos, code): def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, pos, code):
......
...@@ -1349,6 +1349,9 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1349,6 +1349,9 @@ class FuncDefNode(StatNode, BlockNode):
"%s%s;" % "%s%s;" %
(self.return_type.declaration_code(Naming.retval_cname), (self.return_type.declaration_code(Naming.retval_cname),
init)) init))
if self.return_type.is_memoryviewslice:
import MemoryView
MemoryView.put_init_entry(Naming.retval_cname, code)
tempvardecl_code = code.insertion_point() tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code) self.generate_keyword_list(code)
...@@ -1561,6 +1564,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1561,6 +1564,8 @@ class FuncDefNode(StatNode, BlockNode):
err_val = default_retval err_val = default_retval
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
code.put_xgiveref(self.return_type.as_pyobject(Naming.retval_cname)) code.put_xgiveref(self.return_type.as_pyobject(Naming.retval_cname))
elif self.return_type.is_memoryviewslice:
code.put_xgiveref(code.as_pyobject("%s.memview" % Naming.retval_cname,cython_memoryview_ptr_type))
if self.entry.is_special and self.entry.name == "__hash__": if self.entry.is_special and self.entry.name == "__hash__":
# Returning -1 for __hash__ is supposed to signal an error # Returning -1 for __hash__ is supposed to signal an error
...@@ -4265,10 +4270,22 @@ class ReturnStatNode(StatNode): ...@@ -4265,10 +4270,22 @@ class ReturnStatNode(StatNode):
if self.return_type.is_pyobject: if self.return_type.is_pyobject:
code.put_xdecref(Naming.retval_cname, code.put_xdecref(Naming.retval_cname,
self.return_type) self.return_type)
elif self.return_type.is_memoryviewslice:
code.put_xdecref("%s.memview" % Naming.retval_cname,
self.return_type)
if self.value: if self.value:
self.value.generate_evaluation_code(code) self.value.generate_evaluation_code(code)
if self.return_type.is_memoryviewslice:
import MemoryView
MemoryView.gen_acquire_memoryviewslice(self.value, self.return_type,
False, Naming.retval_cname, None, code)
else:
self.value.make_owned_reference(code) self.value.make_owned_reference(code)
self.put_return(code, self.value.result_as(self.return_type)) code.putln(
"%s = %s;" % (
Naming.retval_cname,
self.value.result_as(self.return_type)))
self.value.generate_post_assignment_code(code) self.value.generate_post_assignment_code(code)
self.value.free_temps(code) self.value.free_temps(code)
else: else:
......
...@@ -317,6 +317,9 @@ class MemoryViewSliceType(PyrexType): ...@@ -317,6 +317,9 @@ class MemoryViewSliceType(PyrexType):
is_memoryviewslice = 1 is_memoryviewslice = 1
has_attributes = 1
scope = None
def __init__(self, base_dtype, axes, env): def __init__(self, base_dtype, axes, env):
''' '''
MemoryViewSliceType(base, axes) MemoryViewSliceType(base, axes)
...@@ -372,6 +375,15 @@ class MemoryViewSliceType(PyrexType): ...@@ -372,6 +375,15 @@ class MemoryViewSliceType(PyrexType):
MemoryView.memviewslice_cname, MemoryView.memviewslice_cname,
entity_code) entity_code)
def attributes_known(self):
if self.scope is None:
import Symtab
self.scope = Symtab.StructOrUnionScope(self.specalization_name())
# XXX: we don't necessarily want to have this exposed -- for
# testing purposes currently.
self.scope.declare_var("data", c_char_ptr_type, None, "data")
return True
def global_init_code(self, entry, code): def global_init_code(self, entry, code):
code.putln("%s.data = NULL;" % entry.cname) code.putln("%s.data = NULL;" % entry.cname)
code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_ptr_type, nanny=False) code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_ptr_type, nanny=False)
......
...@@ -2,6 +2,7 @@ u''' ...@@ -2,6 +2,7 @@ u'''
>>> f() >>> f()
>>> g() >>> g()
>>> call() >>> call()
>>> assignmvs()
''' '''
from cython.view cimport memoryview from cython.view cimport memoryview
...@@ -9,10 +10,22 @@ from cython cimport array, PyBUF_C_CONTIGUOUS ...@@ -9,10 +10,22 @@ from cython cimport array, PyBUF_C_CONTIGUOUS
def init_obj(): def init_obj():
return 3 return 3
cdef passmvs(float[:,::1] mvs, object foo): cdef passmvs(float[:,::1] mvs, object foo):
mvs = array((10,10), itemsize=sizeof(float), format='f') mvs = array((10,10), itemsize=sizeof(float), format='f')
foo = init_obj() foo = init_obj()
cdef object returnobj():
cdef obj = object()
return obj
cdef float[::1] returnmvs_inner():
return array((10,), itemsize=sizeof(float), format='f')
cdef float[::1] returnmvs():
cdef float[::1] mvs = returnmvs_inner()
return mvs
def f(): def f():
cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i') cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i')
cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS) cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS)
...@@ -43,10 +56,20 @@ cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f') ...@@ -43,10 +56,20 @@ cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f')
global_mv = array((10,10), itemsize=sizeof(float), format='f') global_mv = array((10,10), itemsize=sizeof(float), format='f')
cdef object global_obj cdef object global_obj
def assignmvs():
cdef int[::1] mv1, mv2
mv1 = array((10,), itemsize=sizeof(int), format='i')
mv2 = mv1
mv1 = mv2
def call(): def call():
global global_mv global global_mv
passmvs(global_mv, global_obj) passmvs(global_mv, global_obj)
global_mv = array((3,3), itemsize=sizeof(float), format='f') global_mv = array((3,3), itemsize=sizeof(float), format='f')
cdef float[::1] getmvs = returnmvs()
returnmvs()
cdef object obj = returnobj()
cdg() cdg()
f = Foo() f = Foo()
pf = pyfoo() pf = pyfoo()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment