Commit a91b6e27 authored by Danilo Freitas's avatar Danilo Freitas

some fixes in best_match, operators and comparisons

parent bc4c5023
......@@ -3822,7 +3822,7 @@ class UnopNode(ExprNode):
"++": u"__inc__",
"--": u"__dec__",
"*": u"__deref__",
"!": u"__not__" # TODO(danilo): Also handle in NotNode.
"not": u"__not__" # TODO(danilo): Also handle in NotNode.
}
......@@ -4289,15 +4289,17 @@ class NumBinopNode(BinopNode):
function = entry.type.scope.lookup(self.operators[self.operator])
if not function:
error(self.pos, "'%s' operator not defined for '%s %s %s'"
% (self.operator, type1, type2, self.operator))
self.type_error()
% (self.operator, type1, self.operator, type2))
return
entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
self.type = entry.type.return_type
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type
def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2):
......@@ -4356,17 +4358,8 @@ class NumBinopNode(BinopNode):
"&": u"__and__",
"|": u"__or__",
"^": u"__xor__",
# TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining).
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
} #for now
"^": u"__xor__",
}
class IntBinopNode(NumBinopNode):
......@@ -4833,6 +4826,15 @@ class CmpNode(object):
result = result and cascade.compile_time_value(operand2, denv)
return result
def is_cpp_comparison(self):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
return type1.is_cpp_class or type2.is_cpp_class
def is_python_comparison(self):
return (self.has_python_operands()
or (self.cascade and self.cascade.is_python_comparison())
......@@ -4965,6 +4967,9 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
def analyse_types(self, env):
self.operand1.analyse_types(env)
self.operand2.analyse_types(env)
if self.is_cpp_comparison():
self.analyse_cpp_comparison(env)
return
if self.cascade:
self.cascade.analyse_types(env, self.operand2)
self.is_pycmp = self.is_python_comparison()
......@@ -4987,6 +4992,29 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
if self.is_pycmp or self.cascade:
self.is_temp = 1
def analyse_cpp_comparison(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
entry = env.lookup(type1.name)
function = entry.type.scope.lookup(self.operators[self.operator])
if not function:
error(self.pos, "'%s' operator not defined for '%s %s %s'"
% (self.operator, type1, self.operator, type2))
return
entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type
def check_operand_types(self, env):
self.check_types(env,
self.operand1, self.operator, self.operand2)
......@@ -5083,6 +5111,16 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
self.operand2.annotate(code)
if self.cascade:
self.cascade.annotate(code)
operators = {
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
}
class CascadedCmpNode(Node, CmpNode):
......
......@@ -1768,6 +1768,8 @@ def best_match(args, functions, pos):
actual_nargs = len(args)
possibilities = []
bad_types = 0
from_type = None
target_type = None
for func in functions:
func_type = func.type
if func_type.is_ptr:
......@@ -1806,6 +1808,8 @@ def best_match(args, functions, pos):
score[0] += 1
else:
bad_types = func
from_type = src_type
target_type = dst_type
break
else:
possibilities.append((score, func)) # so we can sort it
......@@ -1816,8 +1820,8 @@ def best_match(args, functions, pos):
return None
return possibilities[0][1]
if bad_types:
# This will raise the right error.
return func
error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
return None
else:
error(pos, error_str)
return None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment