Commit a531a31c authored by Robert Bradshaw's avatar Robert Bradshaw

Walk cpp class hierarchy for class template deduction.

parent 2a360337
......@@ -3498,18 +3498,24 @@ class CppClassType(CType):
def deduce_template_params(self, actual):
if self == actual:
return {}
elif not hasattr(actual, 'template_type'):
# Untemplated type?
return None
# TODO(robertwb): Actual type equality.
elif (self.template_type or self).empty_declaration_code() == actual.template_type.empty_declaration_code():
return reduce(
merge_template_deductions,
[formal_param.deduce_template_params(actual_param)
for (formal_param, actual_param) in zip(self.templates, actual.templates)],
{})
elif actual.is_cpp_class:
self_template_type = self.template_type or self
def all_bases(cls):
yield cls
for parent in cls.base_classes:
for base in all_bases(parent):
yield base
for actual_base in all_bases(actual):
if (actual_base.template_type
and self_template_type.empty_declaration_code()
== actual_base.template_type.empty_declaration_code()):
return reduce(
merge_template_deductions,
[formal_param.deduce_template_params(actual_param)
for (formal_param, actual_param) in zip(self.templates, actual_base.templates)],
{})
else:
return None
return {}
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0,
......
......@@ -9,9 +9,12 @@ cdef extern from "cpp_template_functions_helper.h":
cdef cppclass A[T]:
pair[T, U] method[U](T, U)
U part_method[U](pair[T, U])
U part_method_ref[U](pair[T, U]&)
cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a)
cdef T* pointer_param[T](T*)
cdef cppclass double_pair(pair[double, double]):
double_pair(double, double)
def test_no_arg():
"""
......@@ -48,13 +51,15 @@ def test_method(int x, int y):
def test_part_method(int x, int y):
"""
>>> test_part_method(5, 10)
(10.0, 10)
(10.0, 10, 10.0)
"""
cdef A[int] a_int
cdef pair[int, double] p_int = (x, y)
cdef A[double] a_double
cdef pair[double, int] p_double = (x, y)
return a_int.part_method(p_int), a_double.part_method(p_double)
return (a_int.part_method(p_int),
a_double.part_method(p_double),
a_double.part_method_ref(double_pair(x, y)))
def test_simple_deduction(int x, double y):
"""
......
......@@ -24,6 +24,10 @@ class A {
U part_method(std::pair<T, U> p) {
return p.second;
}
template <typename U>
U part_method_ref(const std::pair<T, U>& p) {
return p.second;
}
};
template <typename T>
......@@ -40,3 +44,8 @@ template <typename T>
T* pointer_param(T* param) {
return param;
}
class double_pair : public std::pair<double, double> {
public:
double_pair(double x, double y) : std::pair<double, double>(x, y) { };
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment