Commit a598f9e6 authored by DaniloFreitas's avatar DaniloFreitas

Some work with operator

parent c4808363
...@@ -2392,6 +2392,7 @@ class SimpleCallNode(CallNode): ...@@ -2392,6 +2392,7 @@ class SimpleCallNode(CallNode):
def best_match(self): def best_match(self):
entries = [self.function.entry] + self.function.entry.overloaded_alternatives entries = [self.function.entry] + self.function.entry.overloaded_alternatives
#print self.function.entry.name, self.function.entry.type, self.function.entry.overloaded_alternatives
actual_nargs = len(self.args) actual_nargs = len(self.args)
possibilities = [] possibilities = []
for entry in entries: for entry in entries:
...@@ -2407,7 +2408,10 @@ class SimpleCallNode(CallNode): ...@@ -2407,7 +2408,10 @@ class SimpleCallNode(CallNode):
score = [0,0,0] score = [0,0,0]
for i in range(len(self.args)): for i in range(len(self.args)):
src_type = self.args[i].type src_type = self.args[i].type
dst_type = entry.type.base_type.args[i].type if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type): if dst_type.assignable_from(src_type):
if src_type == dst_type: if src_type == dst_type:
pass # score 0 pass # score 0
...@@ -2429,9 +2433,11 @@ class SimpleCallNode(CallNode): ...@@ -2429,9 +2433,11 @@ class SimpleCallNode(CallNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return None return None
#for (score, entry) in possibilities:
#print entry.name, entry.type, score
return possibilities[0][1] return possibilities[0][1]
error(self.pos, error(self.pos,
"Call with wrong number of arguments")# (expected %s, got %s)" "Call with wrong arguments")# (expected %s, got %s)"
#% (expected_str, actual_nargs)) #% (expected_str, actual_nargs))
self.args = None self.args = None
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
...@@ -4225,6 +4231,8 @@ class BinopNode(NewTempExprNode): ...@@ -4225,6 +4231,8 @@ class BinopNode(NewTempExprNode):
self.is_temp = 1 self.is_temp = 1
if Options.incref_local_binop and self.operand1.type.is_pyobject: if Options.incref_local_binop and self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_temp(env) self.operand1 = self.operand1.coerce_to_temp(env)
elif self.is_cpp_operation():
self.analyse_cpp_operation(env)
else: else:
self.analyse_c_operation(env) self.analyse_c_operation(env)
...@@ -4232,6 +4240,16 @@ class BinopNode(NewTempExprNode): ...@@ -4232,6 +4240,16 @@ class BinopNode(NewTempExprNode):
return (self.operand1.type.is_pyobject return (self.operand1.type.is_pyobject
or self.operand2.type.is_pyobject) or self.operand2.type.is_pyobject)
def is_cpp_operation(self):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
return (type1.is_cpp_class
or type2.is_cpp_class)
def coerce_operands_to_pyobjects(self, env): def coerce_operands_to_pyobjects(self, env):
self.operand1 = self.operand1.coerce_to_pyobject(env) self.operand1 = self.operand1.coerce_to_pyobject(env)
self.operand2 = self.operand2.coerce_to_pyobject(env) self.operand2 = self.operand2.coerce_to_pyobject(env)
...@@ -4345,6 +4363,74 @@ class IntBinopNode(NumBinopNode): ...@@ -4345,6 +4363,74 @@ class IntBinopNode(NumBinopNode):
class AddNode(NumBinopNode): class AddNode(NumBinopNode):
# '+' operator. # '+' operator.
def analyse_cpp_operation(self, env):
type1 = self.operand1.type
type2 = self.operand2.type
if type1.is_ptr:
type1 = type1.base_type
if type2.is_ptr:
type2 = type2.base_type
entry1 = env.lookup(type1.name)
entry2 = env.lookup(type2.name)
entry = entry1.scope.lookup_here("__add__")
if not entry:
error(self.pos, "'+' operator not defined for '%s + %s'"
% (self.operand1.type, self.operand2.type))
self.type_error()
return
self.type = self.best_match(entry)
def best_match(self, entry):
entries = [entry] + entry.overloaded_alternatives
actual_nargs = 2
possibilities = []
for entry in entries:
type = entry.type
if type.is_ptr:
type = type.base_type
# Check no. of args
max_nargs = len(type.args)
expected_nargs = max_nargs - type.optional_arg_count
if actual_nargs < expected_nargs \
or (not type.has_varargs and actual_nargs > max_nargs):
continue
score = [0,0,0]
for i in range(len(self.args)):
src_type = self.args[i].type
if entry.type.is_ptr:
dst_type = entry.type.base_type.args[i].type
else:
dst_type = entry.type.args[i].type
if dst_type.assignable_from(src_type):
if src_type == dst_type:
pass # score 0
elif PyrexTypes.is_promotion(src_type, dst_type):
score[2] += 1
elif not src_type.is_pyobject:
score[1] += 1
else:
score[0] += 1
else:
break
else:
possibilities.append((score, entry)) # so we can sort it
if len(possibilities):
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
return possibilities[0][1]
error(self.pos,
"Call with wrong arguments")# (expected %s, got %s)"
#% (expected_str, actual_nargs))
self.args = None
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return None
def is_py_operation(self): def is_py_operation(self):
if self.operand1.type.is_string \ if self.operand1.type.is_string \
and self.operand2.type.is_string: and self.operand2.type.is_string:
......
...@@ -359,6 +359,7 @@ class StatListNode(Node): ...@@ -359,6 +359,7 @@ class StatListNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
#print "StatListNode.analyse_expressions" ### #print "StatListNode.analyse_expressions" ###
entry = env.entries.get("cpp_sum", None)
for stat in self.stats: for stat in self.stats:
stat.analyse_expressions(env) stat.analyse_expressions(env)
......
...@@ -1385,6 +1385,7 @@ class CppClassType(CType): ...@@ -1385,6 +1385,7 @@ class CppClassType(CType):
self._convert_code = None self._convert_code = None
self.packed = packed self.packed = packed
self.base_classes = base_classes self.base_classes = base_classes
self.operators = []
def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
if for_display or pyrex: if for_display or pyrex:
......
...@@ -454,24 +454,29 @@ class Scope(object): ...@@ -454,24 +454,29 @@ class Scope(object):
cname = None, visibility = 'private', defining = 0, cname = None, visibility = 'private', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = ()):
# Add an entry for a C function. # Add an entry for a C function.
if not cname:
if api or visibility != 'private':
cname = name
else:
cname = self.mangle(Naming.func_prefix, name)
entry = self.lookup_here(name) entry = self.lookup_here(name)
if entry: if entry:
entry.overloaded_alternatives.append(self.add_cfunction(name, type, pos, cname, visibility, modifiers))
if visibility != 'private' and visibility != entry.visibility: if visibility != 'private' and visibility != entry.visibility:
warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1) warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1)
if not entry.type.same_as(type): if not entry.type.same_as(type):
if visibility == 'extern' and entry.visibility == 'extern': if visibility == 'extern' and entry.visibility == 'extern':
warning(pos, "Function signature does not match previous declaration", 1) warning(pos, "Function signature does not match previous declaration", 1)
entry.type = type #entry.type = type
else: else:
error(pos, "Function signature does not match previous declaration") error(pos, "Function signature does not match previous declaration")
else: else:
if not cname:
if api or visibility != 'private':
cname = name
else:
cname = self.mangle(Naming.func_prefix, name)
entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers) entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers)
entry.func_cname = cname entry.func_cname = cname
#try:
# print entry.name, entry.type, entry.overloaded_alternatives
#except:
# pass
if in_pxd and visibility != 'extern': if in_pxd and visibility != 'extern':
entry.defined_in_pxd = 1 entry.defined_in_pxd = 1
if api: if api:
...@@ -482,6 +487,12 @@ class Scope(object): ...@@ -482,6 +487,12 @@ class Scope(object):
entry.is_implemented = True entry.is_implemented = True
if modifiers: if modifiers:
entry.func_modifiers = modifiers entry.func_modifiers = modifiers
#try:
# print entry.name, entry.type, entry.overloaded_alternatives
#except:
# pass
#if len(entry.overloaded_alternatives) > 0:
# print entry.name, entry.type, entry.overloaded_alternatives[0].type
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment