Commit 0756de96 authored by Robert Bradshaw's avatar Robert Bradshaw

CTuple type unpacking.

parent 4598f534
...@@ -1163,7 +1163,7 @@ class CTupleBaseTypeNode(CBaseTypeNode): ...@@ -1163,7 +1163,7 @@ class CTupleBaseTypeNode(CBaseTypeNode):
child_attrs = ["components"] child_attrs = ["components"]
def analyse(self, env): def analyse(self, env, could_be_name=False):
component_types = [] component_types = []
for c in self.components: for c in self.components:
type = c.analyse(env) type = c.analyse(env)
...@@ -1171,7 +1171,7 @@ class CTupleBaseTypeNode(CBaseTypeNode): ...@@ -1171,7 +1171,7 @@ class CTupleBaseTypeNode(CBaseTypeNode):
error(type_node.pos, "Tuple types can't (yet) contain Python objects.") error(type_node.pos, "Tuple types can't (yet) contain Python objects.")
return PyrexType.error_type return PyrexType.error_type
component_types.append(type) component_types.append(type)
type = PyrexTypes.c_tuple_type(tuple(component_types)) type = PyrexTypes.c_tuple_type(component_types)
env.declare_tuple_type(self.pos, type) env.declare_tuple_type(self.pos, type)
return type return type
...@@ -1195,7 +1195,7 @@ class FusedTypeNode(CBaseTypeNode): ...@@ -1195,7 +1195,7 @@ class FusedTypeNode(CBaseTypeNode):
# Omit the typedef declaration that self.declarator would produce # Omit the typedef declaration that self.declarator would produce
entry.in_cinclude = True entry.in_cinclude = True
def analyse(self, env): def analyse(self, env, could_be_name = False):
types = [] types = []
for type_node in self.types: for type_node in self.types:
type = type_node.analyse_as_type(env) type = type_node.analyse_as_type(env)
...@@ -4644,7 +4644,7 @@ class AssignmentNode(StatNode): ...@@ -4644,7 +4644,7 @@ class AssignmentNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
node = self.analyse_types(env) node = self.analyse_types(env)
if isinstance(node, AssignmentNode): if isinstance(node, AssignmentNode) and not isinstance(node, ParallelAssignmentNode):
if node.rhs.type.is_ptr and node.rhs.is_ephemeral(): if node.rhs.type.is_ptr and node.rhs.is_ephemeral():
error(self.pos, "Storing unsafe C derivative of temporary Python reference") error(self.pos, "Storing unsafe C derivative of temporary Python reference")
return node return node
...@@ -4750,12 +4750,30 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4750,12 +4750,30 @@ class SingleAssignmentNode(AssignmentNode):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
def analyse_types(self, env, use_temp = 0): def analyse_types(self, env, use_temp = 0):
from . import ExprNodes from . import ExprNodes, UtilNodes
self.rhs = self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
if self.rhs.type.is_ctuple and isinstance(self.lhs, ExprNodes.TupleNode):
if self.rhs.type.size == len(self.lhs.args):
rhs = UtilNodes.LetRefNode(self.rhs)
nodes = []
for ix, lhs in enumerate(self.lhs.args):
nodes.append(SingleAssignmentNode(
pos = self.pos,
lhs = lhs,
rhs = ExprNodes.IndexNode(
pos=self.pos,
base=rhs,
index=ExprNodes.IntNode(pos=self.pos, value=str(ix))),
first = self.first))
return UtilNodes.LetNode(rhs, ParallelAssignmentNode(pos=self.pos, stats=nodes).analyse_expressions(env))
else:
error(self.pos, "Unpacking type %s requires exactly %s arguments." % (
self.rhs.type, self.rhs.type.size))
if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast: if self.lhs.memslice_broadcast or self.rhs.memslice_broadcast:
self.lhs.memslice_broadcast = True self.lhs.memslice_broadcast = True
self.rhs.memslice_broadcast = True self.rhs.memslice_broadcast = True
......
...@@ -3384,6 +3384,7 @@ class CTupleType(CType): ...@@ -3384,6 +3384,7 @@ class CTupleType(CType):
c_tuple_types = {} c_tuple_types = {}
def c_tuple_type(components): def c_tuple_type(components):
components = tuple(components)
tuple_type = c_tuple_types.get(components) tuple_type = c_tuple_types.get(components)
if tuple_type is None: if tuple_type is None:
cname = '__pyx_tuple_' + '___'.join( cname = '__pyx_tuple_' + '___'.join(
......
...@@ -15,16 +15,48 @@ def simple_convert(*o): ...@@ -15,16 +15,48 @@ def simple_convert(*o):
cdef (int, double) xy = o cdef (int, double) xy = o
return xy return xy
def rotate_via_indexing((int, int, double) xyz): def indexing((int, double) xy):
""" """
>>> rotate_via_indexing((1, 2, 3)) >>> indexing((1, 2))
(2, 3, 1.0) (2, 3.0)
""" """
a = xyz[0] x = xy[0]
xyz[0] = xyz[1] y = xy[1]
xyz[1] = <int>xyz[2] xy[0] = x + 1
xyz[-1] = a xy[1] = y + 1
return xyz return xy
def unpacking((int, double) xy):
"""
>>> unpacking((1, 2))
(1, 2.0)
"""
x, y = xy
return x, y
cdef (int, double) side_effect((int, double) xy):
print "called with", xy
return xy
def unpacking_with_side_effect((int, double) xy):
"""
>>> unpacking_with_side_effect((1, 2))
called with (1, 2.0)
(1, 2.0)
"""
x, y = side_effect(xy)
return x, y
def c_types(int a, double b):
"""
>>> c_types(1, 2)
(1, 2.0)
"""
cdef (int*, double*) ab
ab[0] = &a
ab[1] = &b
a_ptr, b_ptr = ab[0], ab[1]
return a_ptr[0], b_ptr[0]
cpdef (int, double) ctuple_return_type(x, y): cpdef (int, double) ctuple_return_type(x, y):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment