Commit 995d565e authored by Robert Bradshaw's avatar Robert Bradshaw

Enable type inference of template function results.

parent 1b61bc34
...@@ -4940,6 +4940,7 @@ class CallNode(ExprNode): ...@@ -4940,6 +4940,7 @@ class CallNode(ExprNode):
may_return_none = None may_return_none = None
def infer_type(self, env): def infer_type(self, env):
# TODO(robertwb): Reduce redundancy with analyse_types.
function = self.function function = self.function
func_type = function.infer_type(env) func_type = function.infer_type(env)
if isinstance(function, NewExprNode): if isinstance(function, NewExprNode):
...@@ -4953,6 +4954,17 @@ class CallNode(ExprNode): ...@@ -4953,6 +4954,17 @@ class CallNode(ExprNode):
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
if func_type.is_cfunction: if func_type.is_cfunction:
if hasattr(self.function, 'entry'):
alternatives = self.function.entry.all_alternatives()
arg_types = [arg.infer_type(env) for arg in self.args]
func_entry = PyrexTypes.best_match(
arg_types, alternatives, self.pos, env)
if func_entry:
func_type = func_entry.type
if func_type.is_ptr:
func_type = func_type.base_type
return func_type.return_type
return func_type.return_type return func_type.return_type
elif func_type is type_type: elif func_type is type_type:
if function.is_name and function.entry and function.entry.type: if function.is_name and function.entry and function.entry.type:
...@@ -5173,7 +5185,8 @@ class SimpleCallNode(CallNode): ...@@ -5173,7 +5185,8 @@ class SimpleCallNode(CallNode):
else: else:
alternatives = overloaded_entry.all_alternatives() alternatives = overloaded_entry.all_alternatives()
entry = PyrexTypes.best_match(args, alternatives, self.pos, env) entry = PyrexTypes.best_match(
[arg.type for arg in args], alternatives, self.pos, env)
if not entry: if not entry:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
......
...@@ -4021,7 +4021,7 @@ def is_promotion(src_type, dst_type): ...@@ -4021,7 +4021,7 @@ def is_promotion(src_type, dst_type):
return src_type.is_float and src_type.rank <= dst_type.rank return src_type.is_float and src_type.rank <= dst_type.rank
return False return False
def best_match(args, functions, pos=None, env=None): def best_match(arg_types, functions, pos=None, env=None):
""" """
Given a list args of arguments and a list of functions, choose one Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments. to call which seems to be the "best" fit for this list of arguments.
...@@ -4044,7 +4044,7 @@ def best_match(args, functions, pos=None, env=None): ...@@ -4044,7 +4044,7 @@ def best_match(args, functions, pos=None, env=None):
is not None, we also generate an error. is not None, we also generate an error.
""" """
# TODO: args should be a list of types, not a list of Nodes. # TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(args) actual_nargs = len(arg_types)
candidates = [] candidates = []
errors = [] errors = []
...@@ -4075,7 +4075,6 @@ def best_match(args, functions, pos=None, env=None): ...@@ -4075,7 +4075,6 @@ def best_match(args, functions, pos=None, env=None):
errors.append((func, error_mesg)) errors.append((func, error_mesg))
continue continue
if func_type.templates: if func_type.templates:
arg_types = [arg.type for arg in args]
deductions = reduce( deductions = reduce(
merge_template_deductions, merge_template_deductions,
[pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)], [pattern.type.deduce_template_params(actual) for (pattern, actual) in zip(func_type.args, arg_types)],
......
...@@ -846,13 +846,16 @@ class Scope(object): ...@@ -846,13 +846,16 @@ class Scope(object):
obj_type = operands[0].type obj_type = operands[0].type
method = obj_type.scope.lookup("operator%s" % operator) method = obj_type.scope.lookup("operator%s" % operator)
if method is not None: if method is not None:
res = PyrexTypes.best_match(operands[1:], method.all_alternatives()) arg_types = [arg.type for arg in operands[1:]]
res = PyrexTypes.best_match([arg.type for arg in operands[1:]],
method.all_alternatives())
if res is not None: if res is not None:
return res return res
function = self.lookup("operator%s" % operator) function = self.lookup("operator%s" % operator)
if function is None: if function is None:
return None return None
return PyrexTypes.best_match(operands, function.all_alternatives()) return PyrexTypes.best_match([arg.type for arg in operands],
function.all_alternatives())
def lookup_operator_for_types(self, pos, operator, types): def lookup_operator_for_types(self, pos, operator, types):
from .Nodes import Node from .Nodes import Node
......
...@@ -16,8 +16,7 @@ class SignatureMatcherTest(unittest.TestCase): ...@@ -16,8 +16,7 @@ class SignatureMatcherTest(unittest.TestCase):
Test the signature matching algorithm for overloaded signatures. Test the signature matching algorithm for overloaded signatures.
""" """
def assertMatches(self, expected_type, arg_types, functions): def assertMatches(self, expected_type, arg_types, functions):
args = [ NameNode(None, type=arg_type) for arg_type in arg_types ] match = pt.best_match(arg_types, functions)
match = pt.best_match(args, functions)
if expected_type is not None: if expected_type is not None:
self.assertNotEqual(None, match) self.assertNotEqual(None, match)
self.assertEqual(expected_type, match.type) self.assertEqual(expected_type, match.type)
......
...@@ -475,6 +475,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -475,6 +475,7 @@ class SimpleAssignmentTypeInferer(object):
for assmt in entry.cf_assignments] for assmt in entry.cf_assignments]
new_type = spanning_type(types, entry.might_overflow, entry.pos, scope) new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
if new_type != entry.type: if new_type != entry.type:
print "FOUND", entry, entry.type, new_type, type(new_type)
self.set_entry_type(entry, new_type) self.set_entry_type(entry, new_type)
dirty = True dirty = True
return dirty return dirty
......
# tag: cpp # tag: cpp
cimport cython
from libcpp.pair cimport pair from libcpp.pair cimport pair
cdef extern from "cpp_template_functions_helper.h": cdef extern from "cpp_template_functions_helper.h":
...@@ -89,3 +90,12 @@ def test_deduce_through_pointers(int k): ...@@ -89,3 +90,12 @@ def test_deduce_through_pointers(int k):
""" """
cdef double x = k cdef double x = k
return pointer_param(&k)[0], pointer_param(&x)[0] return pointer_param(&k)[0], pointer_param(&x)[0]
def test_inference(int k):
"""
>>> test_inference(27)
27
"""
res = one_param(&k)
assert cython.typeof(res) == 'int *', cython.typeof(res)
return res[0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment