Commit b3f27fe5 authored by Robert Bradshaw's avatar Robert Bradshaw

Uniformize and cleanup operator overloading.

parent ed063894
...@@ -1922,14 +1922,10 @@ class IndexNode(ExprNode): ...@@ -1922,14 +1922,10 @@ class IndexNode(ExprNode):
"Invalid index type '%s'" % "Invalid index type '%s'" %
self.index.type) self.index.type)
elif self.base.type.is_cpp_class: elif self.base.type.is_cpp_class:
function = env.lookup_operator("[]", [self.base, self.index])
function = self.base.type.scope.lookup("operator[]") function = self.base.type.scope.lookup("operator[]")
if function is None: if function is None:
error(self.pos, "Indexing '%s' not supported" % self.base.type) error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type))
else:
function = PyrexTypes.best_match([self.index], function.all_alternatives(), self.pos)
if function is None:
error(self.pos, "Invalid index type '%s'" % self.index.type)
if function is None:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
...@@ -4682,6 +4678,9 @@ class BinopNode(ExprNode): ...@@ -4682,6 +4678,9 @@ class BinopNode(ExprNode):
def is_py_operation(self): def is_py_operation(self):
return self.is_py_operation_types(self.operand1.type, self.operand2.type) return self.is_py_operation_types(self.operand1.type, self.operand2.type)
def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject
def is_cpp_operation(self): def is_cpp_operation(self):
type1 = self.operand1.type type1 = self.operand1.type
type2 = self.operand2.type type2 = self.operand2.type
...@@ -4692,9 +4691,23 @@ class BinopNode(ExprNode): ...@@ -4692,9 +4691,23 @@ class BinopNode(ExprNode):
return (type1.is_cpp_class return (type1.is_cpp_class
or type2.is_cpp_class) or type2.is_cpp_class)
def is_py_operation_types(self, type1, type2): def analyse_cpp_operation(self, env):
return type1.is_pyobject or type2.is_pyobject type1 = self.operand1.type
type2 = self.operand2.type
entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
if not entry:
self.type_error()
return
func_type = entry.type
if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def result_type(self, type1, type2): def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2): if self.is_py_operation_types(type1, type2):
return py_object_type return py_object_type
...@@ -4756,34 +4769,6 @@ class NumBinopNode(BinopNode): ...@@ -4756,34 +4769,6 @@ class NumBinopNode(BinopNode):
if not self.infix: if not self.infix:
self.operand1 = self.operand1.coerce_to(self.type, env) self.operand1 = self.operand1.coerce_to(self.type, env)
self.operand2 = self.operand2.coerce_to(self.type, env) self.operand2 = self.operand2.coerce_to(self.type, env)
def analyse_cpp_operation(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
entry = env.lookup(type1.name)
# Shouldn't this be type1.scope?
function = entry.type.scope.lookup("operator%s" % self.operator)
if function is not None:
operands = [self.operand2]
else:
function = env.lookup("operator%s" % self.operator)
operands = [self.operand1, self.operand2]
if not function:
self.type_error()
return
entry = PyrexTypes.best_match(operands, function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
if (entry.type.is_ptr):
self.type = entry.type.base_type.return_type
else:
self.type = entry.type.return_type
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
if self.c_types_okay(type1, type2): if self.c_types_okay(type1, type2):
...@@ -5632,25 +5617,22 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5632,25 +5617,22 @@ class PrimaryCmpNode(ExprNode, CmpNode):
def analyse_cpp_comparison(self, env): def analyse_cpp_comparison(self, env):
type1 = self.operand1.type type1 = self.operand1.type
type2 = self.operand2.type type2 = self.operand2.type
if type1.is_reference: entry = env.lookup_operator(self.operator, [self.operand1, self.operand2])
type1 = type1.base_type if entry is None:
if type2.is_reference:
type2 = type2.base_type
entry = env.lookup(type1.name)
function = entry.type.scope.lookup("operator%s" % self.operator)
if not function:
error(self.pos, "Invalid types for '%s' (%s, %s)" % error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, type1, type2)) (self.operator, type1, type2))
return
entry = PyrexTypes.best_match([self.operand2], function.all_alternatives(), self.pos)
if entry is None:
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
if entry.type.is_ptr: func_type = entry.type
self.type = entry.type.base_type.return_type if func_type.is_ptr:
func_type = func_type.base_type
if len(func_type.args) == 1:
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else: else:
self.type = entry.type.return_type self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type
def has_python_operands(self): def has_python_operands(self):
return (self.operand1.type.is_pyobject return (self.operand1.type.is_pyobject
......
...@@ -2201,7 +2201,7 @@ def is_promotion(type, other_type): ...@@ -2201,7 +2201,7 @@ def is_promotion(type, other_type):
else: else:
return False return False
def best_match(args, functions, pos): def best_match(args, functions, pos=None):
""" """
Finds the best function to be called Finds the best function to be called
Error if no function fits the call or an ambiguity is find (two or more possible functions) Error if no function fits the call or an ambiguity is find (two or more possible functions)
...@@ -2217,7 +2217,7 @@ def best_match(args, functions, pos): ...@@ -2217,7 +2217,7 @@ def best_match(args, functions, pos):
func_type = func_type.base_type func_type = func_type.base_type
# Check function type # Check function type
if not func_type.is_cfunction: if not func_type.is_cfunction:
if not func_type.is_error: if not func_type.is_error and pos is not None:
error(pos, "Calling non-function type '%s'" % func_type) error(pos, "Calling non-function type '%s'" % func_type)
return None return None
# Check no. of args # Check no. of args
...@@ -2262,14 +2262,15 @@ def best_match(args, functions, pos): ...@@ -2262,14 +2262,15 @@ def best_match(args, functions, pos):
if len(possibilities): if len(possibilities):
possibilities.sort() possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(pos, "ambiguous overloaded method") if pos is not None:
error(pos, "ambiguous overloaded method")
return None return None
return possibilities[0][1] return possibilities[0][1]
if bad_types: if pos is not None:
error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type)) if bad_types:
return None error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
else: else:
error(pos, error_str) error(pos, error_str)
return None return None
......
...@@ -554,6 +554,21 @@ class Scope(object): ...@@ -554,6 +554,21 @@ class Scope(object):
entry = self.lookup(name) entry = self.lookup(name)
if entry and entry.is_type: if entry and entry.is_type:
return entry.type return entry.type
def lookup_operator(self, operator, operands):
if operands[0].type.is_cpp_class:
obj_type = operands[0].type
if obj_type.is_reference:
obj_type = obj_type.base_type
method = obj_type.scope.lookup("operator%s" % operator)
if method is not None:
res = PyrexTypes.best_match(operands[1:], method.all_alternatives())
if res is not None:
return res
function = self.lookup("operator%s" % operator)
if function is None:
return None
return PyrexTypes.best_match(operands, function.all_alternatives())
def use_utility_code(self, new_code): def use_utility_code(self, new_code):
self.global_scope().use_utility_code(new_code) self.global_scope().use_utility_code(new_code)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment