Commit 86f49e42 authored by Danilo Freitas's avatar Danilo Freitas

Overloading operators

parent b8860831
...@@ -609,6 +609,49 @@ class ExprNode(Node): ...@@ -609,6 +609,49 @@ class ExprNode(Node):
def as_cython_attribute(self): def as_cython_attribute(self):
return None return None
def best_match(self, args, env):
entries = [env] + env.overloaded_alternatives
possibilities = []
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1].type
error(self.pos,
"Call with wrong arguments")
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
class RemoveAllocateTemps(type): class RemoveAllocateTemps(type):
def __init__(cls, name, bases, dct): def __init__(cls, name, bases, dct):
...@@ -2390,8 +2433,8 @@ class SimpleCallNode(CallNode): ...@@ -2390,8 +2433,8 @@ class SimpleCallNode(CallNode):
self.args.insert(0, self.coerced_self) self.args.insert(0, self.coerced_self)
self.analyse_c_function_call(env) self.analyse_c_function_call(env)
def best_match(self): def best_match(self, args, env):
entries = [self.function.entry] + self.function.entry.overloaded_alternatives entries = [env] + env.overloaded_alternatives
actual_nargs = len(self.args) actual_nargs = len(self.args)
possibilities = [] possibilities = []
for entry in entries: for entry in entries:
...@@ -2449,7 +2492,7 @@ class SimpleCallNode(CallNode): ...@@ -2449,7 +2492,7 @@ class SimpleCallNode(CallNode):
return func_type return func_type
def analyse_c_function_call(self, env): def analyse_c_function_call(self, env):
entry = self.best_match() entry = self.best_match(self.args, self.function.entry)
if not entry: if not entry:
return return
self.function.entry = entry self.function.entry = entry
...@@ -3815,6 +3858,8 @@ class UnopNode(ExprNode): ...@@ -3815,6 +3858,8 @@ class UnopNode(ExprNode):
self.type = py_object_type self.type = py_object_type
self.gil_check(env) self.gil_check(env)
self.is_temp = 1 self.is_temp = 1
elif self.is_cpp_operation:
self.analyse_cpp_operation
else: else:
self.analyse_c_operation(env) self.analyse_c_operation(env)
...@@ -3824,6 +3869,9 @@ class UnopNode(ExprNode): ...@@ -3824,6 +3869,9 @@ class UnopNode(ExprNode):
def is_py_operation(self): def is_py_operation(self):
return self.operand.type.is_pyobject return self.operand.type.is_pyobject
def is_cpp_operation(self):
return self.operand.type.is_cpp_class
def coerce_operand_to_pyobject(self, env): def coerce_operand_to_pyobject(self, env):
self.operand = self.operand.coerce_to_pyobject(env) self.operand = self.operand.coerce_to_pyobject(env)
...@@ -3850,6 +3898,27 @@ class UnopNode(ExprNode): ...@@ -3850,6 +3898,27 @@ class UnopNode(ExprNode):
(self.operator, self.operand.type)) (self.operator, self.operand.type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
def analyse_cpp_operation(self, env):
type = operand.type
if type.is_ptr:
type = type.base_type
entry = env.lookup(type.name)
function = entry.type.scope.lookup(self.operators[self.operator])
if not function:
error(self.pos, "'%s' operator not defined for %s"
% (self.operator, type1, type2, self.operator))
self.type_error()
return
self.type = self.best_match([self.operand], function)
operator = {
"++": u"__inc__",
"--": u"__dec__",
"*": u"__deref__",
"!": u"__not__"
}
class NotNode(ExprNode): class NotNode(ExprNode):
# 'not' operator # 'not' operator
...@@ -4316,7 +4385,7 @@ class NumBinopNode(BinopNode): ...@@ -4316,7 +4385,7 @@ class NumBinopNode(BinopNode):
% (self.operator, type1, type2, self.operator)) % (self.operator, type1, type2, self.operator))
self.type_error() self.type_error()
return return
self.type = self.best_match(function) self.type = self.best_match([self.operand1, self.operand2], function)
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2): if self.c_types_okay(type1, type2):
...@@ -4347,50 +4416,6 @@ class NumBinopNode(BinopNode): ...@@ -4347,50 +4416,6 @@ class NumBinopNode(BinopNode):
def py_operation_function(self): def py_operation_function(self):
return self.py_functions[self.operator] return self.py_functions[self.operator]
def best_match(self, env):
entries = [env] + env.overloaded_alternatives
possibilities = []
args = [self.operand1, self.operand2]
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
score = [0,0,0]
for i in range(len(args)):
src_type = args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1].type
error(self.pos,
"Call with wrong arguments")# (expected %s, got %s)"
#% (expected_str, actual_nargs))
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
py_functions = { py_functions = {
"|": "PyNumber_Or", "|": "PyNumber_Or",
...@@ -4410,7 +4435,14 @@ class NumBinopNode(BinopNode): ...@@ -4410,7 +4435,14 @@ class NumBinopNode(BinopNode):
operators = { operators = {
"+": u"__add__", "+": u"__add__",
"-": u"__sub__", "-": u"__sub__",
"*": u"__mul__" "*": u"__mul__",
"<": u"__le__",
">": u"__gt__",
"==": u"__eq__",
"<=": u"__le__",
">=": u"__ge__",
"!=": u"__ne__",
"<>": u"__ne__"
} #for now } #for now
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment