Commit 1e0ab7ca authored by Stefan Behnel's avatar Stefan Behnel

handle value coercion correctly in dict iteration

parent cfb3beff
......@@ -4504,8 +4504,6 @@ class PyTypeTestNode(CoercionNode):
self.type = dst_type
self.gil_check(env)
self.result_ctype = arg.ctype()
if not dst_type.is_builtin_type:
env.use_utility_code(type_test_utility_code)
gil_message = "Python type test"
......@@ -4523,6 +4521,8 @@ class PyTypeTestNode(CoercionNode):
def generate_result_code(self, code):
if self.type.typeobj_is_available():
if not dst_type.is_builtin_type:
code.globalstate.use_utility_code(type_test_utility_code)
code.putln(
"if (!(%s)) %s" % (
self.type.type_test_code(self.arg.py_result()),
......
......@@ -111,16 +111,25 @@ class DictIterTransform(Visitor.VisitorTransform):
else:
tuple_target = node.target
if keys:
key_cast = ExprNodes.TypecastNode(
pos = key_target.pos,
operand = key_temp,
type = key_target.type)
if values:
value_cast = ExprNodes.TypecastNode(
pos = value_target.pos,
operand = value_temp,
type = value_target.type)
def coerce_object_to(obj_node, dest_type):
class FakeEnv(object):
nogil = False
if dest_type.is_pyobject:
if dest_type.is_extension_type or dest_type.is_builtin_type:
return (obj_node, ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv()))
else:
return (obj_node, None)
else:
temp = UtilNodes.TempHandle(dest_type)
temps.append(temp)
temp_result = temp.ref(obj_node.pos)
class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
# FIXME: remove this after result-code refactoring
def result(self):
return temp_result.result()
def generate_execution_code(self, code):
self.generate_result_code(code)
return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
if isinstance(node.body, Nodes.StatListNode):
body = node.body
......@@ -129,7 +138,7 @@ class DictIterTransform(Visitor.VisitorTransform):
stats = [node.body])
if tuple_target:
temp = UtilNodes.TempHandle(py_object_ptr)
temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
temps.append(temp)
temp_tuple = temp.ref(tuple_target.pos)
class TempTupleNode(ExprNodes.TupleNode):
......@@ -139,7 +148,7 @@ class DictIterTransform(Visitor.VisitorTransform):
tuple_result = TempTupleNode(
pos = tuple_target.pos,
args = [key_cast, value_cast],
args = [key_temp, value_temp],
is_temp = 1,
type = Builtin.tuple_type,
)
......@@ -148,18 +157,30 @@ class DictIterTransform(Visitor.VisitorTransform):
lhs = tuple_target,
rhs = tuple_result))
else:
if values:
body.stats.insert(
0, Nodes.SingleAssignmentNode(
pos = value_target.pos,
lhs = value_target,
rhs = value_cast))
# execute all coercions before the assignments
coercion_stats = []
assign_stats = []
if keys:
body.stats.insert(
0, Nodes.SingleAssignmentNode(
pos = key_target.pos,
lhs = key_target,
rhs = key_cast))
temp_result, coercion = coerce_object_to(
key_temp, key_target.type)
if coercion:
coercion_stats.append(coercion)
assign_stats.append(
Nodes.SingleAssignmentNode(
pos = key_temp.pos,
rhs = temp_result,
lhs = key_target))
if values:
temp_result, coercion = coerce_object_to(
value_temp, value_target.type)
if coercion:
coercion_stats.append(coercion)
assign_stats.append(
Nodes.SingleAssignmentNode(
pos = value_temp.pos,
rhs = temp_result,
lhs = value_target))
body.stats[0:0] = coercion_stats + assign_stats
result_code = [
Nodes.SingleAssignmentNode(
......
......@@ -6,14 +6,22 @@ __doc__ = u"""
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems_int(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iteritems_tuple(d)
[(10, 0), (11, 1), (12, 2), (13, 3)]
>>> iterkeys(d)
[10, 11, 12, 13]
>>> iterkeys_int(d)
[10, 11, 12, 13]
>>> iterdict(d)
[10, 11, 12, 13]
>>> iterdict_int(d)
[10, 11, 12, 13]
>>> itervalues(d)
[0, 1, 2, 3]
>>> itervalues_int(d)
[0, 1, 2, 3]
"""
def items(dict d):
......@@ -30,6 +38,14 @@ def iteritems(dict d):
l.sort()
return l
def iteritems_int(dict d):
cdef int k,v
l = []
for k,v in d.iteritems():
l.append((k,v))
l.sort()
return l
def iteritems_tuple(dict d):
l = []
for t in d.iteritems():
......@@ -44,6 +60,14 @@ def iterkeys(dict d):
l.sort()
return l
def iterkeys_int(dict d):
cdef int k
l = []
for k in d.iterkeys():
l.append(k)
l.sort()
return l
def iterdict(dict d):
l = []
for k in d:
......@@ -51,9 +75,25 @@ def iterdict(dict d):
l.sort()
return l
def iterdict_int(dict d):
cdef int k
l = []
for k in d:
l.append(k)
l.sort()
return l
def itervalues(dict d):
l = []
for v in d.itervalues():
l.append(v)
l.sort()
return l
def itervalues_int(dict d):
cdef int v
l = []
for v in d.itervalues():
l.append(v)
l.sort()
return l
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment