Commit 6d2b5074 authored by Stefan Behnel's avatar Stefan Behnel

Infer some bytearray operation types in the same way as for bytes operations.

parent 4e8c9b7c
...@@ -3287,7 +3287,7 @@ class _IndexingBaseNode(ExprNode): ...@@ -3287,7 +3287,7 @@ class _IndexingBaseNode(ExprNode):
# in most cases, indexing will return a safe reference to an object in a container, # in most cases, indexing will return a safe reference to an object in a container,
# so we consider the result safe if the base object is # so we consider the result safe if the base object is
return self.base.is_ephemeral() or self.base.type in ( return self.base.is_ephemeral() or self.base.type in (
basestring_type, str_type, bytes_type, unicode_type) basestring_type, str_type, bytes_type, bytearray_type, unicode_type)
def check_const_addr(self): def check_const_addr(self):
return self.base.check_const_addr() and self.index.check_const() return self.base.check_const_addr() and self.index.check_const()
...@@ -3347,7 +3347,7 @@ class IndexNode(_IndexingBaseNode): ...@@ -3347,7 +3347,7 @@ class IndexNode(_IndexingBaseNode):
return False return False
if isinstance(self.index, SliceNode): if isinstance(self.index, SliceNode):
# slicing! # slicing!
if base_type in (bytes_type, str_type, unicode_type, if base_type in (bytes_type, bytearray_type, str_type, unicode_type,
basestring_type, list_type, tuple_type): basestring_type, list_type, tuple_type):
return False return False
return ExprNode.may_be_none(self) return ExprNode.may_be_none(self)
...@@ -3572,7 +3572,7 @@ class IndexNode(_IndexingBaseNode): ...@@ -3572,7 +3572,7 @@ class IndexNode(_IndexingBaseNode):
else: else:
# not using 'uchar' to enable fast and safe error reporting as '-1' # not using 'uchar' to enable fast and safe error reporting as '-1'
self.type = PyrexTypes.c_int_type self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type): elif is_slice and base_type in (bytes_type, bytearray_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type self.type = base_type
else: else:
item_type = None item_type = None
...@@ -4569,7 +4569,7 @@ class SliceIndexNode(ExprNode): ...@@ -4569,7 +4569,7 @@ class SliceIndexNode(ExprNode):
return bytes_type return bytes_type
elif base_type.is_pyunicode_ptr: elif base_type.is_pyunicode_ptr:
return unicode_type return unicode_type
elif base_type in (bytes_type, str_type, unicode_type, elif base_type in (bytes_type, bytearray_type, str_type, unicode_type,
basestring_type, list_type, tuple_type): basestring_type, list_type, tuple_type):
return base_type return base_type
elif base_type.is_ptr or base_type.is_array: elif base_type.is_ptr or base_type.is_array:
...@@ -11124,7 +11124,7 @@ class AddNode(NumBinopNode): ...@@ -11124,7 +11124,7 @@ class AddNode(NumBinopNode):
def infer_builtin_types_operation(self, type1, type2): def infer_builtin_types_operation(self, type1, type2):
# b'abc' + 'abc' raises an exception in Py3, # b'abc' + 'abc' raises an exception in Py3,
# so we can safely infer the Py2 type for bytes here # so we can safely infer the Py2 type for bytes here
string_types = (bytes_type, str_type, basestring_type, unicode_type) string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
if type1 in string_types and type2 in string_types: if type1 in string_types and type2 in string_types:
return string_types[max(string_types.index(type1), return string_types[max(string_types.index(type1),
string_types.index(type2))] string_types.index(type2))]
...@@ -11183,7 +11183,7 @@ class MulNode(NumBinopNode): ...@@ -11183,7 +11183,7 @@ class MulNode(NumBinopNode):
def infer_builtin_types_operation(self, type1, type2): def infer_builtin_types_operation(self, type1, type2):
# let's assume that whatever builtin type you multiply a string with # let's assume that whatever builtin type you multiply a string with
# will either return a string of the same type or fail with an exception # will either return a string of the same type or fail with an exception
string_types = (bytes_type, str_type, basestring_type, unicode_type) string_types = (bytes_type, bytearray_type, str_type, basestring_type, unicode_type)
if type1 in string_types and type2.is_builtin_type: if type1 in string_types and type2.is_builtin_type:
return type1 return type1
if type2 in string_types and type1.is_builtin_type: if type2 in string_types and type1.is_builtin_type:
...@@ -13008,6 +13008,7 @@ class CoerceToBooleanNode(CoercionNode): ...@@ -13008,6 +13008,7 @@ class CoerceToBooleanNode(CoercionNode):
Builtin.set_type: 'PySet_GET_SIZE', Builtin.set_type: 'PySet_GET_SIZE',
Builtin.frozenset_type: 'PySet_GET_SIZE', Builtin.frozenset_type: 'PySet_GET_SIZE',
Builtin.bytes_type: 'PyBytes_GET_SIZE', Builtin.bytes_type: 'PyBytes_GET_SIZE',
Builtin.bytearray_type: 'PyByteArray_GET_SIZE',
Builtin.unicode_type: '__Pyx_PyUnicode_IS_TRUE', Builtin.unicode_type: '__Pyx_PyUnicode_IS_TRUE',
} }
......
...@@ -38,6 +38,27 @@ cpdef bytearray coerce_charptr_slice(char* b): ...@@ -38,6 +38,27 @@ cpdef bytearray coerce_charptr_slice(char* b):
""" """
return b[:2] return b[:2]
def infer_concatenation_types(bytearray b):
"""
>>> b = bytearray(b'a\\xFEc')
>>> b2, c, d, e, tb, tc, td, te = infer_concatenation_types(b)
>>> tb, tc, td, te
('bytearray object', 'bytearray object', 'bytearray object', 'bytearray object')
>>> b2, c, d, e
(bytearray(b'a\\xfec'), bytearray(b'a\\xfeca\\xfec'), bytearray(b'a\\xfeca\\xfec'), bytearray(b'a\\xfeca\\xfec'))
"""
c = b[:]
c += b[:]
d = b[:]
d *= 2
e = b + b
return b, c, d, e, cython.typeof(b), cython.typeof(c), cython.typeof(d), cython.typeof(e)
def infer_index_types(bytearray b): def infer_index_types(bytearray b):
""" """
>>> b = bytearray(b'a\\xFEc') >>> b = bytearray(b'a\\xFEc')
...@@ -51,11 +72,12 @@ def infer_index_types(bytearray b): ...@@ -51,11 +72,12 @@ def infer_index_types(bytearray b):
e = b[1] e = b[1]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1]) return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1])
def infer_slice_types(bytearray b): def infer_slice_types(bytearray b):
""" """
>>> b = bytearray(b'abc') >>> b = bytearray(b'abc')
>>> print(infer_slice_types(b)) >>> print(infer_slice_types(b))
(bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'Python object', 'Python object', 'Python object', 'bytearray object') (bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'bytearray object', 'bytearray object', 'bytearray object', 'bytearray object')
""" """
c = b[1:] c = b[1:]
with cython.boundscheck(False): with cython.boundscheck(False):
...@@ -64,6 +86,7 @@ def infer_slice_types(bytearray b): ...@@ -64,6 +86,7 @@ def infer_slice_types(bytearray b):
e = b[1:] e = b[1:]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:]) return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:])
def assign_to_index(bytearray b, value): def assign_to_index(bytearray b, value):
""" """
>>> b = bytearray(b'0abcdefg') >>> b = bytearray(b'0abcdefg')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment