Commit ceb39bf6 authored by Stefan Behnel's avatar Stefan Behnel

clean up C tuple assignment unrolling code

parent d7246a58
...@@ -4839,7 +4839,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4839,7 +4839,7 @@ class SingleAssignmentNode(AssignmentNode):
check_node = None check_node = None
else: else:
error(self.pos, "Unpacking type %s requires exactly %s arguments." % ( error(self.pos, "Unpacking type %s requires exactly %s arguments." % (
node.type, node.type.size)) node.type, node.type.size))
return return
elif node.type.is_ptr: elif node.type.is_ptr:
...@@ -4853,7 +4853,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4853,7 +4853,10 @@ class SingleAssignmentNode(AssignmentNode):
stop_node = stop_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env) stop_node = stop_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env)
else: else:
if node.type.is_array and node.type.size: if node.type.is_array and node.type.size:
stop_node = ExprNodes.IntNode(pos=self.pos, value=str(rhs.type.size)) stop_node = ExprNodes.IntNode(
self.pos, value=str(node.type.size),
constant_result=(node.type.size if isinstance(node.type.size, (int, long))
else ExprNodes.constant_value_not_set))
else: else:
error(self.pos, "C array iteration requires known end index") error(self.pos, "C array iteration requires known end index")
return return
...@@ -4886,38 +4889,41 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4886,38 +4889,41 @@ class SingleAssignmentNode(AssignmentNode):
return return
items = [] items = []
base_ref = UtilNodes.LetRefNode(base) base = UtilNodes.LetRefNode(base)
refs = [base_ref] refs = [base]
if start_node: if start_node and not start_node.is_literal:
start_node = UtilNodes.LetRefNode(start_node) start_node = UtilNodes.LetRefNode(start_node)
refs.append(start_node) refs.append(start_node)
if stop_node: if stop_node and not stop_node.is_literal:
stop_node = UtilNodes.LetRefNode(stop_node) stop_node = UtilNodes.LetRefNode(stop_node)
refs.append(stop_node) refs.append(stop_node)
if step_node: if step_node and not step_node.is_literal:
step_node = UtilNodes.LetRefNode(step_node) step_node = UtilNodes.LetRefNode(step_node)
refs.append(step_node) refs.append(step_node)
for ix in range(target_size): for ix in range(target_size):
ix_node = ExprNodes.IntNode(pos=self.pos, value=str(ix)) ix_node = ExprNodes.IntNode(self.pos, value=str(ix), constant_result=ix, type=PyrexTypes.c_py_ssize_t_type)
if step_node is not None: if step_node is not None:
ix_node = ExprNodes.MulNode(pos=self.pos, operator='*', operand1=step_node, operand2=ix_node).analyse_types(env) if step_node.has_constant_result():
step_value = ix_node.constant_result * step_node.constant_result
ix_node = ExprNodes.IntNode(self.pos, value=str(step_value), constant_result=step_value)
else:
ix_node = ExprNodes.MulNode(self.pos, operator='*', operand1=step_node, operand2=ix_node)
if start_node is not None: if start_node is not None:
ix_node = ExprNodes.AddNode(pos=self.pos, operator='+', operand1=start_node, operand2=ix_node).analyse_types(env) if start_node.has_constant_result() and ix_node.has_constant_result():
items.append(ExprNodes.IndexNode( index_value = ix_node.constant_result + start_node.constant_result
pos=self.pos, ix_node = ExprNodes.IntNode(self.pos, value=str(index_value), constant_result=index_value)
base=base_ref, else:
index=ix_node)) ix_node = ExprNodes.AddNode(
self.pos, operator='+', operand1=start_node, operand2=ix_node)
items.append(ExprNodes.IndexNode(self.pos, base=base, index=ix_node.analyse_types(env)))
return check_node, refs, items return check_node, refs, items
def unroll_assignments(self, refs, check_node, lhs_list, rhs_list, env): def unroll_assignments(self, refs, check_node, lhs_list, rhs_list, env):
from . import ExprNodes, UtilNodes from . import UtilNodes
assignments = [] assignments = []
for lhs, rhs in zip(lhs_list, rhs_list): for lhs, rhs in zip(lhs_list, rhs_list):
assignments.append(SingleAssignmentNode( assignments.append(SingleAssignmentNode(self.pos, lhs=lhs, rhs=rhs, first=self.first))
pos = self.pos,
lhs = lhs,
rhs = rhs,
first = self.first))
all = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env) all = ParallelAssignmentNode(pos=self.pos, stats=assignments).analyse_expressions(env)
if check_node: if check_node:
all = StatListNode(pos=self.pos, stats=[check_node, all]) all = StatListNode(pos=self.pos, stats=[check_node, all])
...@@ -4926,7 +4932,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4926,7 +4932,7 @@ class SingleAssignmentNode(AssignmentNode):
return all return all
def unroll_rhs(self, env): def unroll_rhs(self, env):
from . import ExprNodes, UtilNodes from . import ExprNodes
if not isinstance(self.lhs, ExprNodes.TupleNode): if not isinstance(self.lhs, ExprNodes.TupleNode):
return return
for arg in self.lhs.args: for arg in self.lhs.args:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment