Commit a91b6e27 authored by Danilo Freitas's avatar Danilo Freitas

some fixes in best_match, operators and comparisons

parent bc4c5023
...@@ -3822,7 +3822,7 @@ class UnopNode(ExprNode): ...@@ -3822,7 +3822,7 @@ class UnopNode(ExprNode):
"++": u"__inc__", "++": u"__inc__",
"--": u"__dec__", "--": u"__dec__",
"*": u"__deref__", "*": u"__deref__",
"!": u"__not__" # TODO(danilo): Also handle in NotNode. "not": u"__not__" # TODO(danilo): Also handle in NotNode.
} }
...@@ -4289,14 +4289,16 @@ class NumBinopNode(BinopNode): ...@@ -4289,14 +4289,16 @@ class NumBinopNode(BinopNode):
function = entry.type.scope.lookup(self.operators[self.operator]) function = entry.type.scope.lookup(self.operators[self.operator])
if not function: if not function:
error(self.pos, "'%s' operator not defined for '%s %s %s'" error(self.pos, "'%s' operator not defined for '%s %s %s'"
% (self.operator, type1, type2, self.operator)) % (self.operator, type1, self.operator, type2))
self.type_error()
return return
entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos) entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
if entry is None: if entry is None:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type self.type = entry.type.return_type
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
...@@ -4357,16 +4359,7 @@ class NumBinopNode(BinopNode): ...@@ -4357,16 +4359,7 @@ class NumBinopNode(BinopNode):
"&": u"__and__", "&": u"__and__",
"|": u"__or__", "|": u"__or__",
"^": u"__xor__", "^": u"__xor__",
}
# TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining).
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
} #for now
class IntBinopNode(NumBinopNode): class IntBinopNode(NumBinopNode):
...@@ -4833,6 +4826,15 @@ class CmpNode(object): ...@@ -4833,6 +4826,15 @@ class CmpNode(object):
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.compile_time_value(operand2, denv)
return result return result
def is_cpp_comparison(self):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
return type1.is_cpp_class or type2.is_cpp_class
def is_python_comparison(self): def is_python_comparison(self):
return (self.has_python_operands() return (self.has_python_operands()
or (self.cascade and self.cascade.is_python_comparison()) or (self.cascade and self.cascade.is_python_comparison())
...@@ -4965,6 +4967,9 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): ...@@ -4965,6 +4967,9 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.is_cpp_comparison():
self.analyse_cpp_comparison(env)
return
if self.cascade: if self.cascade:
self.cascade.analyse_types(env, self.operand2) self.cascade.analyse_types(env, self.operand2)
self.is_pycmp = self.is_python_comparison() self.is_pycmp = self.is_python_comparison()
...@@ -4987,6 +4992,29 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): ...@@ -4987,6 +4992,29 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
if self.is_pycmp or self.cascade: if self.is_pycmp or self.cascade:
self.is_temp = 1 self.is_temp = 1
def analyse_cpp_comparison(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
entry = env.lookup(type1.name)
function = entry.type.scope.lookup(self.operators[self.operator])
if not function:
error(self.pos, "'%s' operator not defined for '%s %s %s'"
% (self.operator, type1, self.operator, type2))
return
entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type
def check_operand_types(self, env): def check_operand_types(self, env):
self.check_types(env, self.check_types(env,
self.operand1, self.operator, self.operand2) self.operand1, self.operator, self.operand2)
...@@ -5084,6 +5112,16 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): ...@@ -5084,6 +5112,16 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
if self.cascade: if self.cascade:
self.cascade.annotate(code) self.cascade.annotate(code)
operators = {
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
}
class CascadedCmpNode(Node, CmpNode): class CascadedCmpNode(Node, CmpNode):
# A CascadedCmpNode is not a complete expression node. It # A CascadedCmpNode is not a complete expression node. It
......
...@@ -1768,6 +1768,8 @@ def best_match(args, functions, pos): ...@@ -1768,6 +1768,8 @@ def best_match(args, functions, pos):
actual_nargs = len(args) actual_nargs = len(args)
possibilities = [] possibilities = []
bad_types = 0 bad_types = 0
from_type = None
target_type = None
for func in functions: for func in functions:
func_type = func.type func_type = func.type
if func_type.is_ptr: if func_type.is_ptr:
...@@ -1806,6 +1808,8 @@ def best_match(args, functions, pos): ...@@ -1806,6 +1808,8 @@ def best_match(args, functions, pos):
score[0] += 1 score[0] += 1
else: else:
bad_types = func bad_types = func
from_type = src_type
target_type = dst_type
break break
else: else:
possibilities.append((score, func)) # so we can sort it possibilities.append((score, func)) # so we can sort it
...@@ -1816,8 +1820,8 @@ def best_match(args, functions, pos): ...@@ -1816,8 +1820,8 @@ def best_match(args, functions, pos):
return None return None
return possibilities[0][1] return possibilities[0][1]
if bad_types: if bad_types:
# This will raise the right error. error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
return func return None
else: else:
error(pos, error_str) error(pos, error_str)
return None return None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment