Commit efd0d426 authored by Robert Bradshaw's avatar Robert Bradshaw

Assignment to array slice now works.

parent 5f8e0e70
...@@ -4760,6 +4760,9 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4760,6 +4760,9 @@ class SingleAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
unrolled_assignment = self.unroll_lhs(env)
if unrolled_assignment:
return unrolled_assignment
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast: if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast:
self.lhs.memslice_broadcast = True self.lhs.memslice_broadcast = True
...@@ -4791,32 +4794,32 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4791,32 +4794,32 @@ class SingleAssignmentNode(AssignmentNode):
from . import ExprNodes, UtilNodes from . import ExprNodes, UtilNodes
if node.type.is_ctuple: if node.type.is_ctuple:
if node.type.size == target_size: if node.type.size == target_size:
base = self.rhs base = node
start_node = None start_node = None
stop_node = None stop_node = None
step_node = None step_node = None
check_node = None check_node = None
else: else:
error(self.pos, "Unpacking type %s requires exactly %s arguments." % ( error(self.pos, "Unpacking type %s requires exactly %s arguments." % (
self.rhs.type, self.rhs.type.size)) node.type, node.type.size))
return return
elif node.type.is_ptr: elif node.type.is_ptr:
if isinstance(self.rhs, ExprNodes.SliceIndexNode): if isinstance(node, ExprNodes.SliceIndexNode):
base = self.rhs.base base = node.base
start_node = self.rhs.start start_node = node.start
if start_node: if start_node:
start_node = start_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env) start_node = start_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
stop_node = self.rhs.stop stop_node = node.stop
if stop_node: if stop_node:
stop_node = stop_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env) stop_node = stop_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
else: else:
if rhs.is_array and rhs.type.size: if node.type.is_array and node.type.size:
stop_node = ExprNodes.IntNode(pos=self.pos, value=str(rhs.type.size)) stop_node = ExprNodes.IntNode(pos=self.pos, value=str(rhs.type.size))
else: else:
error(self.pos, "C array iteration requires known end index") error(self.pos, "C array iteration requires known end index")
return return
step_node = None #self.rhs.step step_node = None #node.step
if step_node: if step_node:
step_node = step_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env) step_node = step_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
# TODO: Factor out SliceIndexNode.generate_slice_guard_code() for use here. # TODO: Factor out SliceIndexNode.generate_slice_guard_code() for use here.
...@@ -4867,17 +4870,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4867,17 +4870,10 @@ class SingleAssignmentNode(AssignmentNode):
index=ix_node)) index=ix_node))
return check_node, refs, items return check_node, refs, items
def unroll_rhs(self, env): def unroll_assignments(self, refs, check_node, lhs_list, rhs_list, env):
from . import ExprNodes, UtilNodes from . import ExprNodes, UtilNodes
if not isinstance(self.lhs, ExprNodes.TupleNode):
return
unrolled = self.unroll(self.rhs, len(self.lhs.args), env)
if not unrolled:
return
check_node, refs, items = unrolled
assignments = [] assignments = []
for lhs, rhs in zip(self.lhs.args, items): for lhs, rhs in zip(lhs_list, rhs_list):
assignments.append(SingleAssignmentNode( assignments.append(SingleAssignmentNode(
pos = self.pos, pos = self.pos,
lhs = lhs, lhs = lhs,
...@@ -4890,6 +4886,31 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4890,6 +4886,31 @@ class SingleAssignmentNode(AssignmentNode):
all = UtilNodes.LetNode(ref, all) all = UtilNodes.LetNode(ref, all)
return all return all
def unroll_rhs(self, env):
from . import ExprNodes, UtilNodes
if not isinstance(self.lhs, ExprNodes.TupleNode):
return
unrolled = self.unroll(self.rhs, len(self.lhs.args), env)
if not unrolled:
return
check_node, refs, rhs = unrolled
return self.unroll_assignments(refs, check_node, self.lhs.args, rhs, env)
def unroll_lhs(self, env):
if self.lhs.type.is_ctuple:
# Handled directly.
return
from . import ExprNodes, UtilNodes
if not isinstance(self.rhs, ExprNodes.TupleNode):
return
unrolled = self.unroll(self.lhs, len(self.rhs.args), env)
if not unrolled:
return
check_node, refs, lhs = unrolled
return self.unroll_assignments(refs, check_node, lhs, self.rhs.args, env)
def generate_rhs_evaluation_code(self, code): def generate_rhs_evaluation_code(self, code):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
......
...@@ -142,11 +142,20 @@ def test_ptr_literal_list_slice_end(): ...@@ -142,11 +142,20 @@ def test_ptr_literal_list_slice_end():
# a[:] = l # a[:] = l
# return (a[0], a[1], a[2], a[3], a[4]) # return (a[0], a[1], a[2], a[3], a[4])
def test_from_ptr(): def test_multiple_from_slice():
""" """
>>> test_from_ptr() >>> test_multiple_from_slice()
(5, 4, 3) (5, 4, 3)
""" """
cdef int *a = [6,5,4,3,2,1] cdef int *a = [6,5,4,3,2,1]
x, y, z = a[1:4] x, y, z = a[1:4]
return x, y, z return x, y, z
def test_slice_from_multiple():
"""
>>> test_slice_from_multiple()
(6, -1, -2, -3, 2, 1)
"""
cdef int *a = [6,5,4,3,2,1]
a[1:4] = -1, -2, -3
return a[0], a[1], a[2], a[3], a[4], a[5]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment