Commit b629f965 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Compiler option decorator, with statement, and testcase for buffer boundscheck toggling

parent d2f592da
...@@ -266,11 +266,16 @@ class ResolveOptions(CythonTransform): ...@@ -266,11 +266,16 @@ class ResolveOptions(CythonTransform):
"options" attribute linking it to a dict containing the exact "options" attribute linking it to a dict containing the exact
options that are in effect for that node. Any corresponding decorators options that are in effect for that node. Any corresponding decorators
or with statements are removed in the process. or with statements are removed in the process.
Note that we have to run this prior to analysis, and so some minor
duplication of functionality has to occur: We manually track cimports
to correctly intercept @cython... and with cython...
""" """
def __init__(self, context, compilation_option_overrides): def __init__(self, context, compilation_option_overrides):
super(ResolveOptions, self).__init__(context) super(ResolveOptions, self).__init__(context)
self.compilation_option_overrides = compilation_option_overrides self.compilation_option_overrides = compilation_option_overrides
self.cython_module_names = set()
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
options = copy.copy(Options.option_defaults) options = copy.copy(Options.option_defaults)
...@@ -281,11 +286,91 @@ class ResolveOptions(CythonTransform): ...@@ -281,11 +286,91 @@ class ResolveOptions(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
# Track cimports of the cython module.
def visit_CImportStatNode(self, node):
if node.module_name == u"cython":
if node.as_name:
modname = node.as_name
else:
modname = u"cython"
self.cython_module_names.add(modname)
elif node.as_name and node.as_name in self.cython_module_names:
self.cython_module_names.remove(node.as_name)
return node
def visit_Node(self, node): def visit_Node(self, node):
node.options = self.options node.options = self.options
self.visitchildren(node) self.visitchildren(node)
return node return node
def try_to_parse_option(self, node):
# If node is the contents of an option (in a with statement or
# decorator), returns (optionname, value).
# Otherwise, returns None
if (isinstance(node, SimpleCallNode) and
isinstance(node.function, AttributeNode) and
isinstance(node.function.obj, NameNode) and
node.function.obj.name in self.cython_module_names):
optname = node.function.attribute
optiontype = Options.option_types.get(optname)
if optiontype:
args = node.args
if optiontype is bool:
if len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(dec.function.pos,
'The %s option takes one compile-time boolean argument' % optname)
return (optname, args[0].value)
else:
assert False
else:
return None
options.append((dec.function.attribute, dec.args, dec.function.pos))
return False
else:
return None
def visit_with_options(self, node, options):
if not options:
return self.visit_Node(node)
else:
oldoptions = self.options
newoptions = copy.copy(oldoptions)
newoptions.update(options)
self.options = newoptions
node = self.visit_Node(node)
self.options = oldoptions
return node
# Handle decorators
def visit_DefNode(self, node):
options = {}
if node.decorators:
# Split the decorators into two lists -- real decorators and options
realdecs = []
for dec in node.decorators:
option = self.try_to_parse_option(dec.decorator)
if option is not None:
name, value = option
options[name] = value
else:
realdecs.append(dec)
node.decorators = realdecs
return self.visit_with_options(node, options)
# Handle with statements
def visit_WithStatNode(self, node):
option = self.try_to_parse_option(node.manager)
if option is not None:
if node.target is not None:
raise PostParseError(node.pos, "Compiler option with statements cannot contain 'as'")
name, value = option
self.visit_with_options(node.body, {name:value})
return node.body.stats
else:
return self.visit_Node(node)
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
...@@ -2,5 +2,17 @@ ...@@ -2,5 +2,17 @@
print 3 print 3
cimport python_dict as asadf, python_exc, cython as cy
@cy.boundscheck(False)
def f(object[int, 2] buf): def f(object[int, 2] buf):
print buf[3, 2] print buf[3, 2]
@cy.boundscheck(True)
def g(object[int, 2] buf):
print buf[3, 2]
def h(object[int, 2] buf):
print buf[3, 2]
with cy.boundscheck(True):
print buf[3,2]
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
cimport stdlib cimport stdlib
cimport python_buffer cimport python_buffer
cimport stdio cimport stdio
cimport cython
__test__ = {} __test__ = {}
setup_string = """ setup_string = """
...@@ -506,6 +508,70 @@ def strided(object[int, 1, 'strided'] buf): ...@@ -506,6 +508,70 @@ def strided(object[int, 1, 'strided'] buf):
""" """
return buf[2] return buf[2]
#
# Test compiler options for bounds checking. We create an array with a
# safe "boundary" (memory
# allocated outside of what it published) and then check whether we get back
# what we stored in the memory or an error.
@testcase
def safe_get(object[int] buf, int idx):
"""
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
Validate our testing buffer...
>>> safe_get(A, 0)
5
>>> safe_get(A, 2)
7
>>> safe_get(A, -3)
5
Access outside it. This is already done above for bounds check
testing but we include it to tell the story right.
>>> safe_get(A, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> safe_get(A, 3)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
"""
return buf[idx]
@testcase
@cython.boundscheck(False)
def unsafe_get(object[int] buf, int idx):
"""
Access outside of the area the buffer publishes.
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
>>> unsafe_get(A, -4)
4
>>> unsafe_get(A, -5)
3
>>> unsafe_get(A, 3)
8
"""
return buf[idx]
@testcase
def mixed_get(object[int] buf, int unsafe_idx, int safe_idx):
"""
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
>>> mixed_get(A, -4, 0)
(4, 5)
>>> mixed_get(A, 0, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
"""
with cython.boundscheck(False):
one = buf[unsafe_idx]
with cython.boundscheck(True):
two = buf[safe_idx]
return (one, two)
# #
# Coercions # Coercions
...@@ -658,7 +724,7 @@ available_flags = ( ...@@ -658,7 +724,7 @@ available_flags = (
) )
cdef class MockBuffer: cdef class MockBuffer:
cdef object format cdef object format, offset
cdef void* buffer cdef void* buffer
cdef int len, itemsize, ndim cdef int len, itemsize, ndim
cdef Py_ssize_t* strides cdef Py_ssize_t* strides
...@@ -669,10 +735,11 @@ cdef class MockBuffer: ...@@ -669,10 +735,11 @@ cdef class MockBuffer:
cdef readonly object recieved_flags, release_ok cdef readonly object recieved_flags, release_ok
cdef public object fail cdef public object fail
def __init__(self, label, data, shape=None, strides=None, format=None): def __init__(self, label, data, shape=None, strides=None, format=None, offset=0):
self.label = label self.label = label
self.release_ok = True self.release_ok = True
self.log = "" self.log = ""
self.offset = offset
self.itemsize = self.get_itemsize() self.itemsize = self.get_itemsize()
if format is None: format = self.get_default_format() if format is None: format = self.get_default_format()
if shape is None: shape = (len(data),) if shape is None: shape = (len(data),)
...@@ -765,7 +832,7 @@ cdef class MockBuffer: ...@@ -765,7 +832,7 @@ cdef class MockBuffer:
if (value & flags) == value: if (value & flags) == value:
self.recieved_flags.append(name) self.recieved_flags.append(name)
buffer.buf = self.buffer buffer.buf = <void*>(<char*>self.buffer + (<int>self.offset * self.itemsize))
buffer.len = self.len buffer.len = self.len
buffer.readonly = 0 buffer.readonly = 0
buffer.format = <char*>self.format buffer.format = <char*>self.format
...@@ -775,6 +842,7 @@ cdef class MockBuffer: ...@@ -775,6 +842,7 @@ cdef class MockBuffer:
buffer.suboffsets = self.suboffsets buffer.suboffsets = self.suboffsets
buffer.itemsize = self.itemsize buffer.itemsize = self.itemsize
buffer.internal = NULL buffer.internal = NULL
if self.label:
msg = "acquired %s" % self.label msg = "acquired %s" % self.label
print msg print msg
self.log += msg + "\n" self.log += msg + "\n"
...@@ -782,6 +850,7 @@ cdef class MockBuffer: ...@@ -782,6 +850,7 @@ cdef class MockBuffer:
def __releasebuffer__(MockBuffer self, Py_buffer* buffer): def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
if buffer.suboffsets != self.suboffsets: if buffer.suboffsets != self.suboffsets:
self.release_ok = False self.release_ok = False
if self.label:
msg = "released %s" % self.label msg = "released %s" % self.label
print msg print msg
self.log += msg + "\n" self.log += msg + "\n"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment