Commit e77db9c9 authored by Ian Henriksen's avatar Ian Henriksen

Allow Cython-style cascaded assignment to use overloaded assignment

operators declared for C++ classes.
parent 2bb76871
...@@ -5066,10 +5066,12 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5066,10 +5066,12 @@ class CascadedAssignmentNode(AssignmentNode):
# #
# coerced_values [ExprNode] RHS coerced to all distinct LHS types # coerced_values [ExprNode] RHS coerced to all distinct LHS types
# cloned_values [ExprNode] cloned RHS value for each LHS # cloned_values [ExprNode] cloned RHS value for each LHS
# assignment_overloads [Bool] If each assignment uses a C++ operator=
child_attrs = ["lhs_list", "rhs", "coerced_values", "cloned_values"] child_attrs = ["lhs_list", "rhs", "coerced_values", "cloned_values"]
cloned_values = None cloned_values = None
coerced_values = None coerced_values = None
assignment_overloads = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
for lhs in self.lhs_list: for lhs in self.lhs_list:
...@@ -5100,11 +5102,26 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5100,11 +5102,26 @@ class CascadedAssignmentNode(AssignmentNode):
# clone RHS and coerce it to all distinct LHS types # clone RHS and coerce it to all distinct LHS types
self.coerced_values = [] self.coerced_values = []
coerced_values = {} coerced_values = {}
self.assignment_overloads = []
for lhs in self.lhs_list: for lhs in self.lhs_list:
overloaded = False
if lhs.type.is_cpp_class:
op = env.lookup_operator('=', [lhs, self.rhs])
if op:
rhs = self.rhs
self.assignment_overloads.append(True)
overloaded = True
else:
self.assignment_overloads.append(False)
else:
self.assignment_overloads.append(False)
if lhs.type not in coerced_values and lhs.type != rhs.type: if lhs.type not in coerced_values and lhs.type != rhs.type:
if not overloaded:
rhs = CloneNode(self.rhs).coerce_to(lhs.type, env) rhs = CloneNode(self.rhs).coerce_to(lhs.type, env)
self.coerced_values.append(rhs) self.coerced_values.append(rhs)
coerced_values[lhs.type] = rhs coerced_values[lhs.type] = rhs
else:
self.assignment_overloads.append(False)
# clone coerced values for all LHS assignments # clone coerced values for all LHS assignments
self.cloned_values = [] self.cloned_values = []
...@@ -5121,9 +5138,9 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -5121,9 +5138,9 @@ class CascadedAssignmentNode(AssignmentNode):
for rhs in self.coerced_values: for rhs in self.coerced_values:
rhs.generate_evaluation_code(code) rhs.generate_evaluation_code(code)
# assign clones to LHS # assign clones to LHS
for lhs, rhs in zip(self.lhs_list, self.cloned_values): for lhs, rhs, overload in zip(self.lhs_list, self.cloned_values, self.assignment_overloads):
rhs.generate_evaluation_code(code) rhs.generate_evaluation_code(code)
lhs.generate_assignment_code(rhs, code) lhs.generate_assignment_code(rhs, code, overloaded_assignment=overload)
# dispose of coerced values and original RHS # dispose of coerced values and original RHS
for rhs_value in self.coerced_values: for rhs_value in self.coerced_values:
rhs_value.generate_disposal_code(code) rhs_value.generate_disposal_code(code)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment