Commit a068e66f authored by Stefan Behnel's avatar Stefan Behnel

partial rewrite of dict iteration to fix two bugs:

- iteration over empty dicts could crash
- dict modification during iteration failed to raise an exception
parent 5ae635eb
......@@ -4507,8 +4507,8 @@ class WhileStatNode(LoopNode, StatNode):
self.else_clause.analyse_declarations(env)
def analyse_expressions(self, env):
self.condition = \
self.condition.analyse_temp_boolean_expression(env)
if self.condition:
self.condition = self.condition.analyse_temp_boolean_expression(env)
self.body.analyse_expressions(env)
if self.else_clause:
self.else_clause.analyse_expressions(env)
......@@ -4517,12 +4517,13 @@ class WhileStatNode(LoopNode, StatNode):
old_loop_labels = code.new_loop_labels()
code.putln(
"while (1) {")
self.condition.generate_evaluation_code(code)
self.condition.generate_disposal_code(code)
code.putln(
"if (!%s) break;" %
self.condition.result())
self.condition.free_temps(code)
if self.condition:
self.condition.generate_evaluation_code(code)
self.condition.generate_disposal_code(code)
code.putln(
"if (!%s) break;" %
self.condition.result())
self.condition.free_temps(code)
self.body.generate_execution_code(code)
code.put_label(code.continue_label)
code.putln("}")
......@@ -4535,18 +4536,64 @@ class WhileStatNode(LoopNode, StatNode):
code.put_label(break_label)
def generate_function_definitions(self, env, code):
self.condition.generate_function_definitions(env, code)
if self.condition:
self.condition.generate_function_definitions(env, code)
self.body.generate_function_definitions(env, code)
if self.else_clause is not None:
self.else_clause.generate_function_definitions(env, code)
def annotate(self, code):
self.condition.annotate(code)
if self.condition:
self.condition.annotate(code)
self.body.annotate(code)
if self.else_clause:
self.else_clause.annotate(code)
class DictIterationNextNode(Node):
# Helper node for calling PyDict_Next() inside of a WhileStatNode
# and checking the dictionary size for changes. Created in
# Optimize.py.
child_attrs = ['dict_obj', 'expected_size', 'pos_index_addr', 'key_addr', 'value_addr']
def __init__(self, dict_obj, expected_size, pos_index_addr, key_addr, value_addr):
Node.__init__(
self, dict_obj.pos,
dict_obj = dict_obj,
expected_size = expected_size,
pos_index_addr = pos_index_addr,
key_addr = key_addr,
value_addr = value_addr,
type = PyrexTypes.c_bint_type)
def analyse_expressions(self, env):
self.dict_obj.analyse_types(env)
self.expected_size.analyse_types(env)
self.pos_index_addr.analyse_types(env)
self.key_addr.analyse_types(env)
self.value_addr.analyse_types(env)
def generate_function_definitions(self, env, code):
self.dict_obj.generate_function_definitions(env, code)
def generate_execution_code(self, code):
self.dict_obj.generate_evaluation_code(code)
code.putln("if (unlikely(%s != PyDict_Size(%s))) {" % (
self.expected_size.result(),
self.dict_obj.py_result(),
))
code.putln('PyErr_SetString(PyExc_RuntimeError, "dictionary changed size during iteration"); %s' % (
code.error_goto(self.pos)))
code.putln("}")
self.pos_index_addr.generate_evaluation_code(code)
code.putln("if (!PyDict_Next(%s, %s, %s, %s)) break;" % (
self.dict_obj.py_result(),
self.pos_index_addr.result(),
self.key_addr.result(),
self.value_addr.result()))
def ForStatNode(pos, **kw):
if 'iterator' in kw:
return ForInStatNode(pos, **kw)
......
......@@ -63,18 +63,15 @@ class IterationTransform(Visitor.VisitorTransform):
- for-in-enumerate is replaced by an external counter variable
- for-in-range loop becomes a plain C for loop
"""
PyDict_Next_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyDict_Size_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
])
PyDict_Next_name = EncodedString("PyDict_Next")
PyDict_Size_name = EncodedString("PyDict_Size")
PyDict_Next_entry = Symtab.Entry(
PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
PyDict_Size_entry = Symtab.Entry(
PyDict_Size_name, PyDict_Size_name, PyDict_Size_func_type)
visit_Node = Visitor.VisitorTransform.recurse_to_children
......@@ -544,7 +541,7 @@ class IterationTransform(Visitor.VisitorTransform):
return for_node
def _transform_dict_iteration(self, node, dict_obj, keys, values):
py_object_ptr = PyrexTypes.c_void_ptr_type
py_object_ptr = PyrexTypes.py_object_type
temps = []
temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
......@@ -556,9 +553,12 @@ class IterationTransform(Visitor.VisitorTransform):
pos_temp_addr = ExprNodes.AmpersandNode(
node.pos, operand=pos_temp,
type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
target_temps = []
if keys:
temp = UtilNodes.TempHandle(py_object_ptr)
temps.append(temp)
temp = UtilNodes.TempHandle(
py_object_ptr, needs_cleanup=False) # ref will be stolen
target_temps.append(temp)
key_temp = temp.ref(node.target.pos)
key_temp_addr = ExprNodes.AmpersandNode(
node.target.pos, operand=key_temp,
......@@ -567,8 +567,9 @@ class IterationTransform(Visitor.VisitorTransform):
key_temp_addr = key_temp = ExprNodes.NullNode(
pos=node.target.pos)
if values:
temp = UtilNodes.TempHandle(py_object_ptr)
temps.append(temp)
temp = UtilNodes.TempHandle(
py_object_ptr, needs_cleanup=False) # ref will be stolen
target_temps.append(temp)
value_temp = temp.ref(node.target.pos)
value_temp_addr = ExprNodes.AmpersandNode(
node.target.pos, operand=value_temp,
......@@ -602,7 +603,7 @@ class IterationTransform(Visitor.VisitorTransform):
return (result, None)
else:
temp = UtilNodes.TempHandle(dest_type)
temps.append(temp)
target_temps.append(temp)
temp_result = temp.ref(obj_node.pos)
class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
def result(self):
......@@ -611,12 +612,6 @@ class IterationTransform(Visitor.VisitorTransform):
self.generate_result_code(code)
return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
if isinstance(node.body, Nodes.StatListNode):
body = node.body
else:
body = Nodes.StatListNode(pos = node.body.pos,
stats = [node.body])
if tuple_target:
tuple_result = ExprNodes.TupleNode(
pos = tuple_target.pos,
......@@ -624,11 +619,12 @@ class IterationTransform(Visitor.VisitorTransform):
is_temp = 1,
type = Builtin.tuple_type,
)
body.stats.insert(
0, Nodes.SingleAssignmentNode(
body_init_stats = [
Nodes.SingleAssignmentNode(
pos = tuple_target.pos,
lhs = tuple_target,
rhs = tuple_result))
rhs = tuple_result)
]
else:
# execute all coercions before the assignments
coercion_stats = []
......@@ -653,7 +649,29 @@ class IterationTransform(Visitor.VisitorTransform):
pos = value_temp.pos,
lhs = value_target,
rhs = temp_result))
body.stats[0:0] = coercion_stats + assign_stats
body_init_stats = coercion_stats + assign_stats
if isinstance(node.body, Nodes.StatListNode):
body = node.body
else:
body = Nodes.StatListNode(pos = node.body.pos,
stats = [node.body])
# keep original length to guard against dict modification
dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
temps.append(dict_len_temp)
body_init_stats.insert(0, Nodes.DictIterationNextNode(
dict_temp,
dict_len_temp.ref(dict_obj.pos),
pos_temp_addr, key_temp_addr, value_temp_addr
))
body.stats[0:0] = [UtilNodes.TempsBlockNode(
node.pos,
temps = target_temps,
body = Nodes.StatListNode(pos = node.pos,
stats = body_init_stats)
)]
result_code = [
Nodes.SingleAssignmentNode(
......@@ -665,19 +683,22 @@ class IterationTransform(Visitor.VisitorTransform):
lhs = pos_temp,
rhs = ExprNodes.IntNode(node.pos, value='0',
constant_result=0)),
Nodes.WhileStatNode(
pos = node.pos,
condition = ExprNodes.SimpleCallNode(
Nodes.SingleAssignmentNode(
pos = dict_obj.pos,
lhs = dict_len_temp.ref(dict_obj.pos),
rhs = ExprNodes.SimpleCallNode(
pos = dict_obj.pos,
type = PyrexTypes.c_bint_type,
type = PyrexTypes.c_py_ssize_t_type,
function = ExprNodes.NameNode(
pos = dict_obj.pos,
name = self.PyDict_Next_name,
type = self.PyDict_Next_func_type,
entry = self.PyDict_Next_entry),
args = [dict_temp, pos_temp_addr,
key_temp_addr, value_temp_addr]
),
name = self.PyDict_Size_name,
type = self.PyDict_Size_func_type,
entry = self.PyDict_Size_entry),
args = [dict_temp],
)),
Nodes.WhileStatNode(
pos = node.pos,
condition = None,
body = body,
else_clause = node.else_clause
)
......
......@@ -19,12 +19,13 @@ def items(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iteritems(dict d):
"""
>>> iteritems(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems({})
[]
"""
l = []
for k,v in d.iteritems():
......@@ -34,12 +35,13 @@ def iteritems(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iteritems_int(dict d):
"""
>>> iteritems_int(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems_int({})
[]
"""
cdef int k,v
l = []
......@@ -50,12 +52,13 @@ def iteritems_int(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iteritems_tuple(dict d):
"""
>>> iteritems_tuple(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems_tuple({})
[]
"""
l = []
for t in d.iteritems():
......@@ -65,8 +68,7 @@ def iteritems_tuple(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iteritems_listcomp(dict d):
cdef list l = [(k,v) for k,v in d.iteritems()]
l.sort()
......@@ -74,12 +76,13 @@ def iteritems_listcomp(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterkeys(dict d):
"""
>>> iterkeys(d)
[10, 11, 12, 13]
>>> iterkeys({})
[]
"""
l = []
for k in d.iterkeys():
......@@ -89,12 +92,13 @@ def iterkeys(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterkeys_int(dict d):
"""
>>> iterkeys_int(d)
[10, 11, 12, 13]
>>> iterkeys_int({})
[]
"""
cdef int k
l = []
......@@ -105,12 +109,13 @@ def iterkeys_int(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterdict(dict d):
"""
>>> iterdict(d)
[10, 11, 12, 13]
>>> iterdict({})
[]
"""
l = []
for k in d:
......@@ -120,12 +125,13 @@ def iterdict(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterdict_int(dict d):
"""
>>> iterdict_int(d)
[10, 11, 12, 13]
>>> iterdict_int({})
[]
"""
cdef int k
l = []
......@@ -136,12 +142,13 @@ def iterdict_int(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterdict_reassign(dict d):
"""
>>> iterdict_reassign(d)
[10, 11, 12, 13]
>>> iterdict_reassign({})
[]
"""
cdef dict d_new = {}
l = []
......@@ -153,12 +160,13 @@ def iterdict_reassign(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def iterdict_listcomp(dict d):
"""
>>> iterdict_listcomp(d)
[10, 11, 12, 13]
>>> iterdict_listcomp({})
[]
"""
cdef list l = [k for k in d]
l.sort()
......@@ -166,12 +174,13 @@ def iterdict_listcomp(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def itervalues(dict d):
"""
>>> itervalues(d)
[0, 1, 2, 3]
>>> itervalues({})
[]
"""
l = []
for v in d.itervalues():
......@@ -181,12 +190,13 @@ def itervalues(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def itervalues_int(dict d):
"""
>>> itervalues_int(d)
[0, 1, 2, 3]
>>> itervalues_int({})
[]
"""
cdef int v
l = []
......@@ -197,12 +207,13 @@ def itervalues_int(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def itervalues_listcomp(dict d):
"""
>>> itervalues_listcomp(d)
[0, 1, 2, 3]
>>> itervalues_listcomp({})
[]
"""
cdef list l = [v for v in d.itervalues()]
l.sort()
......@@ -210,13 +221,44 @@ def itervalues_listcomp(dict d):
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
"//WhileStatNode//DictIterationNextNode")
def itervalues_kwargs(**d):
"""
>>> itervalues_kwargs(a=1, b=2, c=3, d=4)
[1, 2, 3, 4]
>>> itervalues_kwargs()
[]
"""
cdef list l = [v for v in d.itervalues()]
l.sort()
return l
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode//DictIterationNextNode")
def iterdict_change_size(dict d):
"""
>>> count, i = 0, -1
>>> d = {1:2, 10:20}
>>> for i in d:
... d[i+1] = 5
... count += 1
... if count > 5:
... break # safety
Traceback (most recent call last):
RuntimeError: dictionary changed size during iteration
>>> iterdict_change_size({1:2, 10:20})
Traceback (most recent call last):
RuntimeError: dictionary changed size during iteration
>>> print( iterdict_change_size({}) )
DONE
"""
cdef int count = 0
i = -1
for i in d:
d[i+1] = 5
count += 1
if count > 5:
break # safety
return "DONE"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment