Commit fe2b8aaf authored by Robert Bradshaw's avatar Robert Bradshaw

Specialization of C++ template classes.

parent 44875019
...@@ -1769,7 +1769,19 @@ class IndexNode(ExprNode): ...@@ -1769,7 +1769,19 @@ class IndexNode(ExprNode):
def analyse_as_type(self, env): def analyse_as_type(self, env):
base_type = self.base.analyse_as_type(env) base_type = self.base.analyse_as_type(env)
if base_type and not base_type.is_pyobject: if base_type and not base_type.is_pyobject:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) if base_type.is_cpp_class:
if isinstance(self.index, TupleExprNode):
template_values = self.index.args
else:
template_values = [self.index]
import Nodes
type_node = Nodes.TemplatedTypeNode(
pos = self.pos,
positional_args = template_values,
keyword_args = None)
return type_node.analyse(env, base_type = base_type)
else:
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None return None
def analyse_types(self, env): def analyse_types(self, env):
......
...@@ -668,6 +668,9 @@ class CBaseTypeNode(Node): ...@@ -668,6 +668,9 @@ class CBaseTypeNode(Node):
pass pass
def analyse_as_type(self, env):
return self.analyse(env)
class CAnalysedBaseTypeNode(Node): class CAnalysedBaseTypeNode(Node):
# type type # type type
...@@ -739,31 +742,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -739,31 +742,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
return PyrexTypes.error_type return PyrexTypes.error_type
class TemplatedTypeNode(CBaseTypeNode): class TemplatedTypeNode(CBaseTypeNode):
# name
# base_type_node CSimpleBaseTypeNode
# templates [CSimpleBaseTypeNode]
child_attrs = ["base_type_node", "templates"]
def analyse(self, env, could_be_name = False):
entry = env.lookup(self.base_type_node.name)
base_types = entry.type.templates
if not base_types:
error(self.pos, "%s type is not a template" % entry.type)
if len(base_types) != len(self.templates):
error(self.pos, "%s templated type receives %d arguments, got %d" %
(entry.type, len(base_types), len(self.templates)))
print entry.type
return entry.type
class CBufferAccessTypeNode(CBaseTypeNode):
# After parsing: # After parsing:
# positional_args [ExprNode] List of positional arguments # positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments # keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode # base_type_node CBaseTypeNode
# After analysis: # After analysis:
# type PyrexType.BufferType ...containing the right options # type PyrexTypes.BufferType or PyrexTypes.CppClassType ...containing the right options
child_attrs = ["base_type_node", "positional_args", child_attrs = ["base_type_node", "positional_args",
...@@ -773,19 +758,37 @@ class CBufferAccessTypeNode(CBaseTypeNode): ...@@ -773,19 +758,37 @@ class CBufferAccessTypeNode(CBaseTypeNode):
name = None name = None
def analyse(self, env, could_be_name = False): def analyse(self, env, could_be_name = False, base_type = None):
base_type = self.base_type_node.analyse(env) if base_type is None:
base_type = self.base_type_node.analyse(env)
if base_type.is_error: return base_type if base_type.is_error: return base_type
import Buffer
options = Buffer.analyse_buffer_options(
self.pos,
env,
self.positional_args,
self.keyword_args,
base_type.buffer_defaults)
self.type = PyrexTypes.BufferType(base_type, **options) if base_type.is_cpp_class:
if len(self.keyword_args.key_value_pairs) != 0:
error(self.pos, "c++ templates cannot take keyword arguments");
self.type = PyrexTypes.error_type
else:
template_types = []
for template_node in self.positional_args:
template_types.append(template_node.analyse_as_type(env))
self.type = base_type.specialize(self.pos, template_types)
else:
if not isinstance(env, Symtab.LocalScope):
error(self.pos, ERR_BUF_LOCALONLY)
import Buffer
options = Buffer.analyse_buffer_options(
self.pos,
env,
self.positional_args,
self.keyword_args,
base_type.buffer_defaults)
self.type = PyrexTypes.BufferType(base_type, **options)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
...@@ -954,7 +957,7 @@ class CppClassNode(CStructOrUnionDefNode): ...@@ -954,7 +957,7 @@ class CppClassNode(CStructOrUnionDefNode):
else: else:
base_class_types.append(base_class_entry.type) base_class_types.append(base_class_entry.type)
self.entry = env.declare_cpp_class( self.entry = env.declare_cpp_class(
self.name, "cppclass", scope, 0, self.pos, self.name, scope, self.pos,
self.cname, base_class_types, visibility = self.visibility, templates = self.templates) self.cname, base_class_types, visibility = self.visibility, templates = self.templates)
self.entry.is_cpp_class = 1 self.entry.is_cpp_class = 1
if self.attributes is not None: if self.attributes is not None:
...@@ -5809,3 +5812,5 @@ proto=""" ...@@ -5809,3 +5812,5 @@ proto="""
""") """)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
...@@ -127,7 +127,6 @@ class PostParseError(CompileError): pass ...@@ -127,7 +127,6 @@ class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them # error strings checked by unit tests, so define them
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions' ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared' ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform): class PostParse(CythonTransform):
...@@ -144,7 +143,7 @@ class PostParse(CythonTransform): ...@@ -144,7 +143,7 @@ class PostParse(CythonTransform):
- Interpret some node structures into Python runtime values. - Interpret some node structures into Python runtime values.
Some nodes take compile-time arguments (currently: Some nodes take compile-time arguments (currently:
CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}), TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
which should be interpreted. This happens in a general way which should be interpreted. This happens in a general way
and other steps should be taken to ensure validity. and other steps should be taken to ensure validity.
...@@ -153,7 +152,7 @@ class PostParse(CythonTransform): ...@@ -153,7 +152,7 @@ class PostParse(CythonTransform):
- For __cythonbufferdefaults__ the arguments are checked for - For __cythonbufferdefaults__ the arguments are checked for
validity. validity.
CBufferAccessTypeNode has its options interpreted: TemplatedTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid. so on. Also it is checked that the option combination is valid.
...@@ -242,11 +241,6 @@ class PostParse(CythonTransform): ...@@ -242,11 +241,6 @@ class PostParse(CythonTransform):
self.context.nonfatal_error(e) self.context.nonfatal_error(e)
return None return None
def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
return node
class PxdPostParse(CythonTransform, SkipDeclarations): class PxdPostParse(CythonTransform, SkipDeclarations):
""" """
Basic interpretation/validity checking that should only be Basic interpretation/validity checking that should only be
......
...@@ -1795,43 +1795,23 @@ def p_buffer_or_template(s, base_type_node): ...@@ -1795,43 +1795,23 @@ def p_buffer_or_template(s, base_type_node):
# s.sy == '[' # s.sy == '['
pos = s.position() pos = s.position()
s.next() s.next()
if s.systring == 'int' or s.systring == 'long': positional_args, keyword_args = (
positional_args, keyword_args = ( p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) )
) s.expect(']')
if keyword_args:
error(pos, "Keyword arguments not allowed for template types")
s.expect(']')
result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node, keyword_dict = ExprNodes.DictNode(pos,
templates = positional_args) key_value_pairs = [
else: ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
positional_args, keyword_args = ( for key, value in keyword_args
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) ])
)
if positional_args:
if positional_args[0] != 'int' or positional_args != 'long':
if keyword_args:
error(pos, "Keyword arguments not allowed for template types")
s.expect(']')
result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node, result = Nodes.TemplatedTypeNode(pos,
templates = positional_args) positional_args = positional_args,
else: keyword_args = keyword_dict,
s.expect(']') base_type_node = base_type_node)
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = base_type_node)
return result return result
def looking_at_name(s): def looking_at_name(s):
......
...@@ -1369,39 +1369,37 @@ class CStructOrUnionType(CType): ...@@ -1369,39 +1369,37 @@ class CStructOrUnionType(CType):
class CppClassType(CType): class CppClassType(CType):
# name string # name string
# cname string # cname string
# kind string "cppclass"
# scope CppClassScope # scope CppClassScope
# typedef_flag boolean
# packed boolean
# templates [string] or None # templates [string] or None
is_cpp_class = 1 is_cpp_class = 1
has_attributes = 1 has_attributes = 1
base_classes = [] exception_check = True
def __init__(self, name, kind, scope, typedef_flag, cname, base_classes, packed=False, def __init__(self, name, scope, cname, base_classes, templates = None):
templates = None):
self.name = name self.name = name
self.cname = cname self.cname = cname
self.kind = kind
self.scope = scope self.scope = scope
self.typedef_flag = typedef_flag
self.exception_check = True
self._convert_code = None
self.packed = packed
self.base_classes = base_classes self.base_classes = base_classes
self.operators = [] self.operators = []
self.templates = templates self.templates = templates
def specialize(self, pos, template_values):
if self.templates is None:
error(pos, "'%s' type is not a template" % self);
return PyrexTypes.error_type
if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" %
(base_type, len(self.templates), len(template_values)))
return PyrexTypes.error_type
return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values)
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
templates = ""
if self.templates: if self.templates:
templates = "<" template_strings = [param.declaration_code('', for_display, pyrex) for param in self.templates]
for i in range(len(self.templates)-1): templates = "<" + ",".join(template_strings) + ">"
templates += self.templates[i] else:
templates += ',' templates = ""
templates += self.templates[-1]
templates += ">"
if for_display or pyrex: if for_display or pyrex:
name = self.name name = self.name
else: else:
...@@ -1419,6 +1417,7 @@ class CppClassType(CType): ...@@ -1419,6 +1417,7 @@ class CppClassType(CType):
def attributes_known(self): def attributes_known(self):
return self.scope is not None return self.scope is not None
class TemplatedType(CType): class TemplatedType(CType):
def __init__(self, name): def __init__(self, name):
...@@ -1609,8 +1608,6 @@ c_anon_enum_type = CAnonEnumType(-1, 1) ...@@ -1609,8 +1608,6 @@ c_anon_enum_type = CAnonEnumType(-1, 1)
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer") c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type) c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
cpp_class_type = CppClassType("cpp_class", "cppclass", None, 1, "cpp_class", [])
error_type = ErrorType() error_type = ErrorType()
unspecified_type = UnspecifiedType() unspecified_type = UnspecifiedType()
......
...@@ -1110,9 +1110,9 @@ class ModuleScope(Scope): ...@@ -1110,9 +1110,9 @@ class ModuleScope(Scope):
# #
return entry return entry
def declare_cpp_class(self, name, kind, scope, def declare_cpp_class(self, name, scope,
typedef_flag, pos, cname = None, base_classes = [], pos, cname = None, base_classes = [],
visibility = 'extern', packed = False, templates = None): visibility = 'extern', templates = None):
if visibility != 'extern': if visibility != 'extern':
error(pos, "C++ classes may only be extern") error(pos, "C++ classes may only be extern")
if cname is None: if cname is None:
...@@ -1120,22 +1120,19 @@ class ModuleScope(Scope): ...@@ -1120,22 +1120,19 @@ class ModuleScope(Scope):
entry = self.lookup(name) entry = self.lookup(name)
if not entry: if not entry:
type = PyrexTypes.CppClassType( type = PyrexTypes.CppClassType(
name, kind, scope, typedef_flag, cname, base_classes, packed, templates = templates) name, scope, cname, base_classes, templates = templates)
entry = self.declare_type(name, type, pos, cname, entry = self.declare_type(name, type, pos, cname,
visibility = visibility, defining = scope is not None) visibility = visibility, defining = scope is not None)
else: else:
if not (entry.is_type and entry.type.is_cpp_class if not (entry.is_type and entry.type.is_cpp_class):
and entry.type.kind == kind):
warning(pos, "'%s' redeclared " % name, 0) warning(pos, "'%s' redeclared " % name, 0)
elif scope and entry.type.scope: elif scope and entry.type.scope:
warning(pos, "'%s' already defined (ignoring second definition)" % name, 0) warning(pos, "'%s' already defined (ignoring second definition)" % name, 0)
else: else:
self.check_previous_typedef_flag(entry, typedef_flag, pos)
if scope: if scope:
entry.type.scope = scope entry.type.scope = scope
self.type_entries.append(entry) self.type_entries.append(entry)
if not scope and not entry.type.scope: if not scope and not entry.type.scope:
self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos)
entry.type.scope = CppClassScope(name) entry.type.scope = CppClassScope(name)
def declare_inherited_attributes(entry, base_classes): def declare_inherited_attributes(entry, base_classes):
...@@ -1145,10 +1142,6 @@ class ModuleScope(Scope): ...@@ -1145,10 +1142,6 @@ class ModuleScope(Scope):
declare_inherited_attributes(entry, base_classes) declare_inherited_attributes(entry, base_classes)
return entry return entry
def check_for_illegal_incomplete_ctypedef(self, typedef_flag, pos):
if typedef_flag and not self.in_cinclude:
error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'")
def allocate_vtable_names(self, entry): def allocate_vtable_names(self, entry):
# If extension type has a vtable, allocate vtable struct and # If extension type has a vtable, allocate vtable struct and
# slot names for it. # slot names for it.
...@@ -1238,7 +1231,7 @@ class ModuleScope(Scope): ...@@ -1238,7 +1231,7 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
class LocalScope(Scope): class LocalScope(Scope):
def __init__(self, name, outer_scope): def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, outer_scope) Scope.__init__(self, name, outer_scope, outer_scope)
......
...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest): ...@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest):
def test_basic(self): def test_basic(self):
t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x") t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
bufnode = t.stats[0].base_type bufnode = t.stats[0].base_type
self.assert_(isinstance(bufnode, CBufferAccessTypeNode)) self.assert_(isinstance(bufnode, TemplatedTypeNode))
self.assertEqual(2, len(bufnode.positional_args)) self.assertEqual(2, len(bufnode.positional_args))
# print bufnode.dump() # print bufnode.dump()
# should put more here... # should put more here...
...@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest): ...@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest):
vardef = root.stats[0].body.stats[0] vardef = root.stats[0].body.stats[0]
assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
buftype = vardef.base_type buftype = vardef.base_type
self.assert_(isinstance(buftype, CBufferAccessTypeNode)) self.assert_(isinstance(buftype, TemplatedTypeNode))
self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
self.assertEqual(u"object", buftype.base_type_node.name) self.assertEqual(u"object", buftype.base_type_node.name)
return buftype return buftype
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment