Commit 92cf54f9 authored by Robert Bradshaw's avatar Robert Bradshaw

simple template deduction

parent 52eaf5ce
......@@ -81,6 +81,18 @@ class BaseType(object):
is_fused = property(_get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type")
def deduce_template_params(self, actual):
"""
Deduce any template params in this (argument) type given the actual
argument type.
http://en.cppreference.com/w/cpp/language/function_template#Template_argument_deduction
"""
if self == actual:
return {}
else:
return None
def __lt__(self, other):
"""
For sorting. The sorting order should correspond to the preference of
......@@ -3162,6 +3174,9 @@ class TemplatePlaceholderType(CType):
else:
return self
def deduce_template_params(self, actual):
return {self: actual}
def same_as_resolved_type(self, other_type):
if isinstance(other_type, TemplatePlaceholderType):
return self.name == other_type.name
......@@ -3481,7 +3496,27 @@ def best_match(args, functions, pos=None, env=None):
% (expectation, actual_nargs)
errors.append((func, error_mesg))
continue
candidates.append((func, func_type))
if func_type.templates:
arg_types = [arg.type for arg in args]
deductions = reduce(
merge_template_deductions,
[pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
{})
if deductions is not None:
if len(deductions) < len(func_type.templates):
errors.append((func, "Unable to deduce type parameter %s" % (
", ".join([param.name for param in set(func_type.templates) - set(deductions.keys())]))))
else:
type_list = [deductions[param] for param in func_type.templates]
from Symtab import Entry
specialization = Entry(
name = func.name + "[%s]" % ",".join([str(t) for t in type_list]),
cname = func.cname + "<%s>" % ",".join([t.declaration_code("") for t in type_list]),
type = func_type.specialize(deductions),
pos = func.pos)
candidates.append((specialization, specialization.type))
else:
candidates.append((func, func_type))
# Optimize the most common case of no overloading...
if len(candidates) == 1:
......@@ -3573,6 +3608,18 @@ def best_match(args, functions, pos=None, env=None):
return None
def merge_template_deductions(a, b):
if a is None or b is None:
return None
all = a
for param, value in b.iteritems():
if param in all:
if a[param] != b[param]:
return None
else:
all[param] = value
return all
def widest_numeric_type(type1, type2):
# Given two numeric types, return the narrowest type
# encompassing both of them.
......
......@@ -3,11 +3,19 @@
from libcpp.pair cimport pair
cdef extern from "cpp_template_functions_helper.h":
cdef T no_arg[T]()
cdef T one_param[T](T)
cdef pair[T, U] two_params[T, U](T, U)
cdef cppclass A[T]:
pair[T, U] method[U](T, U)
def test_no_arg():
"""
>>> test_no_arg()
0
"""
return no_arg[int]()
def test_one_param(int x):
"""
>>> test_one_param(3)
......@@ -31,3 +39,10 @@ def test_method(int x, int y):
cdef A[double] a_double
return a_int.method[float](x, y), a_double.method[int](x, y)
# return a_int.method[double](x, y), a_double.method[int](x, y)
def test_simple_deduction(int x, double y):
"""
>>> test_simple_deduction(1, 2)
(1, 2.0)
"""
return one_param(x), one_param(y)
template <typename T>
T no_arg() {
return T();
}
template <typename T>
T one_param(T value) {
return value;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment