Commit 91d2df2b authored by Robert Bradshaw's avatar Robert Bradshaw

More template parameter deduction.

parent 92cf54f9
...@@ -1239,6 +1239,9 @@ class CConstType(BaseType): ...@@ -1239,6 +1239,9 @@ class CConstType(BaseType):
else: else:
return CConstType(base_type) return CConstType(base_type)
def deduce_template_params(self, actual):
return self.const_base_type.deduce_template_params(actual)
def create_to_py_utility_code(self, env): def create_to_py_utility_code(self, env):
if self.const_base_type.create_to_py_utility_code(env): if self.const_base_type.create_to_py_utility_code(env):
self.to_py_function = self.const_base_type.to_py_function self.to_py_function = self.const_base_type.to_py_function
...@@ -2178,6 +2181,19 @@ class CArrayType(CPointerBaseType): ...@@ -2178,6 +2181,19 @@ class CArrayType(CPointerBaseType):
def is_complete(self): def is_complete(self):
return self.size is not None return self.size is not None
def specialize(self, values):
base_type = self.base_type.specialize(values)
if base_type == self.base_type:
return self
else:
return CArrayType(base_type)
def deduce_template_params(self, actual):
if isinstance(actual, CArrayType):
return self.base_type.deduce_template_params(actual.base_type)
else:
return None
class CPtrType(CPointerBaseType): class CPtrType(CPointerBaseType):
# base_type CType Reference type # base_type CType Reference type
...@@ -2239,6 +2255,12 @@ class CPtrType(CPointerBaseType): ...@@ -2239,6 +2255,12 @@ class CPtrType(CPointerBaseType):
else: else:
return CPtrType(base_type) return CPtrType(base_type)
def deduce_template_params(self, actual):
if isinstance(actual, CPtrType):
return self.base_type.deduce_template_params(actual.base_type)
else:
return None
def invalid_value(self): def invalid_value(self):
return "1" return "1"
...@@ -2279,6 +2301,9 @@ class CReferenceType(BaseType): ...@@ -2279,6 +2301,9 @@ class CReferenceType(BaseType):
else: else:
return CReferenceType(base_type) return CReferenceType(base_type)
def deduce_template_params(self, actual):
return self.ref_base_type.deduce_template_params(actual)
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.ref_base_type, name) return getattr(self.ref_base_type, name)
...@@ -3083,6 +3108,18 @@ class CppClassType(CType): ...@@ -3083,6 +3108,18 @@ class CppClassType(CType):
specialized.namespace = self.namespace.specialize(values) specialized.namespace = self.namespace.specialize(values)
return specialized return specialized
def deduce_template_params(self, actual):
if self == actual:
return {}
# TODO(robertwb): Actual type equality.
elif self.declaration_code("") == actual.template_type.declaration_code(""):
return reduce(
merge_template_deductions,
[formal_param.deduce_template_params(actual_param) for (formal_param, actual_param) in zip(self.templates, actual.templates)],
{})
else:
return None
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if self.templates: if self.templates:
...@@ -3502,8 +3539,9 @@ def best_match(args, functions, pos=None, env=None): ...@@ -3502,8 +3539,9 @@ def best_match(args, functions, pos=None, env=None):
merge_template_deductions, merge_template_deductions,
[pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)], [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
{}) {})
if deductions is not None: if deductions is None:
if len(deductions) < len(func_type.templates): errors.append((func, "Unable to deduce type parameters"))
elif len(deductions) < len(func_type.templates):
errors.append((func, "Unable to deduce type parameter %s" % ( errors.append((func, "Unable to deduce type parameter %s" % (
", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())])))) ", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
else: else:
......
...@@ -8,6 +8,8 @@ cdef extern from "cpp_template_functions_helper.h": ...@@ -8,6 +8,8 @@ cdef extern from "cpp_template_functions_helper.h":
cdef pair[T, U] two_params[T, U](T, U) cdef pair[T, U] two_params[T, U](T, U)
cdef cppclass A[T]: cdef cppclass A[T]:
pair[T, U] method[U](T, U) pair[T, U] method[U](T, U)
cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a)
def test_no_arg(): def test_no_arg():
""" """
...@@ -46,3 +48,18 @@ def test_simple_deduction(int x, double y): ...@@ -46,3 +48,18 @@ def test_simple_deduction(int x, double y):
(1, 2.0) (1, 2.0)
""" """
return one_param(x), one_param(y) return one_param(x), one_param(y)
def test_more_deductions(int x, double y):
"""
>>> test_more_deductions(1, 2)
(1, 2.0)
"""
return nested_deduction(&x), nested_deduction(&y)
def test_class_deductions(pair[long, double] x):
"""
>>> test_class_deductions((1, 1.5))
(1, 1.5)
"""
return pair_arg(x)
...@@ -21,3 +21,13 @@ class A { ...@@ -21,3 +21,13 @@ class A {
return std::pair<T, U>(a, b); return std::pair<T, U>(a, b);
} }
}; };
template <typename T>
T nested_deduction(const T *a) {
return *a;
}
template <typename T, typename U>
std::pair<T, U> pair_arg(std::pair<T, U> a) {
return a;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment