Commit fef78f5e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Creating buffer type

parents be51c0a5 c2c1596c
...@@ -607,15 +607,29 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -607,15 +607,29 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
return PyrexTypes.error_type return PyrexTypes.error_type
class CBufferAccessTypeNode(Node): class CBufferAccessTypeNode(Node):
# base_type_node CBaseTypeNode # After parsing:
# positional_args [ExprNode] List of positional arguments # positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments # keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode
child_attrs = ["base_type_node", "positional_args", "keyword_args"] # After PostParse:
# dtype_node CBaseTypeNode
# ndim int
def analyse(self, env): # After analysis:
# type PyrexType.PyrexType
return self.base_type_node.analyse(env) child_attrs = ["base_type_node", "positional_args", "keyword_args",
"dtype_node"]
dtype_node = None
def analyse(self, env):
base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env)
options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
self.type = PyrexTypes.create_buffer_type(base_type, options)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
# base_type CBaseTypeNode # base_type CBaseTypeNode
......
...@@ -124,7 +124,7 @@ class PostParse(CythonTransform): ...@@ -124,7 +124,7 @@ class PostParse(CythonTransform):
# get dtype # get dtype
dtype = options.get("dtype") dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype = dtype node.dtype_node = dtype
# get ndim # get ndim
if "ndim" in provided: if "ndim" in provided:
...@@ -143,8 +143,6 @@ class PostParse(CythonTransform): ...@@ -143,8 +143,6 @@ class PostParse(CythonTransform):
node.keyword_args = None node.keyword_args = None
return node return node
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from Cython import Utils from Cython import Utils
import Naming import Naming
import copy
class BaseType: class BaseType:
# #
...@@ -183,6 +184,20 @@ class CTypedefType(BaseType): ...@@ -183,6 +184,20 @@ class CTypedefType(BaseType):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.typedef_base_type, name) return getattr(self.typedef_base_type, name)
class BufferOptions:
# dtype PyrexType
# ndim int
def __init__(self, dtype, ndim):
self.dtype = dtype
self.ndim = ndim
def create_buffer_type(base_type, buffer_options):
# Make a shallow copy of base_type and then annotate it
# with the buffer information
result = copy.copy(base_type)
result.buffer_options = buffer_options
return result
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
...@@ -193,6 +208,7 @@ class PyObjectType(PyrexType): ...@@ -193,6 +208,7 @@ class PyObjectType(PyrexType):
default_value = "0" default_value = "0"
parsetuple_format = "O" parsetuple_format = "O"
pymemberdef_typecode = "T_OBJECT" pymemberdef_typecode = "T_OBJECT"
buffer_options = None # can contain a BufferOptions instance
def __str__(self): def __str__(self):
return "Python object" return "Python object"
......
from Cython.TestUtils import CythonTest
import Cython.Compiler.Errors as Errors
from Cython.Compiler.Nodes import *
from Cython.Compiler.ParseTreeTransforms import *
class TestBufferParsing(CythonTest):
# First, we only test the raw parser, i.e.
# the number and contents of arguments are NOT checked.
# However "dtype"/the first positional argument is special-cased
# to parse a type argument rather than an expression
def parse(self, s):
return self.should_not_fail(lambda: self.fragment(s)).root
def not_parseable(self, expected_error, s):
e = self.should_fail(lambda: self.fragment(s), Errors.CompileError)
self.assertEqual(expected_error, e.message_only)
def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump()
# should put more here...
def test_type_fail(self):
self.not_parseable("Expected: type",
u"cdef object[2] x")
def test_type_pos(self):
self.parse(u"cdef object[short unsigned int, 3] x")
def test_type_keyword(self):
self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x")
def test_notype_as_expr1(self):
self.not_parseable("Expected: expression",
u"cdef object[foo2=short unsigned int] x")
def test_notype_as_expr2(self):
self.not_parseable("Expected: expression",
u"cdef object[int, short unsigned int] x")
def test_pos_after_key(self):
self.not_parseable("Non-keyword arg following keyword arg",
u"cdef object[foo=1, 2] x")
class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets
def parse_opts(self, opts):
s = u"cdef object[%s] x" % opts
root = self.fragment(s, pipeline=[PostParse(self)]).root
buftype = root.stats[0].base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
def non_parse(self, expected_err, opts):
e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, e.message_only)
def test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dict(self):
buf = self.parse_opts(u"ndim=3, dtype=unsigned short int")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim)
def test_dtype(self):
self.non_parse(ERR_BUF_MISSING % 'dtype', u"")
def test_ndim(self):
self.parse_opts(u"int, 2")
self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'")
self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34")
def test_use_DEF(self):
t = self.fragment(u"""
DEF ndim = 3
cdef object[int, ndim] x
cdef object[ndim=ndim, dtype=int] y
""", pipeline=[PostParse(self)]).root
self.assert_(t.stats[1].base_type.ndim == 3)
self.assert_(t.stats[2].base_type.ndim == 3)
# add exotic and impossible combinations as they come along
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment