Commit 7c3ca028 authored by scoder's avatar scoder

Merge pull request #414 from insertinterestingnamehere/assignment_diff_type

Further Assignment Operator Fixes
parents 8c563333 c4881c9e
...@@ -672,7 +672,7 @@ class ExprNode(Node): ...@@ -672,7 +672,7 @@ class ExprNode(Node):
else: else:
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
# Stub method for nodes which are not legal as # Stub method for nodes which are not legal as
# the LHS of an assignment. An error will have # the LHS of an assignment. An error will have
# been reported earlier. # been reported earlier.
...@@ -1998,7 +1998,7 @@ class NameNode(AtomicExprNode): ...@@ -1998,7 +1998,7 @@ class NameNode(AtomicExprNode):
if null_code and raise_unbound and (entry.type.is_pyobject or memslice_check): if null_code and raise_unbound and (entry.type.is_pyobject or memslice_check):
code.put_error_if_unbound(self.pos, entry, self.in_nogil_context) code.put_error_if_unbound(self.pos, entry, self.in_nogil_context)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
#print "NameNode.generate_assignment_code:", self.name ### #print "NameNode.generate_assignment_code:", self.name ###
entry = self.entry entry = self.entry
if entry is None: if entry is None:
...@@ -2088,6 +2088,9 @@ class NameNode(AtomicExprNode): ...@@ -2088,6 +2088,9 @@ class NameNode(AtomicExprNode):
code.put_giveref(rhs.py_result()) code.put_giveref(rhs.py_result())
if not self.type.is_memoryviewslice: if not self.type.is_memoryviewslice:
if not assigned: if not assigned:
if overloaded_assignment:
code.putln('%s = %s;' % (self.result(), rhs.result()))
else:
code.putln('%s = %s;' % ( code.putln('%s = %s;' % (
self.result(), rhs.result_as(self.ctype()))) self.result(), rhs.result_as(self.ctype())))
if debug_disposal_code: if debug_disposal_code:
...@@ -3767,7 +3770,7 @@ class IndexNode(ExprNode): ...@@ -3767,7 +3770,7 @@ class IndexNode(ExprNode):
# Simple case # Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result())) code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
generate_evaluation_code = (self.is_memslice_scalar_assignment or generate_evaluation_code = (self.is_memslice_scalar_assignment or
self.memslice_slice) self.memslice_slice)
if generate_evaluation_code: if generate_evaluation_code:
...@@ -4227,7 +4230,7 @@ class SliceIndexNode(ExprNode): ...@@ -4227,7 +4230,7 @@ class SliceIndexNode(ExprNode):
code.error_goto_if_null(result, self.pos))) code.error_goto_if_null(result, self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.type.is_pyobject:
code.globalstate.use_utility_code(self.set_slice_utility_code) code.globalstate.use_utility_code(self.set_slice_utility_code)
...@@ -6197,7 +6200,7 @@ class AttributeNode(ExprNode): ...@@ -6197,7 +6200,7 @@ class AttributeNode(ExprNode):
else: else:
ExprNode.generate_disposal_code(self, code) ExprNode.generate_disposal_code(self, code)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.obj.generate_evaluation_code(code) self.obj.generate_evaluation_code(code)
if self.is_py_attr: if self.is_py_attr:
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
...@@ -6534,7 +6537,7 @@ class SequenceNode(ExprNode): ...@@ -6534,7 +6537,7 @@ class SequenceNode(ExprNode):
if self.mult_factor: if self.mult_factor:
self.mult_factor.generate_disposal_code(code) self.mult_factor.generate_disposal_code(code)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
if self.starred_assignment: if self.starred_assignment:
self.generate_starred_assignment_code(rhs, code) self.generate_starred_assignment_code(rhs, code)
else: else:
......
...@@ -4787,9 +4787,11 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4787,9 +4787,11 @@ class SingleAssignmentNode(AssignmentNode):
# lhs ExprNode Left hand side # lhs ExprNode Left hand side
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs? # first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator=
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False first = False
is_overloaded_assignment = False
declaration_only = False declaration_only = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -4906,6 +4908,14 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4906,6 +4908,14 @@ class SingleAssignmentNode(AssignmentNode):
else: else:
dtype = self.lhs.type dtype = self.lhs.type
if self.lhs.type.is_cpp_class:
op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type])
if op:
rhs = self.rhs
self.is_overloaded_assignment = 1
else:
rhs = self.rhs.coerce_to(dtype, env)
else:
rhs = self.rhs.coerce_to(dtype, env) rhs = self.rhs.coerce_to(dtype, env)
if use_temp or rhs.is_attribute or ( if use_temp or rhs.is_attribute or (
not rhs.is_name and not rhs.is_literal and not rhs.is_name and not rhs.is_literal and
...@@ -5054,8 +5064,12 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5054,8 +5064,12 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
def generate_assignment_code(self, code): def generate_assignment_code(self, code):
if self.is_overloaded_assignment:
self.lhs.generate_assignment_code(self.rhs, code, overloaded_assignment=True)
else:
self.lhs.generate_assignment_code(self.rhs, code) self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code) self.rhs.generate_function_definitions(env, code)
...@@ -5076,10 +5090,12 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5076,10 +5090,12 @@ class CascadedAssignmentNode(AssignmentNode):
# #
# coerced_values [ExprNode] RHS coerced to all distinct LHS types # coerced_values [ExprNode] RHS coerced to all distinct LHS types
# cloned_values [ExprNode] cloned RHS value for each LHS # cloned_values [ExprNode] cloned RHS value for each LHS
# assignment_overloads [Bool] If each assignment uses a C++ operator=
child_attrs = ["lhs_list", "rhs", "coerced_values", "cloned_values"] child_attrs = ["lhs_list", "rhs", "coerced_values", "cloned_values"]
cloned_values = None cloned_values = None
coerced_values = None coerced_values = None
assignment_overloads = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
for lhs in self.lhs_list: for lhs in self.lhs_list:
...@@ -5096,8 +5112,14 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5096,8 +5112,14 @@ class CascadedAssignmentNode(AssignmentNode):
lhs_types.add(lhs.type) lhs_types.add(lhs.type)
rhs = self.rhs.analyse_types(env) rhs = self.rhs.analyse_types(env)
if len(lhs_types) == 1:
# common special case: only one type needed on the LHS => coerce only once # common special case: only one type needed on the LHS => coerce only once
if len(lhs_types) == 1:
# Avoid coercion for overloaded assignment operators.
if next(iter(lhs_types)).is_cpp_class:
op = env.lookup_operator('=', [lhs, self.rhs])
if not op:
rhs = rhs.coerce_to(lhs_types.pop(), env)
else:
rhs = rhs.coerce_to(lhs_types.pop(), env) rhs = rhs.coerce_to(lhs_types.pop(), env)
if not rhs.is_name and not rhs.is_literal and ( if not rhs.is_name and not rhs.is_literal and (
...@@ -5110,11 +5132,26 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5110,11 +5132,26 @@ class CascadedAssignmentNode(AssignmentNode):
# clone RHS and coerce it to all distinct LHS types # clone RHS and coerce it to all distinct LHS types
self.coerced_values = [] self.coerced_values = []
coerced_values = {} coerced_values = {}
self.assignment_overloads = []
for lhs in self.lhs_list: for lhs in self.lhs_list:
overloaded = False
if lhs.type.is_cpp_class:
op = env.lookup_operator('=', [lhs, self.rhs])
if op:
rhs = self.rhs
self.assignment_overloads.append(True)
overloaded = True
else:
self.assignment_overloads.append(False)
else:
self.assignment_overloads.append(False)
if lhs.type not in coerced_values and lhs.type != rhs.type: if lhs.type not in coerced_values and lhs.type != rhs.type:
if not overloaded:
rhs = CloneNode(self.rhs).coerce_to(lhs.type, env) rhs = CloneNode(self.rhs).coerce_to(lhs.type, env)
self.coerced_values.append(rhs) self.coerced_values.append(rhs)
coerced_values[lhs.type] = rhs coerced_values[lhs.type] = rhs
else:
self.assignment_overloads.append(False)
# clone coerced values for all LHS assignments # clone coerced values for all LHS assignments
self.cloned_values = [] self.cloned_values = []
...@@ -5131,9 +5168,9 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5131,9 +5168,9 @@ class CascadedAssignmentNode(AssignmentNode):
for rhs in self.coerced_values: for rhs in self.coerced_values:
rhs.generate_evaluation_code(code) rhs.generate_evaluation_code(code)
# assign clones to LHS # assign clones to LHS
for lhs, rhs in zip(self.lhs_list, self.cloned_values): for lhs, rhs, overload in zip(self.lhs_list, self.cloned_values, self.assignment_overloads):
rhs.generate_evaluation_code(code) rhs.generate_evaluation_code(code)
lhs.generate_assignment_code(rhs, code) lhs.generate_assignment_code(rhs, code, overloaded_assignment=overload)
# dispose of coerced values and original RHS # dispose of coerced values and original RHS
for rhs_value in self.coerced_values: for rhs_value in self.coerced_values:
rhs_value.generate_disposal_code(code) rhs_value.generate_disposal_code(code)
......
...@@ -24,6 +24,10 @@ public: ...@@ -24,6 +24,10 @@ public:
this->val = other.val; this->val = other.val;
return *this; return *this;
} }
wrapped_int &operator=(const long long val) {
this->val = val;
return *this;
}
}; };
...@@ -35,6 +39,7 @@ cdef extern from "assign.cpp" nogil: ...@@ -35,6 +39,7 @@ cdef extern from "assign.cpp" nogil:
wrapped_int() wrapped_int()
wrapped_int(long long val) wrapped_int(long long val)
wrapped_int& operator=(const wrapped_int &other) wrapped_int& operator=(const wrapped_int &other)
wrapped_int& operator=(const long long &other)
######## assignment_overload.pyx ######## ######## assignment_overload.pyx ########
...@@ -44,6 +49,7 @@ from assign cimport wrapped_int ...@@ -44,6 +49,7 @@ from assign cimport wrapped_int
def test(): def test():
cdef wrapped_int a = wrapped_int(2) cdef wrapped_int a = wrapped_int(2)
cdef wrapped_int b = wrapped_int(3) cdef wrapped_int b = wrapped_int(3)
cdef long long c = 4
assert &a != &b assert &a != &b
assert a.val != b.val assert a.val != b.val
...@@ -51,3 +57,11 @@ def test(): ...@@ -51,3 +57,11 @@ def test():
a = b a = b
assert &a != &b assert &a != &b
assert a.val == b.val assert a.val == b.val
a = c
assert a.val == c
a, b, c = 2, 3, 4
a = b = c
assert &a != &b
assert a.val == b.val
assert b.val == c
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment