Commit 995d565e authored by Robert Bradshaw's avatar Robert Bradshaw

Enable type inference of template function results.

parent 1b61bc34
......@@ -4940,6 +4940,7 @@ class CallNode(ExprNode):
may_return_none = None
def infer_type(self, env):
# TODO(robertwb): Reduce redundancy with analyse_types.
function = self.function
func_type = function.infer_type(env)
if isinstance(function, NewExprNode):
......@@ -4953,6 +4954,17 @@ class CallNode(ExprNode):
if func_type.is_ptr:
func_type = func_type.base_type
if func_type.is_cfunction:
if hasattr(self.function, 'entry'):
alternatives = self.function.entry.all_alternatives()
arg_types = [arg.infer_type(env) for arg in self.args]
func_entry = PyrexTypes.best_match(
arg_types, alternatives, self.pos, env)
if func_entry:
func_type = func_entry.type
if func_type.is_ptr:
func_type = func_type.base_type
return func_type.return_type
return func_type.return_type
elif func_type is type_type:
if function.is_name and function.entry and function.entry.type:
......@@ -5173,7 +5185,8 @@ class SimpleCallNode(CallNode):
else:
alternatives = overloaded_entry.all_alternatives()
entry = PyrexTypes.best_match(args, alternatives, self.pos, env)
entry = PyrexTypes.best_match(
[arg.type for arg in args], alternatives, self.pos, env)
if not entry:
self.type = PyrexTypes.error_type
......
......@@ -4021,7 +4021,7 @@ def is_promotion(src_type, dst_type):
return src_type.is_float and src_type.rank <= dst_type.rank
return False
def best_match(args, functions, pos=None, env=None):
def best_match(arg_types, functions, pos=None, env=None):
"""
Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments.
......@@ -4044,7 +4044,7 @@ def best_match(args, functions, pos=None, env=None):
is not None, we also generate an error.
"""
# TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args)
actual_nargs = len(arg_types)
candidates = []
errors = []
......@@ -4075,7 +4075,6 @@ def best_match(args, functions, pos=None, env=None):
errors.append((func, error_mesg))
continue
if func_type.templates:
arg_types = [arg.type for arg in args]
deductions = reduce(
merge_template_deductions,
[pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
......
......@@ -846,13 +846,16 @@ class Scope(object):
obj_type = operands[0].type
method = obj_type.scope.lookup("operator%s" % operator)
if method is not None:
res = PyrexTypes.best_match(operands[1:], method.all_alternatives())
arg_types = [arg.type for arg in operands[1:]]
res = PyrexTypes.best_match([arg.type for arg in operands[1:]],
method.all_alternatives())
if res is not None:
return res
function = self.lookup("operator%s" % operator)
if function is None:
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
return PyrexTypes.best_match([arg.type for arg in operands],
function.all_alternatives())
def lookup_operator_for_types(self, pos, operator, types):
from .Nodes import Node
......
......@@ -16,8 +16,7 @@ class SignatureMatcherTest(unittest.TestCase):
Test the signature matching algorithm for overloaded signatures.
"""
def assertMatches(self, expected_type, arg_types, functions):
args = [ NameNode(None, type=arg_type) for arg_type in arg_types ]
match = pt.best_match(args, functions)
match = pt.best_match(arg_types, functions)
if expected_type is not None:
self.assertNotEqual(None, match)
self.assertEqual(expected_type, match.type)
......
......@@ -475,6 +475,7 @@ class SimpleAssignmentTypeInferer(object):
for assmt in entry.cf_assignments]
new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
if new_type != entry.type:
print "FOUND", entry, entry.type, new_type, type(new_type)
self.set_entry_type(entry, new_type)
dirty = True
return dirty
......
# tag: cpp
cimport cython
from libcpp.pair cimport pair
cdef extern from "cpp_template_functions_helper.h":
......@@ -89,3 +90,12 @@ def test_deduce_through_pointers(int k):
"""
cdef double x = k
return pointer_param(&k)[0], pointer_param(&x)[0]
def test_inference(int k):
"""
>>> test_inference(27)
27
"""
res = one_param(&k)
assert cython.typeof(res) == 'int *', cython.typeof(res)
return res[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment