Commit 91d2df2b authored by Robert Bradshaw's avatar Robert Bradshaw

More template parameter deduction.

parent 92cf54f9
......@@ -1239,6 +1239,9 @@ class CConstType(BaseType):
else:
return CConstType(base_type)
def deduce_template_params(self, actual):
return self.const_base_type.deduce_template_params(actual)
def create_to_py_utility_code(self, env):
if self.const_base_type.create_to_py_utility_code(env):
self.to_py_function = self.const_base_type.to_py_function
......@@ -2178,6 +2181,19 @@ class CArrayType(CPointerBaseType):
def is_complete(self):
return self.size is not None
def specialize(self, values):
base_type = self.base_type.specialize(values)
if base_type == self.base_type:
return self
else:
return CArrayType(base_type)
def deduce_template_params(self, actual):
if isinstance(actual, CArrayType):
return self.base_type.deduce_template_params(actual.base_type)
else:
return None
class CPtrType(CPointerBaseType):
# base_type CType Reference type
......@@ -2239,6 +2255,12 @@ class CPtrType(CPointerBaseType):
else:
return CPtrType(base_type)
def deduce_template_params(self, actual):
if isinstance(actual, CPtrType):
return self.base_type.deduce_template_params(actual.base_type)
else:
return None
def invalid_value(self):
return "1"
......@@ -2279,6 +2301,9 @@ class CReferenceType(BaseType):
else:
return CReferenceType(base_type)
def deduce_template_params(self, actual):
return self.ref_base_type.deduce_template_params(actual)
def __getattr__(self, name):
return getattr(self.ref_base_type, name)
......@@ -3083,6 +3108,18 @@ class CppClassType(CType):
specialized.namespace = self.namespace.specialize(values)
return specialized
def deduce_template_params(self, actual):
if self == actual:
return {}
# TODO(robertwb): Actual type equality.
elif self.declaration_code("") == actual.template_type.declaration_code(""):
return reduce(
merge_template_deductions,
[formal_param.deduce_template_params(actual_param) for (formal_param, actual_param) in zip(self.templates, actual.templates)],
{})
else:
return None
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
if self.templates:
......@@ -3502,19 +3539,20 @@ def best_match(args, functions, pos=None, env=None):
merge_template_deductions,
[pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
{})
if deductions is not None:
if len(deductions) < len(func_type.templates):
errors.append((func, "Unable to deduce type parameter %s" % (
", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
else:
type_list = [deductions[param] for param in func_type.templates]
from Symtab import Entry
specialization = Entry(
name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
type = func_type.specialize(deductions),
pos = func.pos)
candidates.append((specialization, specialization.type))
if deductions is None:
errors.append((func, "Unable to deduce type parameters"))
elif len(deductions) < len(func_type.templates):
errors.append((func, "Unable to deduce type parameter %s" % (
", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
else:
type_list = [deductions[param] for param in func_type.templates]
from Symtab import Entry
specialization = Entry(
name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
type = func_type.specialize(deductions),
pos = func.pos)
candidates.append((specialization, specialization.type))
else:
candidates.append((func, func_type))
......
......@@ -8,6 +8,8 @@ cdef extern from "cpp_template_functions_helper.h":
cdef pair[T, U] two_params[T, U](T, U)
cdef cppclass A[T]:
pair[T, U] method[U](T, U)
cdef T nested_deduction[T](const T*)
pair[T, U] pair_arg[T, U](pair[T, U] a)
def test_no_arg():
"""
......@@ -46,3 +48,18 @@ def test_simple_deduction(int x, double y):
(1, 2.0)
"""
return one_param(x), one_param(y)
def test_more_deductions(int x, double y):
"""
>>> test_more_deductions(1, 2)
(1, 2.0)
"""
return nested_deduction(&x), nested_deduction(&y)
def test_class_deductions(pair[long, double] x):
"""
>>> test_class_deductions((1, 1.5))
(1, 1.5)
"""
return pair_arg(x)
......@@ -21,3 +21,13 @@ class A {
return std::pair<T, U>(a, b);
}
};
template <typename T>
T nested_deduction(const T *a) {
return *a;
}
template <typename T, typename U>
std::pair<T, U> pair_arg(std::pair<T, U> a) {
return a;
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment