Commit 235aaa7a authored by Stefan Behnel's avatar Stefan Behnel

allow list literals with a constant multiplier to be assigned to C arrays,...

allow list literals with a constant multiplier to be assigned to C arrays, extend array assignment tests
parent 903cfee7
...@@ -3947,6 +3947,11 @@ class SliceIndexNode(ExprNode): ...@@ -3947,6 +3947,11 @@ class SliceIndexNode(ExprNode):
"default encoding required for conversion from '%s' to '%s'" % "default encoding required for conversion from '%s' to '%s'" %
(self.base.type, dst_type)) (self.base.type, dst_type))
self.type = dst_type self.type = dst_type
if dst_type.is_array and self.base.type.is_array:
if not self.start and not self.stop:
# redundant slice building, copy C arrays directly
return self.base.coerce_to(dst_type, env)
# else: check array size if possible
return super(SliceIndexNode, self).coerce_to(dst_type, env) return super(SliceIndexNode, self).coerce_to(dst_type, env)
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -6553,16 +6558,25 @@ class ListNode(SequenceNode): ...@@ -6553,16 +6558,25 @@ class ListNode(SequenceNode):
self.obj_conversion_errors = [] self.obj_conversion_errors = []
if not self.type.subtype_of(dst_type): if not self.type.subtype_of(dst_type):
error(self.pos, "Cannot coerce list to type '%s'" % dst_type) error(self.pos, "Cannot coerce list to type '%s'" % dst_type)
elif self.mult_factor:
error(self.pos, "Cannot coerce multiplied list to '%s'" % dst_type)
elif (dst_type.is_array or dst_type.is_ptr) and dst_type.base_type is not PyrexTypes.c_void_type: elif (dst_type.is_array or dst_type.is_ptr) and dst_type.base_type is not PyrexTypes.c_void_type:
array_length = len(self.args)
if self.mult_factor:
if isinstance(self.mult_factor.constant_result, (int, long)):
if self.mult_factor.constant_result <= 0:
error(self.pos, "Cannot coerce non-positively multiplied list to '%s'" % dst_type)
else:
array_length *= self.mult_factor.constant_result
else:
error(self.pos, "Cannot coerce dynamically multiplied list to '%s'" % dst_type)
base_type = dst_type.base_type base_type = dst_type.base_type
self.type = PyrexTypes.CArrayType(base_type, len(self.args)) self.type = PyrexTypes.CArrayType(base_type, array_length)
for i in range(len(self.original_args)): for i in range(len(self.original_args)):
arg = self.args[i] arg = self.args[i]
if isinstance(arg, CoerceToPyTypeNode): if isinstance(arg, CoerceToPyTypeNode):
arg = arg.arg arg = arg.arg
self.args[i] = arg.coerce_to(base_type, env) self.args[i] = arg.coerce_to(base_type, env)
elif self.mult_factor:
error(self.pos, "Cannot coerce multiplied list to '%s'" % dst_type)
elif dst_type.is_struct: elif dst_type.is_struct:
if len(self.args) > len(dst_type.scope.var_entries): if len(self.args) > len(dst_type.scope.var_entries):
error(self.pos, "Too may members for '%s'" % dst_type) error(self.pos, "Too may members for '%s'" % dst_type)
...@@ -6619,11 +6633,23 @@ class ListNode(SequenceNode): ...@@ -6619,11 +6633,23 @@ class ListNode(SequenceNode):
report_error(err) report_error(err)
self.generate_sequence_packing_code(code) self.generate_sequence_packing_code(code)
elif self.type.is_array: elif self.type.is_array:
if self.mult_factor:
code.putln("{")
code.putln("Py_ssize_t %s;" % Naming.quick_temp_cname)
code.putln("for ({i} = 0; {i} < {count}; {i}++) {{".format(
i=Naming.quick_temp_cname, count=self.mult_factor.result()))
offset = '+ (%d * %s)' % (len(self.args), Naming.quick_temp_cname)
else:
offset = ''
for i, arg in enumerate(self.args): for i, arg in enumerate(self.args):
code.putln("%s[%s] = %s;" % ( code.putln("%s[%s%s] = %s;" % (
self.result(), self.result(),
i, i,
offset,
arg.result())) arg.result()))
if self.mult_factor:
code.putln("}")
code.putln("}")
elif self.type.is_struct: elif self.type.is_struct:
for arg, member in zip(self.args, self.type.scope.var_entries): for arg, member in zip(self.args, self.type.scope.var_entries):
code.putln("%s.%s = %s;" % ( code.putln("%s.%s = %s;" % (
......
...@@ -9,6 +9,15 @@ def test_literal_list(): ...@@ -9,6 +9,15 @@ def test_literal_list():
a = [1,2,3,4,5] a = [1,2,3,4,5]
return (a[0], a[1], a[2], a[3], a[4]) return (a[0], a[1], a[2], a[3], a[4])
def test_literal_list_multiplied():
"""
>>> test_literal_list_multiplied()
(1, 2, 1, 2, 1, 2)
"""
cdef int a[6]
a = [1,2] * 3
return (a[0], a[1], a[2], a[3], a[4], a[5])
def test_literal_list_slice_all(): def test_literal_list_slice_all():
""" """
>>> test_literal_list_slice_all() >>> test_literal_list_slice_all()
...@@ -297,3 +306,34 @@ def assign_to_wrong_csize(): ...@@ -297,3 +306,34 @@ def assign_to_wrong_csize():
v[2] = 3 v[2] = 3
d = v d = v
return d return d
def assign_full_array_slice_to_array():
"""
>>> assign_full_array_slice_to_array()
[1, 2, 3]
"""
cdef int[3] x, y
x[0] = 1
x[1] = 2
x[2] = 3
y = x[:]
return y
cdef class ArrayOwner:
cdef readonly int[3] array
def __init__(self, a, b, c):
self.array = (a, b, c)
def assign_from_array_attribute():
"""
>>> assign_from_array_attribute()
[1, 2, 3]
"""
cdef int[3] v
a = ArrayOwner(1, 2, 3)
v = a.array[:]
return v
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment