Commit c599ab4b authored by Robert Bradshaw's avatar Robert Bradshaw

Consolidate best_match, minor refactoring.

parent 1f2c8742
...@@ -609,49 +609,6 @@ class ExprNode(Node): ...@@ -609,49 +609,6 @@ class ExprNode(Node):
def as_cython_attribute(self): def as_cython_attribute(self):
return None return None
def best_match(self, args, env):
entries = [env] + env.overloaded_alternatives
possibilities = []
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1].type
error(self.pos,
"Call with wrong arguments")
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
class RemoveAllocateTemps(type): class RemoveAllocateTemps(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
...@@ -2446,56 +2403,6 @@ class SimpleCallNode(CallNode): ...@@ -2446,56 +2403,6 @@ class SimpleCallNode(CallNode):
self.args.insert(0, self.coerced_self) self.args.insert(0, self.coerced_self)
self.analyse_c_function_call(env) self.analyse_c_function_call(env)
def best_match(self, args, env):
entries = [env] + env.overloaded_alternatives
actual_nargs = len(self.args)
possibilities = []
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
# Check no. of args
max_nargs = len(type.args)
expected_nargs = max_nargs - type.optional_arg_count
if actual_nargs < expected_nargs \
or (not type.has_varargs and actual_nargs > max_nargs):
continue
score = [0,0,0]
for i in range(len(self.args)):
src_type = self.args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1]
error(self.pos,
"Call with wrong arguments")
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
def function_type(self): def function_type(self):
# Return the type of the function being called, coercing a function # Return the type of the function being called, coercing a function
# pointer to a function if necessary. # pointer to a function if necessary.
...@@ -2505,8 +2412,10 @@ class SimpleCallNode(CallNode): ...@@ -2505,8 +2412,10 @@ class SimpleCallNode(CallNode):
return func_type return func_type
def analyse_c_function_call(self, env): def analyse_c_function_call(self, env):
entry = self.best_match(self.args, self.function.entry) entry = PyrexTypes.best_match(self.args, self.function.entry.all_alternatives(), self.pos)
if not entry: if not entry:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return return
self.function.entry = entry self.function.entry = entry
self.function.type = entry.type self.function.type = entry.type
...@@ -2523,23 +2432,6 @@ class SimpleCallNode(CallNode): ...@@ -2523,23 +2432,6 @@ class SimpleCallNode(CallNode):
max_nargs = len(func_type.args) max_nargs = len(func_type.args)
expected_nargs = max_nargs - func_type.optional_arg_count expected_nargs = max_nargs - func_type.optional_arg_count
actual_nargs = len(self.args) actual_nargs = len(self.args)
#if actual_nargs < expected_nargs \
# or (not func_type.has_varargs and actual_nargs > max_nargs):
# expected_str = str(expected_nargs)
# if func_type.has_varargs:
# expected_str = "at least " + expected_str
# elif func_type.optional_arg_count:
# if actual_nargs < max_nargs:
# expected_str = "at least " + expected_str
# else:
# expected_str = "at most " + str(max_nargs)
#error(self.pos,
# "Call with wrong number of arguments (expected %s, got %s)"
# % (expected_str, actual_nargs))
#self.args = None
#self.type = PyrexTypes.error_type
#self.result_code = "<error>"
#return
# Coerce arguments # Coerce arguments
for i in range(min(max_nargs, actual_nargs)): for i in range(min(max_nargs, actual_nargs)):
formal_type = func_type.args[i].type formal_type = func_type.args[i].type
...@@ -3922,7 +3814,7 @@ class UnopNode(ExprNode): ...@@ -3922,7 +3814,7 @@ class UnopNode(ExprNode):
% (self.operator, type1, type2, self.operator)) % (self.operator, type1, type2, self.operator))
self.type_error() self.type_error()
return return
self.type = self.best_match([self.operand], function) self.type = function.type.return_type
operator = { operator = {
"++": u"__inc__", "++": u"__inc__",
...@@ -4398,7 +4290,12 @@ class NumBinopNode(BinopNode): ...@@ -4398,7 +4290,12 @@ class NumBinopNode(BinopNode):
% (self.operator, type1, type2, self.operator)) % (self.operator, type1, type2, self.operator))
self.type_error() self.type_error()
return return
self.type = self.best_match([self.operand1, self.operand2], function) entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
self.type = entry.type.return_type
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2): if self.c_types_okay(type1, type2):
...@@ -4449,6 +4346,17 @@ class NumBinopNode(BinopNode): ...@@ -4449,6 +4346,17 @@ class NumBinopNode(BinopNode):
"+": u"__add__", "+": u"__add__",
"-": u"__sub__", "-": u"__sub__",
"*": u"__mul__", "*": u"__mul__",
"/": u"__div__",
"%": u"__mod__",
"<<": u"__lshift__",
">>": u"__rshift__",
"&": u"__and__",
"|": u"__or__",
"^": u"__xor__",
# TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining).
"<": u"__le__", "<": u"__le__",
">": u"__gt__", ">": u"__gt__",
"==": u"__eq__", "==": u"__eq__",
......
...@@ -6,6 +6,7 @@ from Cython.Utils import UtilityCode ...@@ -6,6 +6,7 @@ from Cython.Utils import UtilityCode
import StringEncoding import StringEncoding
import Naming import Naming
import copy import copy
from Errors import error
class BaseType(object): class BaseType(object):
# #
...@@ -1437,8 +1438,8 @@ class CppClassType(CType): ...@@ -1437,8 +1438,8 @@ class CppClassType(CType):
if other_type.is_cpp_class: if other_type.is_cpp_class:
if self == other_type: if self == other_type:
return 1 return 1
elif self.template_type == other.template_type: elif self.template_type == other_type.template_type:
for t1, t2 in zip(self.templates, other.templates): for t1, t2 in zip(self.templates, other_type.templates):
if not t1.same_as_resolved_type(t2): if not t1.same_as_resolved_type(t2):
return 0 return 0
return 1 return 1
...@@ -1454,7 +1455,10 @@ class TemplatePlaceholderType(CType): ...@@ -1454,7 +1455,10 @@ class TemplatePlaceholderType(CType):
self.name = name self.name = name
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
if entity_code:
return self.name + " " + entity_code return self.name + " " + entity_code
else:
return self.name
def specialize(self, values): def specialize(self, values):
if self in values: if self in values:
...@@ -1464,7 +1468,7 @@ class TemplatePlaceholderType(CType): ...@@ -1464,7 +1468,7 @@ class TemplatePlaceholderType(CType):
def same_as_resolved_type(self, other_type): def same_as_resolved_type(self, other_type):
if isinstance(other_type, TemplatePlaceholderType): if isinstance(other_type, TemplatePlaceholderType):
return self.name == other.name return self.name == other_type.name
else: else:
return 0 return 0
...@@ -1736,6 +1740,54 @@ def is_promotion(type, other_type): ...@@ -1736,6 +1740,54 @@ def is_promotion(type, other_type):
or (type.is_float and other_type.is_float) \ or (type.is_float and other_type.is_float) \
or (type.is_enum and other_type.is_int) or (type.is_enum and other_type.is_int)
def best_match(args, functions, pos):
actual_nargs = len(args)
possibilities = []
bad_types = 0
for func in functions:
func_type = func.type
if func_type.is_ptr:
func_type = func_type.base_type
# Check no. of args
max_nargs = len(func_type.args)
min_nargs = max_nargs - func_type.optional_arg_count
if actual_nargs < min_nargs \
or (not func_type.has_varargs and actual_nargs > max_nargs):
continue
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
dst_type = func_type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
bad_types = func
break
else:
possibilities.append((score, func)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(pos, "ambiguous overloaded method")
return None
return possibilities[0][1]
if bad_types:
# This will raise the right error.
return func
else:
error(pos, "Call with wrong number of arguments (expected %s, got %s)"
% (expected_str, actual_nargs))
return None
def widest_numeric_type(type1, type2): def widest_numeric_type(type1, type2):
# Given two numeric types, return the narrowest type # Given two numeric types, return the narrowest type
# encompassing both of them. # encompassing both of them.
......
...@@ -178,6 +178,9 @@ class Entry(object): ...@@ -178,6 +178,9 @@ class Entry(object):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here") error(self.pos, "Previous declaration is here")
def all_alternatives(self):
return [self] + self.overloaded_alternatives
class Scope(object): class Scope(object):
# name string Unqualified name # name string Unqualified name
# outer_scope Scope or None Enclosing scope # outer_scope Scope or None Enclosing scope
...@@ -1621,7 +1624,7 @@ class CppClassScope(Scope): ...@@ -1621,7 +1624,7 @@ class CppClassScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'extern', defining = 0, cname = None, visibility = 'extern', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = ()):
self.declare_var(name, type, pos, cname, visibility) entry = self.declare_var(name, type, pos, cname, visibility)
def declare_inherited_cpp_attributes(self, base_scope): def declare_inherited_cpp_attributes(self, base_scope):
# Declare entries for all the C++ attributes of an # Declare entries for all the C++ attributes of an
...@@ -1642,7 +1645,11 @@ class CppClassScope(Scope): ...@@ -1642,7 +1645,11 @@ class CppClassScope(Scope):
def specialize(self, values): def specialize(self, values):
scope = CppClassScope() scope = CppClassScope()
for entry in self.entries.values(): for entry in self.entries.values():
scope.declare_var(entry.name, entry.type.specialize(values), entry.pos, entry.cname, entry.visibility) scope.declare_var(entry.name,
entry.type.specialize(values),
entry.pos,
entry.cname,
entry.visibility)
return scope return scope
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment