Commit daa03512 authored by Xavier Thompson's avatar Xavier Thompson

Detect ambiguous cpp/py operations at compile time instead of falling back on Python runtime

parent 3dc38b5c
...@@ -11522,7 +11522,19 @@ class BinopNode(ExprNode): ...@@ -11522,7 +11522,19 @@ class BinopNode(ExprNode):
def analyse_cpp_py_operation(self, env): def analyse_cpp_py_operation(self, env):
operator = self.operator if not self.inplace else self.operator+"=" operator = self.operator if not self.inplace else self.operator+"="
entry = env.lookup_operator(operator, [self.operand1, self.operand2]) entry = None
try:
entry = env.lookup_operator(operator, [self.operand1, self.operand2], throw=True)
except (PyrexTypes.AmbiguousCallException, PyrexTypes.NoTypeMatchCallException):
error(self.pos, ("Ambiguous operation with PyObject operand\n"
"To select one of the alternatives, explicitly cast the PyObject operand\n"
"To let Python handle the operation instead, cast the other operand to 'object'"
"\n"))
self.type = PyrexTypes.error_type
return
except PyrexTypes.CallException:
pass
if entry: if entry:
self.analyse_cpp_operation(env) self.analyse_cpp_operation(env)
else: else:
......
...@@ -4755,7 +4755,19 @@ def is_promotion(src_type, dst_type): ...@@ -4755,7 +4755,19 @@ def is_promotion(src_type, dst_type):
return src_type.is_float and src_type.rank <= dst_type.rank return src_type.is_float and src_type.rank <= dst_type.rank
return False return False
def best_match(arg_types, functions, pos=None, env=None, args=None): class CallException(Exception):
pass
class NoCandidateCallException(CallException):
pass
class AmbiguousCallException(CallException):
pass
class NoTypeMatchCallException(CallException):
pass
def best_match(arg_types, functions, pos=None, env=None, args=None, throw=False):
""" """
Given a list args of arguments and a list of functions, choose one Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments. to call which seems to be the "best" fit for this list of arguments.
...@@ -4776,6 +4788,12 @@ def best_match(arg_types, functions, pos=None, env=None, args=None): ...@@ -4776,6 +4788,12 @@ def best_match(arg_types, functions, pos=None, env=None, args=None):
If no function is deemed a good fit, or if two or more functions have If no function is deemed a good fit, or if two or more functions have
the same weight, we return None (as there is no best match). If pos the same weight, we return None (as there is no best match). If pos
is not None, we also generate an error. is not None, we also generate an error.
If throw is True, an exception is raised instead of returning None:
* NoCandidateCallException: when no candidate is found
* AmbiguousCallException: when the call is ambiguous
* NoTypeMatchCallException: when no candidate matches
This allows the caller to determine why no best match was found.
""" """
# TODO: args should be a list of types, not a list of Nodes. # TODO: args should be a list of types, not a list of Nodes.
actual_nargs = len(arg_types) actual_nargs = len(arg_types)
...@@ -4833,12 +4851,16 @@ def best_match(arg_types, functions, pos=None, env=None, args=None): ...@@ -4833,12 +4851,16 @@ def best_match(arg_types, functions, pos=None, env=None, args=None):
if len(candidates) == 1: if len(candidates) == 1:
return candidates[0][0] return candidates[0][0]
elif len(candidates) == 0: elif len(candidates) == 0:
if pos is not None: if pos is not None or throw:
func, errmsg = errors[0] func, errmsg = errors[0]
if len(errors) == 1 or [1 for func, e in errors if e == errmsg]: if len(errors) == 1 or [1 for func, e in errors if e == errmsg]:
error(pos, errmsg) pass
else: else:
error(pos, "no suitable method found") errmsg = "no suitable method found"
if pos is not None:
error(pos, errmsg)
if throw:
raise NoCandidateCallException(errmsg)
return None return None
possibilities = [] possibilities = []
...@@ -4908,8 +4930,11 @@ def best_match(arg_types, functions, pos=None, env=None, args=None): ...@@ -4908,8 +4930,11 @@ def best_match(arg_types, functions, pos=None, env=None, args=None):
score1 = possibilities[0][0] score1 = possibilities[0][0]
score2 = possibilities[1][0] score2 = possibilities[1][0]
if score1 == score2: if score1 == score2:
errmsg = "ambiguous overloaded method"
if pos is not None: if pos is not None:
error(pos, "ambiguous overloaded method") error(pos, errmsg)
if throw:
raise AmbiguousCallException(errmsg)
return None return None
function = possibilities[0][-1] function = possibilities[0][-1]
...@@ -4920,12 +4945,16 @@ def best_match(arg_types, functions, pos=None, env=None, args=None): ...@@ -4920,12 +4945,16 @@ def best_match(arg_types, functions, pos=None, env=None, args=None):
return function return function
if len(bad_types) == 1:
errmsg = bad_types[0][1]
else:
errmsg = "no suitable method found"
if pos is not None: if pos is not None:
if len(bad_types) == 1: error(pos, errmsg)
error(pos, bad_types[0][1])
else:
error(pos, "no suitable method found")
if throw:
raise NoTypeMatchCallException(errmsg)
return None return None
def merge_template_deductions(a, b): def merge_template_deductions(a, b):
......
...@@ -1205,13 +1205,13 @@ class Scope(object): ...@@ -1205,13 +1205,13 @@ class Scope(object):
return entry.type.specialize(self.fused_to_specific) return entry.type.specialize(self.fused_to_specific)
return entry.type return entry.type
def lookup_operator(self, operator, operands): def lookup_operator(self, operator, operands, throw=False):
if operands[0].type.is_cpp_class: if operands[0].type.is_cpp_class:
obj_type = operands[0].type obj_type = operands[0].type
method = obj_type.scope.lookup("operator%s" % operator) method = obj_type.scope.lookup("operator%s" % operator)
if method is not None: if method is not None:
arg_types = [arg.type for arg in operands[1:]] arg_types = [arg.type for arg in operands[1:]]
res = PyrexTypes.best_match(arg_types, method.all_alternatives()) res = PyrexTypes.best_match(arg_types, method.all_alternatives(), throw=throw)
if res is not None: if res is not None:
return res return res
function = self.lookup("operator%s" % operator) function = self.lookup("operator%s" % operator)
...@@ -1236,7 +1236,7 @@ class Scope(object): ...@@ -1236,7 +1236,7 @@ class Scope(object):
all_alternatives = list(set(method_alternatives + function_alternatives)) all_alternatives = list(set(method_alternatives + function_alternatives))
return PyrexTypes.best_match([arg.type for arg in operands], return PyrexTypes.best_match([arg.type for arg in operands],
all_alternatives) all_alternatives, throw=throw)
def lookup_operator_for_types(self, pos, operator, types): def lookup_operator_for_types(self, pos, operator, types):
from .Nodes import Node from .Nodes import Node
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment