Commit b3f27fe5 authored by Robert Bradshaw's avatar Robert Bradshaw

Uniformize and cleanup operator overloading.

parent ed063894
......@@ -1922,14 +1922,10 @@ class IndexNode(ExprNode):
"Invalid index type '%s'" %
self.index.type)
elif self.base.type.is_cpp_class:
function = env.lookup_operator("[]", [self.base, self.index])
function = self.base.type.scope.lookup("operator[]")
if function is None:
error(self.pos, "Indexing '%s' not supported" % self.base.type)
else:
function = PyrexTypes.best_match([self.index], function.all_alternatives(), self.pos)
if function is None:
error(self.pos, "Invalid index type '%s'" % self.index.type)
if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type))
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
......@@ -4682,6 +4678,9 @@ class BinopNode(ExprNode):
def is_py_operation(self):
return self.is_py_operation_types(self.operand1.type, self.operand2.type)
def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject
def is_cpp_operation(self):
type1 = self.operand1.type
type2 = self.operand2.type
......@@ -4692,9 +4691,23 @@ class BinopNode(ExprNode):
return (type1.is_cpp_class
or type2.is_cpp_class)
def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject
def analyse_cpp_operation(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
if not entry:
self.type_error()
return
func_type = entry.type
if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2):
return py_object_type
......@@ -4756,34 +4769,6 @@ class NumBinopNode(BinopNode):
if not self.infix:
self.operand1 = self.operand1.coerce_to(self.type, env)
self.operand2 = self.operand2.coerce_to(self.type, env)
def analyse_cpp_operation(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
entry = env.lookup(type1.name)
# Shouldn't this be type1.scope?
function = entry.type.scope.lookup("operator%s" % self.operator)
if function is not None:
operands = [self.operand2]
else:
function = env.lookup("operator%s" % self.operator)
operands = [self.operand1, self.operand2]
if not function:
self.type_error()
return
entry = PyrexTypes.best_match(operands, function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type
def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2):
......@@ -5632,25 +5617,22 @@ class PrimaryCmpNode(ExprNode, CmpNode):
def analyse_cpp_comparison(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_reference:
type1 = type1.base_type
if type2.is_reference:
type2 = type2.base_type
entry = env.lookup(type1.name)
function = entry.type.scope.lookup("operator%s" % self.operator)
if not function:
entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
if entry is None:
error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, type1, type2))
return
entry = PyrexTypes.best_match([self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if entry.type.is_ptr:
self.type = entry.type.base_type.return_type
func_type = entry.type
if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else:
self.type = entry.type.return_type
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def has_python_operands(self):
return (self.operand1.type.is_pyobject
......
......@@ -2201,7 +2201,7 @@ def is_promotion(type, other_type):
else:
return False
def best_match(args, functions, pos):
def best_match(args, functions, pos=None):
"""
Finds the best function to be called
Error if no function fits the call or an ambiguity is find (two or more possible functions)
......@@ -2217,7 +2217,7 @@ def best_match(args, functions, pos):
func_type = func_type.base_type
# Check function type
if not func_type.is_cfunction:
if not func_type.is_error:
if not func_type.is_error and pos is not None:
error(pos, "Calling non-function type '%s'" % func_type)
return None
# Check no. of args
......@@ -2262,14 +2262,15 @@ def best_match(args, functions, pos):
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(pos, "ambiguous overloaded method")
if pos is not None:
error(pos, "ambiguous overloaded method")
return None
return possibilities[0][1]
if bad_types:
error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
return None
else:
error(pos, error_str)
if pos is not None:
if bad_types:
error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
else:
error(pos, error_str)
return None
......
......@@ -554,6 +554,21 @@ class Scope(object):
entry = self.lookup(name)
if entry and entry.is_type:
return entry.type
def lookup_operator(self, operator, operands):
if operands[0].type.is_cpp_class:
obj_type = operands[0].type
if obj_type.is_reference:
obj_type = obj_type.base_type
method = obj_type.scope.lookup("operator%s" % operator)
if method is not None:
res = PyrexTypes.best_match(operands[1:], method.all_alternatives())
if res is not None:
return res
function = self.lookup("operator%s" % operator)
if function is None:
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment