Commit 766bfd93 authored by Kurt Smith's avatar Kurt Smith Committed by Mark Florisson

correct reference handling for memoryviewslices.

parent 985d50f3
......@@ -22,7 +22,7 @@ import Nodes
from Nodes import Node
import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \
unspecified_type
unspecified_type, cython_memoryview_ptr_type
import TypeSlots
from Builtin import list_type, tuple_type, set_type, dict_type, \
unicode_type, str_type, bytes_type, type_type
......@@ -1693,18 +1693,35 @@ class NameNode(AtomicExprNode):
code.put_giveref(rhs.py_result())
if not self.type.is_memoryviewslice:
code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
if debug_disposal_code:
print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code)
if debug_disposal_code:
print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs)
rhs.generate_post_assignment_code(code)
rhs.free_temps(code)
def generate_acquire_memoryviewslice(self, rhs, code):
# to explicitly manange the memviewslice.memview object correctly.
import MemoryView
assert rhs.type.is_memoryviewslice
if not rhs.result_in_temp():
code.put_incref("%s.memview" % rhs.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_gotref("%s.memview" % self.result())
if not self.lhs_of_first_assignment:
if self.entry.is_local and not Options.init_local_none:
code.put_xdecref("%s.memview" % self.result(), cython_memoryview_ptr_type)
else:
code.put_decref("%s.memview" % self.result(), cython_memoryview_ptr_type)
if self.entry.is_cglobal:
code.put_giveref("%s.memview" % rhs.result())
MemoryView.put_assign_to_memviewslice(self.result(), rhs.result(), self.type,
pos=self.pos, code=code)
if rhs.is_temp:
code.put_xdecref_clear("%s.memview" % rhs.result(), py_object_type)
if rhs.result_in_temp():
code.putln("%s.memview = 0;" % rhs.result())
def generate_acquire_buffer(self, rhs, code):
# rhstmp is only used in case the rhs is a complicated expression leading to
......@@ -3949,7 +3966,7 @@ class AttributeNode(ExprNode):
MemoryView.put_assign_to_memviewslice(select_code, rhs.result(), self.type,
pos=self.pos, code=code)
if rhs.is_temp:
code.put_xdecref_clear("%s.memview" % rhs.result(), py_object_type)
code.put_xdecref_clear("%s.memview" % rhs.result(), cython_memoryview_ptr_type)
if not self.type.is_memoryviewslice:
code.putln(
"%s = %s;" % (
......@@ -7620,7 +7637,7 @@ class CoerceToMemViewSliceNode(CoercionNode):
code.putln("__pyx_viewaxis_init_memviewslice_from_memview"
"((struct __pyx_obj_memoryview *)%s, %s, %d, sizeof(%s), \"%s\", &%s);" %\
(memviewobj, spec_int_arr, ndim, itemsize, format, self.result()))
code.put_gotref("%s.memview" % self.result())
code.put_gotref(code.as_pyobject("%s.memview" % self.result(), cython_memoryview_ptr_type))
code.funcstate.release_temp(memviewobj)
code.funcstate.release_temp(spec_int_arr)
......
......@@ -5,7 +5,7 @@ import Options
import CythonScope
from Code import UtilityCode
from UtilityCode import CythonUtilityCode
from PyrexTypes import py_object_type, cython_memoryview_type
from PyrexTypes import py_object_type, cython_memoryview_ptr_type
START_ERR = "there must be nothing or the value 0 (zero) in the start slot."
STOP_ERR = "Axis specification only allowed in the 'stop' slot."
......@@ -67,17 +67,11 @@ def format_from_type(base_type):
def put_init_entry(mv_cname, code):
code.putln("%s.data = NULL;" % mv_cname)
code.put_init_to_py_none("%s.memview" % mv_cname, cython_memoryview_type)
code.put_giveref("%s.memview" % mv_cname)
code.put_init_to_py_none("%s.memview" % mv_cname, cython_memoryview_ptr_type)
code.put_giveref(code.as_pyobject("%s.memview" % mv_cname, cython_memoryview_ptr_type))
def put_assign_to_memviewslice(lhs_cname, rhs_cname, memviewslicetype, pos, code):
# XXX: add error checks!
code.put_giveref("%s.memview" % (rhs_cname))
code.put_incref("%s.memview" % (rhs_cname), py_object_type)
code.put_gotref("%s.memview" % (lhs_cname))
code.put_xdecref("%s.memview" % (lhs_cname), py_object_type)
code.putln("%s.memview = %s.memview;" % (lhs_cname, rhs_cname))
code.putln("%s.data = %s.data;" % (lhs_cname, rhs_cname))
ndim = len(memviewslicetype.axes)
......
......@@ -1237,7 +1237,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in memviewslice_attrs:
code.putln("p->%s.data = NULL;" % entry.cname)
code.put_init_to_py_none("p->%s.memview" % entry.cname,
PyrexTypes.cython_memoryview_type, nanny=False)
PyrexTypes.cython_memoryview_ptr_type, nanny=False)
entry = scope.lookup_here("__new__")
if entry and entry.is_special:
if entry.trivial_signature:
......
......@@ -17,7 +17,7 @@ from Errors import error, warning, InternalError, CompileError
import Naming
import PyrexTypes
import TypeSlots
from PyrexTypes import py_object_type, error_type, CFuncType
from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType, cython_memoryview_ptr_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \
StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
from Cython.Utils import open_new_file, replace_suffix
......@@ -1424,12 +1424,14 @@ class FuncDefNode(StatNode, BlockNode):
if entry.type.is_pyobject:
if (acquire_gil or entry.assignments) and not entry.in_closure:
code.put_var_incref(entry)
if entry.type.is_memoryviewslice:
code.put_incref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
# ----- Initialise local buffer auxiliary variables
for entry in lenv.var_entries + lenv.arg_entries:
if entry.type.is_buffer and entry.buffer_aux.buflocal_nd_var.used:
Buffer.put_init_vars(entry, code)
# ----- Initialise local memoryview slices
for entry in lenv.var_entries + lenv.arg_entries:
# ----- Initialise local memoryviewslices
for entry in lenv.var_entries:
if entry.type.is_memoryviewslice:
MemoryView.put_init_entry(entry.cname, code)
# ----- Check and convert arguments
......@@ -1533,14 +1535,20 @@ class FuncDefNode(StatNode, BlockNode):
code.put_label(code.return_from_error_cleanup_label)
for entry in lenv.var_entries:
if not entry.used or entry.in_closure:
continue
if entry.type.is_memoryviewslice:
code.put_xdecref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
if entry.type.is_pyobject:
if entry.used and not entry.in_closure:
code.put_var_decref(entry)
code.put_var_decref(entry)
# Decref any increfed args
for entry in lenv.arg_entries:
if entry.type.is_pyobject:
if (acquire_gil or entry.assignments) and not entry.in_closure:
code.put_var_decref(entry)
if entry.type.is_memoryviewslice:
code.put_decref("%s.memview" % entry.cname, cython_memoryview_ptr_type)
if self.needs_closure:
code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
......@@ -1600,7 +1608,7 @@ class FuncDefNode(StatNode, BlockNode):
def declare_argument(self, env, arg):
if arg.type.is_void:
error(arg.pos, "Invalid use of 'void'")
elif not arg.type.is_complete() and not arg.type.is_array:
elif not arg.type.is_complete() and not (arg.type.is_array or arg.type.is_memoryviewslice):
error(arg.pos,
"Argument type '%s' is incomplete" % arg.type)
return env.declare_arg(arg.name, arg.type, arg.pos)
......
......@@ -172,7 +172,7 @@ class PyrexType(BaseType):
def global_init_code(self, entry, code):
# abstract
raise NotImplementedError()
pass
def public_decl(base_code, dll_linkage):
......@@ -374,7 +374,7 @@ class MemoryViewSliceType(PyrexType):
def global_init_code(self, entry, code):
code.putln("%s.data = NULL;" % entry.cname)
code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_type, nanny=False)
code.put_init_to_py_none("%s.memview" % entry.cname, cython_memoryview_ptr_type, nanny=False)
class BufferType(BaseType):
#
......@@ -2562,9 +2562,10 @@ c_pyx_buffer_ptr_type = CPtrType(c_pyx_buffer_type)
c_pyx_buffer_nd_type = CStructOrUnionType("__Pyx_LocalBuf_ND", "struct",
None, 1, "__Pyx_LocalBuf_ND")
cython_memoryview_type = CPtrType(CStructOrUnionType("__pyx_obj_memoryview", "struct",
None, 0, "__pyx_obj_memoryview"))
cython_memoryview_type = CStructOrUnionType("__pyx_obj_memoryview", "struct",
None, 0, "__pyx_obj_memoryview")
cython_memoryview_ptr_type = CPtrType(cython_memoryview_type)
error_type = ErrorType()
unspecified_type = UnspecifiedType()
......
......@@ -7,12 +7,20 @@ u'''
from cython.view cimport memoryview
from cython cimport array, PyBUF_C_CONTIGUOUS
def init_obj():
return 3
cdef passmvs(float[:,::1] mvs, object foo):
mvs = array((10,10), itemsize=sizeof(float), format='f')
foo = init_obj()
def f():
cdef array arr = array(shape=(10,10), itemsize=sizeof(int), format='i')
cdef memoryview mv = memoryview(arr, PyBUF_C_CONTIGUOUS)
def g():
cdef object obj = init_obj()
cdef int[::1] mview = array((10,), itemsize=sizeof(int), format='i')
obj = init_obj()
mview = array((10,), itemsize=sizeof(int), format='i')
cdef class Foo:
......@@ -33,9 +41,11 @@ cdef cdg():
cdef float[:,::1] global_mv = array((10,10), itemsize=sizeof(float), format='f')
global_mv = array((10,10), itemsize=sizeof(float), format='f')
cdef object global_obj
def call():
global global_mv
passmvs(global_mv, global_obj)
global_mv = array((3,3), itemsize=sizeof(float), format='f')
cdg()
f = Foo()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment