Commit c8852889 authored by Stefan Behnel's avatar Stefan Behnel

implement "for int_var in bytes_string" and "for int_var in unicode_string"

parent fb77dce0
...@@ -95,7 +95,11 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -95,7 +95,11 @@ class IterationTransform(Visitor.VisitorTransform):
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
elif iterator.type.is_array: elif iterator.type.is_array:
return self._transform_carray_iteration(node, iterator) return self._transform_carray_iteration(node, iterator)
elif not isinstance(iterator, ExprNodes.SimpleCallNode): elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
return self._transform_string_iteration(node, iterator)
# the rest is based on function calls
if not isinstance(iterator, ExprNodes.SimpleCallNode):
return node return node
function = iterator.function function = iterator.function
...@@ -132,6 +136,71 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -132,6 +136,71 @@ class IterationTransform(Visitor.VisitorTransform):
return node return node
PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_int_ptr_type, [ # FIXME: return type is actually Py_UNICODE*
PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
])
PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
])
PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_char_ptr_type, [
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
])
PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
])
def _transform_string_iteration(self, node, slice_node):
if not node.target.type.is_int:
return node
if slice_node.type is Builtin.unicode_type:
unpack_func = "PyUnicode_AS_UNICODE"
len_func = "PyUnicode_GET_SIZE"
unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
len_func_type = self.PyUnicode_GET_SIZE_func_type
elif slice_node.type is Builtin.bytes_type:
unpack_func = "PyBytes_AS_STRING"
unpack_func_type = self.PyBytes_AS_STRING_func_type
len_func = "PyBytes_GET_SIZE"
len_func_type = self.PyBytes_GET_SIZE_func_type
else:
return node
unpack_temp_node = UtilNodes.LetRefNode(
ExprNodes.NoneCheckNode(
slice_node, "PyExc_TypeError", "'NoneType' is not iterable"))
slice_base_node = ExprNodes.PythonCapiCallNode(
slice_node.pos, unpack_func, unpack_func_type,
args = [unpack_temp_node],
is_temp = 0,
)
len_node = ExprNodes.PythonCapiCallNode(
slice_node.pos, len_func, len_func_type,
args = [unpack_temp_node],
is_temp = 0,
)
return UtilNodes.LetNode(
unpack_temp_node,
self._transform_carray_iteration(
node,
ExprNodes.SliceIndexNode(
slice_node.pos,
base = slice_base_node,
start = None,
step = None,
stop = len_node,
type = slice_base_node.type,
is_temp = 1,
)))
def _transform_carray_iteration(self, node, slice_node): def _transform_carray_iteration(self, node, slice_node):
if isinstance(slice_node, ExprNodes.SliceIndexNode): if isinstance(slice_node, ExprNodes.SliceIndexNode):
slice_base = slice_node.base slice_base = slice_node.base
...@@ -166,7 +235,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -166,7 +235,7 @@ class IterationTransform(Visitor.VisitorTransform):
stop_ptr_node = ExprNodes.AddNode( stop_ptr_node = ExprNodes.AddNode(
stop.pos, stop.pos,
operand1=carray_ptr, operand1=ExprNodes.CloneNode(carray_ptr),
operator='+', operator='+',
operand2=stop, operand2=stop,
type=ptr_type type=ptr_type
......
bytes_abc = b'abc'
bytes_ABC = b'ABC'
unicode_abc = u'abc'
unicode_ABC = u'ABC'
def for_in_bytes(bytes s):
"""
>>> for_in_bytes(bytes_abc)
'X'
>>> for_in_bytes(bytes_ABC)
'C'
"""
for c in s:
if c == 'C':
return 'C'
else:
return 'X'
def for_char_in_bytes(bytes s):
"""
>>> for_char_in_bytes(bytes_abc)
'X'
>>> for_char_in_bytes(bytes_ABC)
'C'
"""
cdef char c
for c in s:
if c == 'C':
return 'C'
else:
return 'X'
def for_int_in_unicode(unicode s):
"""
>>> for_int_in_unicode(unicode_abc)
'X'
>>> for_int_in_unicode(unicode_ABC)
'C'
"""
cdef int c
for c in s:
if c == 'C':
return 'C'
else:
return 'X'
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment