Commit c099b42b authored by Stefan Behnel's avatar Stefan Behnel

feature complete implementation of PEP 3132

parent 87fa865d
...@@ -2898,8 +2898,16 @@ class SequenceNode(ExprNode): ...@@ -2898,8 +2898,16 @@ class SequenceNode(ExprNode):
self.iterator = PyTempNode(self.pos, env) self.iterator = PyTempNode(self.pos, env)
self.unpacked_items = [] self.unpacked_items = []
self.coerced_unpacked_items = [] self.coerced_unpacked_items = []
self.starred_assignment = False
for arg in self.args: for arg in self.args:
arg.analyse_target_types(env) arg.analyse_target_types(env)
if arg.is_starred:
if not arg.type.assignable_from(Builtin.list_type):
error(arg.pos,
"starred target must have Python object (list) type")
if arg.type is py_object_type:
arg.type = Builtin.list_type
self.starred_assignment = True
unpacked_item = PyTempNode(self.pos, env) unpacked_item = PyTempNode(self.pos, env)
coerced_unpacked_item = unpacked_item.coerce_to(arg.type, env) coerced_unpacked_item = unpacked_item.coerce_to(arg.type, env)
self.unpacked_items.append(unpacked_item) self.unpacked_items.append(unpacked_item)
...@@ -2911,6 +2919,16 @@ class SequenceNode(ExprNode): ...@@ -2911,6 +2919,16 @@ class SequenceNode(ExprNode):
self.generate_operation_code(code) self.generate_operation_code(code)
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
if self.starred_assignment:
self.generate_starred_assignment_code(rhs, code)
else:
self.generate_normal_assignment_code(rhs, code)
for item in self.unpacked_items:
item.release(code)
rhs.free_temps(code)
def generate_normal_assignment_code(self, rhs, code):
# Need to work around the fact that generate_evaluation_code # Need to work around the fact that generate_evaluation_code
# allocates the temps in a rather hacky way -- the assignment # allocates the temps in a rather hacky way -- the assignment
# is evaluated twice, within each if-block. # is evaluated twice, within each if-block.
...@@ -2985,10 +3003,72 @@ class SequenceNode(ExprNode): ...@@ -2985,10 +3003,72 @@ class SequenceNode(ExprNode):
self.coerced_unpacked_items[i], code) self.coerced_unpacked_items[i], code)
code.putln("}") code.putln("}")
def generate_starred_assignment_code(self, rhs, code):
for i, arg in enumerate(self.args):
if arg.is_starred:
starred_target = self.unpacked_items[i]
fixed_args_left = self.args[:i]
fixed_args_right = self.args[i+1:]
break
self.iterator.allocate(code)
code.putln(
"%s = PyObject_GetIter(%s); %s" % (
self.iterator.result(),
rhs.py_result(),
code.error_goto_if_null(self.iterator.result(), self.pos)))
code.put_gotref(self.iterator.py_result())
rhs.generate_disposal_code(code)
for item in self.unpacked_items: for item in self.unpacked_items:
item.release(code) item.allocate(code)
rhs.free_temps(code) for i in range(len(fixed_args_left)):
item = self.unpacked_items[i]
unpack_code = "__Pyx_UnpackItem(%s, %d)" % (
self.iterator.py_result(), i)
code.putln(
"%s = %s; %s" % (
item.result(),
typecast(item.ctype(), py_object_type, unpack_code),
code.error_goto_if_null(item.result(), self.pos)))
code.put_gotref(item.py_result())
value_node = self.coerced_unpacked_items[i]
value_node.generate_evaluation_code(code)
target_list = starred_target.result()
code.putln("%s = PySequence_List(%s); %s" % (
target_list, self.iterator.py_result(),
code.error_goto_if_null(target_list, self.pos)))
code.put_gotref(target_list)
if fixed_args_right:
code.globalstate.use_utility_code(raise_need_more_values_to_unpack)
unpacked_right_args = self.unpacked_items[-len(fixed_args_right):]
code.putln("if (unlikely(PyList_GET_SIZE(%s) < %d)) {" % (
(target_list, len(unpacked_right_args))))
code.put("__Pyx_RaiseNeedMoreValuesError(%d+PyList_GET_SIZE(%s)); %s" % (
len(fixed_args_left), target_list,
code.error_goto(self.pos)))
code.putln('}')
for i, (arg, coerced_arg) in enumerate(zip(unpacked_right_args[::-1],
self.coerced_unpacked_items[::-1])):
code.putln(
"%s = PyList_GET_ITEM(%s, PyList_GET_SIZE(%s)-1); " % (
arg.py_result(),
target_list, target_list))
# resize the list the hard way
code.putln("((PyListObject*)%s)->ob_size--;" % target_list)
code.put_gotref(arg.py_result())
coerced_arg.generate_evaluation_code(code)
self.iterator.generate_disposal_code(code)
self.iterator.free_temps(code)
self.iterator.release(code)
for i in range(len(self.args)):
self.args[i].generate_assignment_code(
self.coerced_unpacked_items[i], code)
def annotate(self, code): def annotate(self, code):
for arg in self.args: for arg in self.args:
arg.annotate(code) arg.annotate(code)
......
...@@ -910,43 +910,96 @@ def p_expression_or_assignment(s): ...@@ -910,43 +910,96 @@ def p_expression_or_assignment(s):
return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
def flatten_parallel_assignments(input, output): def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing # The input is a list of expression nodes, representing the LHSs
# the LHSs and RHS of one (possibly cascaded) assignment # and RHS of one (possibly cascaded) assignment statement. For
# statement. If they are all sequence constructors with # sequence constructors, rearranges the matching parts of both
# the same number of arguments, rearranges them into a # sides into a list of equivalent assignments between the
# list of equivalent assignments between the individual # individual elements. This transformation is applied
# elements. This transformation is applied recursively. # recursively, so that nested structures get matched as well.
size = find_parallel_assignment_size(input) rhs = input[-1]
if size >= 0: if not rhs.is_sequence_constructor:
for i in range(size):
new_exprs = [expr.args[i] for expr in input]
flatten_parallel_assignments(new_exprs, output)
else:
output.append(input) output.append(input)
return
def find_parallel_assignment_size(input):
# The input is a list of expression nodes. If
# they are all sequence constructors with the same number
# of arguments, return that number, else return -1.
# Produces an error message if they are all sequence
# constructors but not all the same size.
for expr in input:
if not expr.is_sequence_constructor:
return -1
rhs = input[-1]
rhs_size = len(rhs.args) rhs_size = len(rhs.args)
lhs_targets = [ [] for _ in range(rhs_size) ]
starred_assignments = []
for lhs in input[:-1]: for lhs in input[:-1]:
starred_args = sum([1 for expr in lhs.args if expr.is_starred]) if not lhs.is_sequence_constructor:
if starred_args: if lhs.is_starred:
if starred_args > 1: error(lhs.pos, "starred assignment target must be in a list or tuple")
error(lhs.pos, "more than 1 starred expression in assignment") output.append(lhs)
return -1 continue
lhs_size = len(lhs.args) lhs_size = len(lhs.args)
if lhs_size != rhs_size: starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
error(lhs.pos, "Unpacking sequence of wrong size (expected %d, got %d)" if starred_targets:
% (lhs_size, rhs_size)) if starred_targets > 1:
return -1 error(lhs.pos, "more than 1 starred expression in assignment")
return rhs_size elif lhs_size - starred_targets > rhs_size:
error(lhs.pos, "need more than %d value%s to unpack"
% (rhs_size, (rhs_size != 1) and 's' or ''))
map_starred_assignment(lhs_targets, starred_assignments,
lhs.args, rhs.args)
else:
if lhs_size > rhs_size:
error(lhs.pos, "need more than %d value%s to unpack"
% (rhs_size, (rhs_size != 1) and 's' or ''))
elif lhs_size < rhs_size:
error(lhs.pos, "too many values to unpack (expected %d, got %d)"
% (lhs_size, rhs_size))
else:
for targets, expr in zip(lhs_targets, lhs.args):
targets.append(expr)
# recursively flatten partial assignments
for cascade, rhs in zip(lhs_targets, rhs.args):
if cascade:
cascade.append(rhs)
flatten_parallel_assignments(cascade, output)
# recursively flatten starred assignments
for cascade in starred_assignments:
if cascade[0].is_sequence_constructor:
flatten_parallel_assignments(cascade, output)
else:
output.append(cascade)
def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
# Appends the fixed-position LHS targets to the target list that
# appear left and right of the starred argument.
#
# The starred_assignments list receives a new tuple
# (lhs_target, rhs_values_list) that maps the remaining arguments
# (those that match the starred target) to a list.
# left side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
if expr.is_starred:
starred = i
lhs_remaining = len(lhs_args) - i - 1
break
targets.append(expr)
else:
raise InternalError("no starred arg found when splitting starred assignment")
# right side of the starred target
for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
lhs_args[-lhs_remaining:])):
targets.append(expr)
# the starred target itself, must be assigned a (potentially empty) list
target = lhs_args[starred]
target.is_starred = False
starred_rhs = rhs_args[starred:]
if lhs_remaining:
starred_rhs = starred_rhs[:-lhs_remaining]
if starred_rhs:
pos = starred_rhs[0].pos
else:
pos = target.pos
starred_assignments.append([
target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
def p_print_statement(s): def p_print_statement(s):
# s.sy == 'print' # s.sy == 'print'
......
...@@ -24,10 +24,15 @@ __doc__ = u""" ...@@ -24,10 +24,15 @@ __doc__ = u"""
(1, [2, 3]) (1, [2, 3])
(1, [2, 3, 4]) (1, [2, 3, 4])
3 3
(1, [2]) (1, [], 2)
(1, [2], 3) (1, [2], 3)
(1, [2, 3], 4) (1, [2, 3], 4)
>>> unpack_recursive((1,2,3,4))
(1, [2, 3], 4)
>>> unpack_typed((1,2))
([1], 2)
>>> assign() >>> assign()
(1, [2, 3, 4], 5) (1, [2, 3, 4], 5)
...@@ -94,7 +99,7 @@ ValueError: need more than 0 values to unpack ...@@ -94,7 +99,7 @@ ValueError: need more than 0 values to unpack
([], 1) ([], 1)
>>> unpack_left_list([1,2]) >>> unpack_left_list([1,2])
([1], 2) ([1], 2)
>>> unpack_left_list([1,2]) >>> unpack_left_list([1,2,3])
([1, 2], 3) ([1, 2], 3)
>>> unpack_left_tuple((1,)) >>> unpack_left_tuple((1,))
([], 1) ([], 1)
...@@ -152,13 +157,13 @@ ValueError: need more than 1 value to unpack ...@@ -152,13 +157,13 @@ ValueError: need more than 1 value to unpack
>>> a,b,c = unpack_middle(range(100)) >>> a,b,c = unpack_middle(range(100))
>>> a, len(b), c >>> a, len(b), c
0, 98, 99 (0, 98, 99)
>>> a,b,c = unpack_middle_list(range(100)) >>> a,b,c = unpack_middle_list(range(100))
>>> a, len(b), c >>> a, len(b), c
0, 98, 99 (0, 98, 99)
>>> a,b,c = unpack_middle_tuple(tuple(range(100))) >>> a,b,c = unpack_middle_tuple(tuple(range(100)))
>>> a, len(b), c >>> a, len(b), c
0, 98, 99 (0, 98, 99)
""" """
...@@ -176,38 +181,48 @@ def unpack_tuple(tuple t): ...@@ -176,38 +181,48 @@ def unpack_tuple(tuple t):
def assign(): def assign():
*a, b = 1,2,3,4,5 *a, b = 1,2,3,4,5
assert a+[b] == (1,2,3,4,5) assert a+[b] == [1,2,3,4,5], (a,b)
a, *b = 1,2,3,4,5 a, *b = 1,2,3,4,5
assert [a]+b == (1,2,3,4,5) assert [a]+b == [1,2,3,4,5], (a,b)
[a, *b, c] = 1,2,3,4,5 [a, *b, c] = 1,2,3,4,5
return a,b,c return a,b,c
def unpack_into_list(l): def unpack_into_list(l):
[*a, b] = l [*a, b] = l
assert a+[b] == l assert a+[b] == list(l), repr((a+[b],list(l)))
[a, *b] = l [a, *b] = l
assert [a]+b == l assert [a]+b == list(l), repr(([a]+b,list(l)))
[a, *b, c] = l [a, *b, c] = l
return a,b,c return a,b,c
def unpack_into_tuple(t): def unpack_into_tuple(t):
(*a, b) = t (*a, b) = t
assert a+(b,) == t assert a+[b] == list(t), repr((a+[b],list(t)))
(a, *b) = t (a, *b) = t
assert (a,)+b == t assert [a]+b == list(t), repr(([a]+b,list(t)))
(a, *b, c) = t (a, *b, c) = t
return a,b,c return a,b,c
def unpack_in_loop(list_of_sequences): def unpack_in_loop(list_of_sequences):
print 1 print 1
for *a,b in list_of_sequences: for *a,b in list_of_sequences:
print a,b print((a,b))
print 2 print 2
for a,*b in list_of_sequences: for a,*b in list_of_sequences:
print a,b print((a,b))
print 3 print 3
for a,*b, c in list_of_sequences: for a,*b, c in list_of_sequences:
print a,b,c print((a,b,c))
def unpack_recursive(t):
*(a, *b), c = t
return a,b,c
def unpack_typed(t):
cdef list a
*a, b = t
return a,b
def unpack_right(l): def unpack_right(l):
a, *b = l a, *b = l
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment