Commit 74ee82e2 authored by Robert Bradshaw's avatar Robert Bradshaw

Fix C++ template subclassing.

parent a7c689c1
...@@ -1170,7 +1170,7 @@ class CppClassNode(CStructOrUnionDefNode): ...@@ -1170,7 +1170,7 @@ class CppClassNode(CStructOrUnionDefNode):
# in_pxd boolean # in_pxd boolean
# attributes [CVarDefNode] or None # attributes [CVarDefNode] or None
# entry Entry # entry Entry
# base_classes [string] # base_classes [CBaseTypeNode]
# templates [string] or None # templates [string] or None
def declare(self, env): def declare(self, env):
...@@ -1185,16 +1185,8 @@ class CppClassNode(CStructOrUnionDefNode): ...@@ -1185,16 +1185,8 @@ class CppClassNode(CStructOrUnionDefNode):
def analyse_declarations(self, env): def analyse_declarations(self, env):
scope = None scope = None
if self.attributes is not None: if self.attributes is not None:
scope = CppClassScope(self.name, env) scope = CppClassScope(self.name, env, templates = self.templates)
base_class_types = [] base_class_types = [b.analyse(scope) for b in self.base_classes]
for base_class_name in self.base_classes:
base_class_entry = env.lookup(base_class_name)
if base_class_entry is None:
error(self.pos, "'%s' not found" % base_class_name)
elif not base_class_entry.is_type or not base_class_entry.type.is_cpp_class:
error(self.pos, "'%s' is not a cpp class type" % base_class_name)
else:
base_class_types.append(base_class_entry.type)
if self.templates is None: if self.templates is None:
template_types = None template_types = None
else: else:
......
...@@ -3041,10 +3041,10 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3041,10 +3041,10 @@ def p_cpp_class_definition(s, pos, ctx):
templates = None templates = None
if s.sy == '(': if s.sy == '(':
s.next() s.next()
base_classes = [p_dotted_name(s, False)[2]] base_classes = [p_c_base_type(s, templates = templates)]
while s.sy == ',': while s.sy == ',':
s.next() s.next()
base_classes.append(p_dotted_name(s, False)[2]) base_classes.append(p_c_base_type(s, templates = templates))
s.expect(')') s.expect(')')
else: else:
base_classes = [] base_classes = []
......
...@@ -514,10 +514,6 @@ class Scope(object): ...@@ -514,10 +514,6 @@ class Scope(object):
if templates or entry.type.templates: if templates or entry.type.templates:
if templates != entry.type.templates: if templates != entry.type.templates:
error(pos, "Template parameters do not match previous declaration") error(pos, "Template parameters do not match previous declaration")
if templates is not None and entry.type.scope is not None:
for T in templates:
template_entry = entry.type.scope.declare(T.name, T.name, T, None, 'extern')
template_entry.is_type = 1
def declare_inherited_attributes(entry, base_classes): def declare_inherited_attributes(entry, base_classes):
for base_class in base_classes: for base_class in base_classes:
...@@ -1993,10 +1989,15 @@ class CppClassScope(Scope): ...@@ -1993,10 +1989,15 @@ class CppClassScope(Scope):
default_constructor = None default_constructor = None
def __init__(self, name, outer_scope): def __init__(self, name, outer_scope, templates=None):
Scope.__init__(self, name, outer_scope, None) Scope.__init__(self, name, outer_scope, None)
self.directives = outer_scope.directives self.directives = outer_scope.directives
self.inherited_var_entries = [] self.inherited_var_entries = []
if templates is not None:
for T in templates:
template_entry = self.declare(
T, T, PyrexTypes.TemplatePlaceholderType(T), None, 'extern')
template_entry.is_type = 1
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'extern', cname = None, visibility = 'extern',
......
# tag: cpp
from cython.operator import dereference as deref
from libcpp.pair cimport pair
from libcpp.vector cimport vector
cdef extern from "cpp_template_subclasses_helper.h":
cdef cppclass Base:
char* name()
cdef cppclass A[A1](Base):
A1 funcA(A1)
cdef cppclass B[B1, B2](A[B2]):
pair[B1, B2] funcB(B1, B2)
cdef cppclass C[C1](B[long, C1]):
C1 funcC(C1)
cdef cppclass D[D1](C[pair[D1, D1]]):
pass
cdef cppclass E(D[double]):
pass
def testA(x):
"""
>>> testA(10)
10.0
"""
cdef Base *base
cdef A[double] *a = NULL
try:
a = new A[double]()
base = a
assert base.name() == b"A", base.name()
return a.funcA(x)
finally:
del a
def testB(x, y):
"""
>>> testB(1, 2)
>>> testB(1, 1.5)
"""
cdef Base *base
cdef A[double] *a
cdef B[long, double] *b = NULL
try:
base = a = b = new B[long, double]()
assert base.name() == b"B", base.name()
assert a.funcA(y) == y
assert <object>b.funcB(x, y) == (x, y)
finally:
del b
def testC(x, y):
"""
>>> testC(37, [1, 37])
>>> testC(25, [1, 5, 25])
>>> testC(105, [1, 3, 5, 7, 15, 21, 35, 105])
"""
cdef Base *base
cdef A[vector[long]] *a
cdef B[long, vector[long]] *b
cdef C[vector[long]] *c = NULL
try:
base = a = b = c = new C[vector[long]]()
assert base.name() == b"C", base.name()
assert <object>a.funcA(y) == y
assert <object>b.funcB(x, y) == (x, y)
assert <object>c.funcC(y) == y
finally:
del c
def testD(x, y):
"""
>>> testD(1, 1.0)
>>> testD(2, 0.5)
>>> testD(4, 0.25)
"""
cdef Base *base
cdef A[pair[double, double]] *a
cdef B[long, pair[double, double]] *b
cdef C[pair[double, double]] *c
cdef D[double] *d = NULL
try:
base = a = b = c = d = new D[double]()
assert base.name() == b"D", base.name()
assert <object>a.funcA((y, y)) == (y, y)
assert <object>b.funcB(x, (y, y + 1)) == (x, (y, y + 1))
assert <object>c.funcC((y, y)) == (y, y)
finally:
del d
def testE(x, y):
"""
>>> testD(1, 1.0)
>>> testD(2, 0.5)
>>> testD(4, 0.25)
"""
cdef Base *base
cdef A[pair[double, double]] *a
cdef B[long, pair[double, double]] *b
cdef C[pair[double, double]] *c
cdef D[double] *d
cdef E *e = NULL
try:
base = a = b = c = d = e = new E()
assert base.name() == b"E", base.name()
assert <object>a.funcA((y, y)) == (y, y)
assert <object>b.funcB(x, (y, y + 1)) == (x, (y, y + 1))
assert <object>c.funcC((y, y)) == (y, y)
finally:
del e
using namespace std;
class Base {
public:
virtual const char* name() { return "Base"; }
};
template <class A1>
class A : public Base {
public:
virtual const char* name() { return "A"; }
A1 funcA(A1 x) { return x; }
};
template <class B1, class B2>
class B : public A<B2> {
public:
virtual const char* name() { return "B"; }
pair<B1, B2> funcB(B1 x, B2 y) { return pair<B1, B2>(x, y); }
};
template <class C1>
class C : public B<long, C1> {
public:
virtual const char* name() { return "C"; }
C1 funcC(C1 x) { return x; }
};
template <class D1>
class D : public C<pair<D1, D1> > {
virtual const char* name() { return "D"; }
};
class E : public D<double> {
virtual const char* name() { return "E"; }
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment