Commit a531a31c authored by Robert Bradshaw's avatar Robert Bradshaw

Walk cpp class hierarchy for class template deduction.

parent 2a360337
...@@ -3498,18 +3498,24 @@ class CppClassType(CType): ...@@ -3498,18 +3498,24 @@ class CppClassType(CType):
def deduce_template_params(self, actual): def deduce_template_params(self, actual):
if self == actual: if self == actual:
return {} return {}
elif not hasattr(actual, 'template_type'): elif actual.is_cpp_class:
# Untemplated type? self_template_type = self.template_type or self
return None def all_bases(cls):
# TODO(robertwb): Actual type equality. yield cls
elif (self.template_type or self).empty_declaration_code() == actual.template_type.empty_declaration_code(): for parent in cls.base_classes:
for base in all_bases(parent):
yield base
for actual_base in all_bases(actual):
if (actual_base.template_type
and self_template_type.empty_declaration_code()
== actual_base.template_type.empty_declaration_code()):
return reduce( return reduce(
merge_template_deductions, merge_template_deductions,
[formal_param.deduce_template_params(actual_param) [formal_param.deduce_template_params(actual_param)
for (formal_param, actual_param) in zip(self.templates, actual.templates)], for (formal_param, actual_param) in zip(self.templates, actual_base.templates)],
{}) {})
else: else:
return None return {}
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0, for_display = 0, dll_linkage = None, pyrex = 0,
......
...@@ -9,9 +9,12 @@ cdef extern from "cpp_template_functions_helper.h": ...@@ -9,9 +9,12 @@ cdef extern from "cpp_template_functions_helper.h":
cdef cppclass A[T]: cdef cppclass A[T]:
pair[T, U] method[U](T, U) pair[T, U] method[U](T, U)
U part_method[U](pair[T, U]) U part_method[U](pair[T, U])
U part_method_ref[U](pair[T, U]&)
cdef T nested_deduction[T](const T*) cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a) pair[T, U] pair_arg[T, U](pair[T, U] a)
cdef T* pointer_param[T](T*) cdef T* pointer_param[T](T*)
cdef cppclass double_pair(pair[double, double]):
double_pair(double, double)
def test_no_arg(): def test_no_arg():
""" """
...@@ -48,13 +51,15 @@ def test_method(int x, int y): ...@@ -48,13 +51,15 @@ def test_method(int x, int y):
def test_part_method(int x, int y): def test_part_method(int x, int y):
""" """
>>> test_part_method(5, 10) >>> test_part_method(5, 10)
(10.0, 10) (10.0, 10, 10.0)
""" """
cdef A[int] a_int cdef A[int] a_int
cdef pair[int, double] p_int = (x, y) cdef pair[int, double] p_int = (x, y)
cdef A[double] a_double cdef A[double] a_double
cdef pair[double, int] p_double = (x, y) cdef pair[double, int] p_double = (x, y)
return a_int.part_method(p_int), a_double.part_method(p_double) return (a_int.part_method(p_int),
a_double.part_method(p_double),
a_double.part_method_ref(double_pair(x, y)))
def test_simple_deduction(int x, double y): def test_simple_deduction(int x, double y):
""" """
......
...@@ -24,6 +24,10 @@ class A { ...@@ -24,6 +24,10 @@ class A {
U part_method(std::pair<T, U> p) { U part_method(std::pair<T, U> p) {
return p.second; return p.second;
} }
template <typename U>
U part_method_ref(const std::pair<T, U>& p) {
return p.second;
}
}; };
template <typename T> template <typename T>
...@@ -40,3 +44,8 @@ template <typename T> ...@@ -40,3 +44,8 @@ template <typename T>
T* pointer_param(T* param) { T* pointer_param(T* param) {
return param; return param;
} }
class double_pair : public std::pair<double, double> {
public:
double_pair(double x, double y) : std::pair<double, double>(x, y) { };
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment