Commit 7c3ca028 authored by scoder's avatar scoder

Merge pull request #414 from insertinterestingnamehere/assignment_diff_type

Further Assignment Operator Fixes
parents 8c563333 c4881c9e
......@@ -672,7 +672,7 @@ class ExprNode(Node):
else:
self.generate_subexpr_disposal_code(code)
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
# Stub method for nodes which are not legal as
# the LHS of an assignment. An error will have
# been reported earlier.
......@@ -1998,7 +1998,7 @@ class NameNode(AtomicExprNode):
if null_code and raise_unbound and (entry.type.is_pyobject or memslice_check):
code.put_error_if_unbound(self.pos, entry, self.in_nogil_context)
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
#print "NameNode.generate_assignment_code:", self.name ###
entry = self.entry
if entry is None:
......@@ -2088,8 +2088,11 @@ class NameNode(AtomicExprNode):
code.put_giveref(rhs.py_result())
if not self.type.is_memoryviewslice:
if not assigned:
code.putln('%s = %s;' % (
self.result(), rhs.result_as(self.ctype())))
if overloaded_assignment:
code.putln('%s = %s;' % (self.result(), rhs.result()))
else:
code.putln('%s = %s;' % (
self.result(), rhs.result_as(self.ctype())))
if debug_disposal_code:
print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs)
......@@ -3767,7 +3770,7 @@ class IndexNode(ExprNode):
# Simple case
code.putln("*%s %s= %s;" % (ptrexpr, op, rhs.result()))
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
generate_evaluation_code = (self.is_memslice_scalar_assignment or
self.memslice_slice)
if generate_evaluation_code:
......@@ -4227,7 +4230,7 @@ class SliceIndexNode(ExprNode):
code.error_goto_if_null(result, self.pos)))
code.put_gotref(self.py_result())
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject:
code.globalstate.use_utility_code(self.set_slice_utility_code)
......@@ -6197,7 +6200,7 @@ class AttributeNode(ExprNode):
else:
ExprNode.generate_disposal_code(self, code)
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
self.obj.generate_evaluation_code(code)
if self.is_py_attr:
code.globalstate.use_utility_code(
......@@ -6534,7 +6537,7 @@ class SequenceNode(ExprNode):
if self.mult_factor:
self.mult_factor.generate_disposal_code(code)
def generate_assignment_code(self, rhs, code):
def generate_assignment_code(self, rhs, code, overloaded_assignment=False):
if self.starred_assignment:
self.generate_starred_assignment_code(rhs, code)
else:
......
......@@ -4784,12 +4784,14 @@ class SingleAssignmentNode(AssignmentNode):
#
# a = b
#
# lhs ExprNode Left hand side
# rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
# lhs ExprNode Left hand side
# rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator=
child_attrs = ["lhs", "rhs"]
first = False
is_overloaded_assignment = False
declaration_only = False
def analyse_declarations(self, env):
......@@ -4906,7 +4908,15 @@ class SingleAssignmentNode(AssignmentNode):
else:
dtype = self.lhs.type
rhs = self.rhs.coerce_to(dtype, env)
if self.lhs.type.is_cpp_class:
op = env.lookup_operator_for_types(self.pos, '=', [self.lhs.type, self.rhs.type])
if op:
rhs = self.rhs
self.is_overloaded_assignment = 1
else:
rhs = self.rhs.coerce_to(dtype, env)
else:
rhs = self.rhs.coerce_to(dtype, env)
if use_temp or rhs.is_attribute or (
not rhs.is_name and not rhs.is_literal and
rhs.type.is_pyobject):
......@@ -5054,7 +5064,11 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs.generate_evaluation_code(code)
def generate_assignment_code(self, code):
self.lhs.generate_assignment_code(self.rhs, code)
if self.is_overloaded_assignment:
self.lhs.generate_assignment_code(self.rhs, code, overloaded_assignment=True)
else:
self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code)
......@@ -5074,12 +5088,14 @@ class CascadedAssignmentNode(AssignmentNode):
#
# Used internally:
#
# coerced_values [ExprNode] RHS coerced to all distinct LHS types
# cloned_values [ExprNode] cloned RHS value for each LHS
# coerced_values [ExprNode] RHS coerced to all distinct LHS types
# cloned_values [ExprNode] cloned RHS value for each LHS
# assignment_overloads [Bool] If each assignment uses a C++ operator=
child_attrs = ["lhs_list", "rhs", "coerced_values", "cloned_values"]
cloned_values = None
coerced_values = None
assignment_overloads = None
def analyse_declarations(self, env):
for lhs in self.lhs_list:
......@@ -5096,9 +5112,15 @@ class CascadedAssignmentNode(AssignmentNode):
lhs_types.add(lhs.type)
rhs = self.rhs.analyse_types(env)
# common special case: only one type needed on the LHS => coerce only once
if len(lhs_types) == 1:
# common special case: only one type needed on the LHS => coerce only once
rhs = rhs.coerce_to(lhs_types.pop(), env)
# Avoid coercion for overloaded assignment operators.
if next(iter(lhs_types)).is_cpp_class:
op = env.lookup_operator('=', [lhs, self.rhs])
if not op:
rhs = rhs.coerce_to(lhs_types.pop(), env)
else:
rhs = rhs.coerce_to(lhs_types.pop(), env)
if not rhs.is_name and not rhs.is_literal and (
use_temp or rhs.is_attribute or rhs.type.is_pyobject):
......@@ -5110,11 +5132,26 @@ class CascadedAssignmentNode(AssignmentNode):
# clone RHS and coerce it to all distinct LHS types
self.coerced_values = []
coerced_values = {}
self.assignment_overloads = []
for lhs in self.lhs_list:
overloaded = False
if lhs.type.is_cpp_class:
op = env.lookup_operator('=', [lhs, self.rhs])
if op:
rhs = self.rhs
self.assignment_overloads.append(True)
overloaded = True
else:
self.assignment_overloads.append(False)
else:
self.assignment_overloads.append(False)
if lhs.type not in coerced_values and lhs.type != rhs.type:
rhs = CloneNode(self.rhs).coerce_to(lhs.type, env)
if not overloaded:
rhs = CloneNode(self.rhs).coerce_to(lhs.type, env)
self.coerced_values.append(rhs)
coerced_values[lhs.type] = rhs
else:
self.assignment_overloads.append(False)
# clone coerced values for all LHS assignments
self.cloned_values = []
......@@ -5131,9 +5168,9 @@ class CascadedAssignmentNode(AssignmentNode):
for rhs in self.coerced_values:
rhs.generate_evaluation_code(code)
# assign clones to LHS
for lhs, rhs in zip(self.lhs_list, self.cloned_values):
for lhs, rhs, overload in zip(self.lhs_list, self.cloned_values, self.assignment_overloads):
rhs.generate_evaluation_code(code)
lhs.generate_assignment_code(rhs, code)
lhs.generate_assignment_code(rhs, code, overloaded_assignment=overload)
# dispose of coerced values and original RHS
for rhs_value in self.coerced_values:
rhs_value.generate_disposal_code(code)
......
......@@ -24,6 +24,10 @@ public:
this->val = other.val;
return *this;
}
wrapped_int &operator=(const long long val) {
this->val = val;
return *this;
}
};
......@@ -35,6 +39,7 @@ cdef extern from "assign.cpp" nogil:
wrapped_int()
wrapped_int(long long val)
wrapped_int& operator=(const wrapped_int &other)
wrapped_int& operator=(const long long &other)
######## assignment_overload.pyx ########
......@@ -44,6 +49,7 @@ from assign cimport wrapped_int
def test():
cdef wrapped_int a = wrapped_int(2)
cdef wrapped_int b = wrapped_int(3)
cdef long long c = 4
assert &a != &b
assert a.val != b.val
......@@ -51,3 +57,11 @@ def test():
a = b
assert &a != &b
assert a.val == b.val
a = c
assert a.val == c
a, b, c = 2, 3, 4
a = b = c
assert &a != &b
assert a.val == b.val
assert b.val == c
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment