Commit 6d2b5074 authored by Stefan Behnel's avatar Stefan Behnel

Infer some bytearray operation types in the same way as for bytes operations.

parent 4e8c9b7c
......@@ -3287,7 +3287,7 @@ class _IndexingBaseNode(ExprNode):
# in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type)
basestring_type, str_type, bytes_type, bytearray_type, unicode_type)
def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const()
......@@ -3347,7 +3347,7 @@ class IndexNode(_IndexingBaseNode):
return False
if isinstance(self.index, SliceNode):
# slicing!
if base_type in (bytes_type, str_type, unicode_type,
if base_type in (bytes_type, bytearray_type, str_type, unicode_type,
basestring_type, list_type, tuple_type):
return False
return ExprNode.may_be_none(self)
......@@ -3572,7 +3572,7 @@ class IndexNode(_IndexingBaseNode):
else:
# not using 'uchar' to enable fast and safe error reporting as '-1'
self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
elif is_slice and base_type in (bytes_type, bytearray_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type
else:
item_type = None
......@@ -4569,7 +4569,7 @@ class SliceIndexNode(ExprNode):
return bytes_type
elif base_type.is_pyunicode_ptr:
return unicode_type
elif base_type in (bytes_type, str_type, unicode_type,
elif base_type in (bytes_type, bytearray_type, str_type, unicode_type,
basestring_type, list_type, tuple_type):
return base_type
elif base_type.is_ptr or base_type.is_array:
......@@ -11124,7 +11124,7 @@ class AddNode(NumBinopNode):
def infer_builtin_types_operation(self, type1, type2):
# b'abc' + 'abc' raises an exception in Py3,
# so we can safely infer the Py2 type for bytes here
string_types = (bytes_type, str_type, basestring_type, unicode_type)
string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
if type1 in string_types and type2 in string_types:
return string_types[max(string_types.index(type1),
string_types.index(type2))]
......@@ -11183,7 +11183,7 @@ class MulNode(NumBinopNode):
def infer_builtin_types_operation(self, type1, type2):
# let's assume that whatever builtin type you multiply a string with
# will either return a string of the same type or fail with an exception
string_types = (bytes_type, str_type, basestring_type, unicode_type)
string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
if type1 in string_types and type2.is_builtin_type:
return type1
if type2 in string_types and type1.is_builtin_type:
......@@ -13008,6 +13008,7 @@ class CoerceToBooleanNode(CoercionNode):
Builtin.set_type: 'PySet_GET_SIZE',
Builtin.frozenset_type: 'PySet_GET_SIZE',
Builtin.bytes_type: 'PyBytes_GET_SIZE',
Builtin.bytearray_type: 'PyByteArray_GET_SIZE',
Builtin.unicode_type: '__Pyx_PyUnicode_IS_TRUE',
}
......
......@@ -38,6 +38,27 @@ cpdef bytearray coerce_charptr_slice(char* b):
"""
return b[:2]
def infer_concatenation_types(bytearray b):
"""
>>> b = bytearray(b'a\\xFEc')
>>> b2, c, d, e, tb, tc, td, te = infer_concatenation_types(b)
>>> tb, tc, td, te
('bytearray object', 'bytearray object', 'bytearray object', 'bytearray object')
>>> b2, c, d, e
(bytearray(b'a\\xfec'), bytearray(b'a\\xfeca\\xfec'), bytearray(b'a\\xfeca\\xfec'), bytearray(b'a\\xfeca\\xfec'))
"""
c = b[:]
c += b[:]
d = b[:]
d *= 2
e = b + b
return b, c, d, e, cython.typeof(b), cython.typeof(c), cython.typeof(d), cython.typeof(e)
def infer_index_types(bytearray b):
"""
>>> b = bytearray(b'a\\xFEc')
......@@ -51,11 +72,12 @@ def infer_index_types(bytearray b):
e = b[1]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1])
def infer_slice_types(bytearray b):
"""
>>> b = bytearray(b'abc')
>>> print(infer_slice_types(b))
(bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'Python object', 'Python object', 'Python object', 'bytearray object')
(bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'bytearray object', 'bytearray object', 'bytearray object', 'bytearray object')
"""
c = b[1:]
with cython.boundscheck(False):
......@@ -64,6 +86,7 @@ def infer_slice_types(bytearray b):
e = b[1:]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:])
def assign_to_index(bytearray b, value):
"""
>>> b = bytearray(b'0abcdefg')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment