Commit fe2b8aaf authored by Robert Bradshaw's avatar Robert Bradshaw

Specialization of C++ template classes.

parent 44875019
......@@ -1769,7 +1769,19 @@ class IndexNode(ExprNode):
def analyse_as_type(self, env):
base_type = self.base.analyse_as_type(env)
if base_type and not base_type.is_pyobject:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
if base_type.is_cpp_class:
if isinstance(self.index, TupleExprNode):
template_values = self.index.args
else:
template_values = [self.index]
import Nodes
type_node = Nodes.TemplatedTypeNode(
pos = self.pos,
positional_args = template_values,
keyword_args = None)
return type_node.analyse(env, base_type = base_type)
else:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None
def analyse_types(self, env):
......
......@@ -668,6 +668,9 @@ class CBaseTypeNode(Node):
pass
def analyse_as_type(self, env):
return self.analyse(env)
class CAnalysedBaseTypeNode(Node):
# type type
......@@ -739,31 +742,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
return PyrexTypes.error_type
class TemplatedTypeNode(CBaseTypeNode):
# name
# base_type_node CSimpleBaseTypeNode
# templates [CSimpleBaseTypeNode]
child_attrs = ["base_type_node", "templates"]
def analyse(self, env, could_be_name = False):
entry = env.lookup(self.base_type_node.name)
base_types = entry.type.templates
if not base_types:
error(self.pos, "%s type is not a template" % entry.type)
if len(base_types) != len(self.templates):
error(self.pos, "%s templated type receives %d arguments, got %d" %
(entry.type, len(base_types), len(self.templates)))
print entry.type
return entry.type
class CBufferAccessTypeNode(CBaseTypeNode):
# After parsing:
# positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode
# After analysis:
# type PyrexType.BufferType ...containing the right options
# type PyrexTypes.BufferType or PyrexTypes.CppClassType ...containing the right options
child_attrs = ["base_type_node", "positional_args",
......@@ -773,19 +758,37 @@ class CBufferAccessTypeNode(CBaseTypeNode):
name = None
def analyse(self, env, could_be_name = False):
base_type = self.base_type_node.analyse(env)
def analyse(self, env, could_be_name = False, base_type = None):
if base_type is None:
base_type = self.base_type_node.analyse(env)
if base_type.is_error: return base_type
import Buffer
options = Buffer.analyse_buffer_options(
self.pos,
env,
self.positional_args,
self.keyword_args,
base_type.buffer_defaults)
self.type = PyrexTypes.BufferType(base_type, **options)
if base_type.is_cpp_class:
if len(self.keyword_args.key_value_pairs) != 0:
error(self.pos, "c++ templates cannot take keyword arguments");
self.type = PyrexTypes.error_type
else:
template_types = []
for template_node in self.positional_args:
template_types.append(template_node.analyse_as_type(env))
self.type = base_type.specialize(self.pos, template_types)
else:
if not isinstance(env, Symtab.LocalScope):
error(self.pos, ERR_BUF_LOCALONLY)
import Buffer
options = Buffer.analyse_buffer_options(
self.pos,
env,
self.positional_args,
self.keyword_args,
base_type.buffer_defaults)
self.type = PyrexTypes.BufferType(base_type, **options)
return self.type
class CComplexBaseTypeNode(CBaseTypeNode):
......@@ -954,7 +957,7 @@ class CppClassNode(CStructOrUnionDefNode):
else:
base_class_types.append(base_class_entry.type)
self.entry = env.declare_cpp_class(
self.name, "cppclass", scope, 0, self.pos,
self.name, scope, self.pos,
self.cname, base_class_types, visibility = self.visibility, templates = self.templates)
self.entry.is_cpp_class = 1
if self.attributes is not None:
......@@ -5809,3 +5812,5 @@ proto="""
""")
#------------------------------------------------------------------------------------
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
......@@ -127,7 +127,6 @@ class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform):
......@@ -144,7 +143,7 @@ class PostParse(CythonTransform):
- Interpret some node structures into Python runtime values.
Some nodes take compile-time arguments (currently:
CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}),
TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
which should be interpreted. This happens in a general way
and other steps should be taken to ensure validity.
......@@ -153,7 +152,7 @@ class PostParse(CythonTransform):
- For __cythonbufferdefaults__ the arguments are checked for
validity.
CBufferAccessTypeNode has its options interpreted:
TemplatedTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid.
......@@ -242,11 +241,6 @@ class PostParse(CythonTransform):
self.context.nonfatal_error(e)
return None
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
return node
class PxdPostParse(CythonTransform, SkipDeclarations):
"""
Basic interpretation/validity checking that should only be
......
......@@ -1795,43 +1795,23 @@ def p_buffer_or_template(s, base_type_node):
# s.sy == '['
pos = s.position()
s.next()
if s.systring == 'int' or s.systring == 'long':
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
if keyword_args:
error(pos, "Keyword arguments not allowed for template types")
s.expect(']')
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node,
templates = positional_args)
else:
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
if positional_args:
if positional_args[0] != 'int' or positional_args != 'long':
if keyword_args:
error(pos, "Keyword arguments not allowed for template types")
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node,
templates = positional_args)
else:
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = base_type_node)
result = Nodes.TemplatedTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = base_type_node)
return result
def looking_at_name(s):
......
......@@ -1369,39 +1369,37 @@ class CStructOrUnionType(CType):
class CppClassType(CType):
# name string
# cname string
# kind string "cppclass"
# scope CppClassScope
# typedef_flag boolean
# packed boolean
# templates [string] or None
is_cpp_class = 1
has_attributes = 1
base_classes = []
exception_check = True
def __init__(self, name, kind, scope, typedef_flag, cname, base_classes, packed=False,
templates = None):
def __init__(self, name, scope, cname, base_classes, templates = None):
self.name = name
self.cname = cname
self.kind = kind
self.scope = scope
self.typedef_flag = typedef_flag
self.exception_check = True
self._convert_code = None
self.packed = packed
self.base_classes = base_classes
self.operators = []
self.templates = templates
def specialize(self, pos, template_values):
if self.templates is None:
error(pos, "'%s' type is not a template" % self);
return PyrexTypes.error_type
if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" %
(base_type, len(self.templates), len(template_values)))
return PyrexTypes.error_type
return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values)
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
templates = ""
if self.templates:
templates = "<"
for i in range(len(self.templates)-1):
templates += self.templates[i]
templates += ','
templates += self.templates[-1]
templates += ">"
template_strings = [param.declaration_code('', for_display, pyrex) for param in self.templates]
templates = "<" + ",".join(template_strings) + ">"
else:
templates = ""
if for_display or pyrex:
name = self.name
else:
......@@ -1419,6 +1417,7 @@ class CppClassType(CType):
def attributes_known(self):
return self.scope is not None
class TemplatedType(CType):
def __init__(self, name):
......@@ -1609,8 +1608,6 @@ c_anon_enum_type = CAnonEnumType(-1, 1)
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
cpp_class_type = CppClassType("cpp_class", "cppclass", None, 1, "cpp_class", [])
error_type = ErrorType()
unspecified_type = UnspecifiedType()
......
......@@ -1110,9 +1110,9 @@ class ModuleScope(Scope):
#
return entry
def declare_cpp_class(self, name, kind, scope,
typedef_flag, pos, cname = None, base_classes = [],
visibility = 'extern', packed = False, templates = None):
def declare_cpp_class(self, name, scope,
pos, cname = None, base_classes = [],
visibility = 'extern', templates = None):
if visibility != 'extern':
error(pos, "C++ classes may only be extern")
if cname is None:
......@@ -1120,22 +1120,19 @@ class ModuleScope(Scope):
entry = self.lookup(name)
if not entry:
type = PyrexTypes.CppClassType(
name, kind, scope, typedef_flag, cname, base_classes, packed, templates = templates)
name, scope, cname, base_classes, templates = templates)
entry = self.declare_type(name, type, pos, cname,
visibility = visibility, defining = scope is not None)
else:
if not (entry.is_type and entry.type.is_cpp_class
and entry.type.kind == kind):
if not (entry.is_type and entry.type.is_cpp_class):
warning(pos, "'%s' redeclared " % name, 0)
elif scope and entry.type.scope:
warning(pos, "'%s' already defined (ignoring second definition)" % name, 0)
else:
self.check_previous_typedef_flag(entry, typedef_flag, pos)
if scope:
entry.type.scope = scope
self.type_entries.append(entry)
if not scope and not entry.type.scope:
self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos)
entry.type.scope = CppClassScope(name)
def declare_inherited_attributes(entry, base_classes):
......@@ -1145,10 +1142,6 @@ class ModuleScope(Scope):
declare_inherited_attributes(entry, base_classes)
return entry
def check_for_illegal_incomplete_ctypedef(self, typedef_flag, pos):
if typedef_flag and not self.in_cinclude:
error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'")
def allocate_vtable_names(self, entry):
# If extension type has a vtable, allocate vtable struct and
# slot names for it.
......@@ -1238,7 +1231,7 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1
entry.as_variable = var_entry
class LocalScope(Scope):
class LocalScope(Scope):
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope)
......
......@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest):
def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
self.assert_(isinstance(bufnode, TemplatedTypeNode))
self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump()
# should put more here...
......@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest):
vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode))
self.assert_(isinstance(buftype, TemplatedTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name)
return buftype
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment