Commit 86f49e42 authored by Danilo Freitas's avatar Danilo Freitas

Overloading operators

parent b8860831
......@@ -609,6 +609,49 @@ class ExprNode(Node):
def as_cython_attribute(self):
return None
def best_match(self, args, env):
entries = [env] + env.overloaded_alternatives
possibilities = []
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1].type
error(self.pos,
"Call with wrong arguments")
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
class RemoveAllocateTemps(type):
def __init__(cls, name, bases, dct):
......@@ -2390,8 +2433,8 @@ class SimpleCallNode(CallNode):
self.args.insert(0, self.coerced_self)
self.analyse_c_function_call(env)
def best_match(self):
entries = [self.function.entry] + self.function.entry.overloaded_alternatives
def best_match(self, args, env):
entries = [env] + env.overloaded_alternatives
actual_nargs = len(self.args)
possibilities = []
for entry in entries:
......@@ -2449,7 +2492,7 @@ class SimpleCallNode(CallNode):
return func_type
def analyse_c_function_call(self, env):
entry = self.best_match()
entry = self.best_match(self.args, self.function.entry)
if not entry:
return
self.function.entry = entry
......@@ -3815,6 +3858,8 @@ class UnopNode(ExprNode):
self.type = py_object_type
self.gil_check(env)
self.is_temp = 1
elif self.is_cpp_operation:
self.analyse_cpp_operation
else:
self.analyse_c_operation(env)
......@@ -3823,6 +3868,9 @@ class UnopNode(ExprNode):
def is_py_operation(self):
return self.operand.type.is_pyobject
def is_cpp_operation(self):
return self.operand.type.is_cpp_class
def coerce_operand_to_pyobject(self, env):
self.operand = self.operand.coerce_to_pyobject(env)
......@@ -3850,6 +3898,27 @@ class UnopNode(ExprNode):
(self.operator, self.operand.type))
self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env):
type = operand.type
if type.is_ptr:
type = type.base_type
entry = env.lookup(type.name)
function = entry.type.scope.lookup(self.operators[self.operator])
if not function:
error(self.pos, "'%s' operator not defined for %s"
% (self.operator, type1, type2, self.operator))
self.type_error()
return
self.type = self.best_match([self.operand], function)
operator = {
"++": u"__inc__",
"--": u"__dec__",
"*": u"__deref__",
"!": u"__not__"
}
class NotNode(ExprNode):
# 'not' operator
......@@ -4316,7 +4385,7 @@ class NumBinopNode(BinopNode):
% (self.operator, type1, type2, self.operator))
self.type_error()
return
self.type = self.best_match(function)
self.type = self.best_match([self.operand1, self.operand2], function)
def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2):
......@@ -4347,50 +4416,6 @@ class NumBinopNode(BinopNode):
def py_operation_function(self):
return self.py_functions[self.operator]
def best_match(self, env):
entries = [env] + env.overloaded_alternatives
possibilities = []
args = [self.operand1, self.operand2]
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1].type
error(self.pos,
"Call with wrong arguments")# (expected %s, got %s)"
#% (expected_str, actual_nargs))
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
py_functions = {
"|": "PyNumber_Or",
......@@ -4410,7 +4435,14 @@ class NumBinopNode(BinopNode):
operators = {
"+": u"__add__",
"-": u"__sub__",
"*": u"__mul__"
"*": u"__mul__",
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
} #for now
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment