Commit 75474ef1 authored by Stefan Behnel's avatar Stefan Behnel

optimise indexing and slicing of bytearray

parent 38ea166b
...@@ -2723,7 +2723,8 @@ class IndexNode(ExprNode): ...@@ -2723,7 +2723,8 @@ class IndexNode(ExprNode):
elif base_type.is_pyunicode_ptr: elif base_type.is_pyunicode_ptr:
# sliced Py_UNICODE* strings must coerce to Python # sliced Py_UNICODE* strings must coerce to Python
return unicode_type return unicode_type
elif base_type in (unicode_type, bytes_type, str_type, list_type, tuple_type): elif base_type in (unicode_type, bytes_type, str_type,
bytearray_type, list_type, tuple_type):
# slicing these returns the same type # slicing these returns the same type
return base_type return base_type
else: else:
...@@ -2745,6 +2746,8 @@ class IndexNode(ExprNode): ...@@ -2745,6 +2746,8 @@ class IndexNode(ExprNode):
elif base_type is str_type: elif base_type is str_type:
# always returns str - Py2: bytes, Py3: unicode # always returns str - Py2: bytes, Py3: unicode
return base_type return base_type
elif base_type is bytearray_type:
return PyrexTypes.c_uchar_type
elif isinstance(self.base, BytesNode): elif isinstance(self.base, BytesNode):
#if env.global_scope().context.language_level >= 3: #if env.global_scope().context.language_level >= 3:
# # inferring 'char' can be made to work in Python 3 mode # # inferring 'char' can be made to work in Python 3 mode
...@@ -3014,7 +3017,7 @@ class IndexNode(ExprNode): ...@@ -3014,7 +3017,7 @@ class IndexNode(ExprNode):
if base_type.is_pyobject: if base_type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
if (not setting if (not setting
and (base_type in (list_type, tuple_type)) and (base_type in (list_type, tuple_type, bytearray_type))
and (not self.index.type.signed and (not self.index.type.signed
or not env.directives['wraparound'] or not env.directives['wraparound']
or (isinstance(self.index, IntNode) and or (isinstance(self.index, IntNode) and
...@@ -3032,6 +3035,9 @@ class IndexNode(ExprNode): ...@@ -3032,6 +3035,9 @@ class IndexNode(ExprNode):
# Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string # Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string
# if required, so this is fast and safe # if required, so this is fast and safe
self.type = PyrexTypes.c_py_ucs4_type self.type = PyrexTypes.c_py_ucs4_type
elif self.index.type.is_int and base_type is bytearray_type:
# not using uchar here to enable error reporting as '-1'
self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type): elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type self.type = base_type
else: else:
...@@ -3230,15 +3236,21 @@ class IndexNode(ExprNode): ...@@ -3230,15 +3236,21 @@ class IndexNode(ExprNode):
return "(*%s)" % self.buffer_ptr_code return "(*%s)" % self.buffer_ptr_code
elif self.is_memslice_copy: elif self.is_memslice_copy:
return self.base.result() return self.base.result()
elif self.base.type is list_type: elif self.base.type in (list_type, tuple_type, bytearray_type):
return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) if self.base.type is list_type:
index_code = "PyList_GET_ITEM(%s, %s)"
elif self.base.type is tuple_type: elif self.base.type is tuple_type:
return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) index_code = "PyTuple_GET_ITEM(%s, %s)"
elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type: elif self.base.type is bytearray_type:
error(self.pos, "Invalid use of pointer slice") index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))"
else:
assert False, "unexpected base type in indexing: %s" % self.base.type
else: else:
return "(%s[%s])" % ( if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
self.base.result(), self.index.result()) error(self.pos, "Invalid use of pointer slice")
return
index_code = "(%s[%s])"
return index_code % (self.base.result(), self.index.result())
def extra_index_params(self, code): def extra_index_params(self, code):
if self.index.type.is_int: if self.index.type.is_int:
...@@ -3344,6 +3356,22 @@ class IndexNode(ExprNode): ...@@ -3344,6 +3356,22 @@ class IndexNode(ExprNode):
self.extra_index_params(code), self.extra_index_params(code),
self.result(), self.result(),
code.error_goto(self.pos))) code.error_goto(self.pos)))
elif self.base.type is bytearray_type:
assert self.index.type.is_int
assert self.type.is_int
index_code = self.index.result()
function = "__Pyx_GetItemInt_ByteArray"
code.globalstate.use_utility_code(
UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
code.putln(
"%s = %s(%s, %s%s); if (unlikely(%s == -1)) %s;" % (
self.result(),
function,
self.base.py_result(),
index_code,
self.extra_index_params(code),
self.result(),
code.error_goto(self.pos)))
def generate_setitem_code(self, value_code, code): def generate_setitem_code(self, value_code, code):
if self.index.type.is_int: if self.index.type.is_int:
......
...@@ -227,6 +227,49 @@ static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int eq ...@@ -227,6 +227,49 @@ static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int eq
#endif #endif
} }
//////////////////// GetItemIntByteArray.proto ////////////////////
#define __Pyx_GetItemInt_ByteArray(o, i, size, to_py_func, is_list, wraparound, boundscheck) \
(((size) <= sizeof(Py_ssize_t)) ? \
__Pyx_GetItemInt_ByteArray_Fast(o, i, wraparound, boundscheck) : \
__Pyx_GetItemInt_ByteArray_Generic(o, to_py_func(i)))
static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i,
int wraparound, int boundscheck);
static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, PyObject* j);
//////////////////// GetItemIntByteArray ////////////////////
static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i,
int wraparound, int boundscheck) {
Py_ssize_t length;
if (wraparound | boundscheck) {
length = PyByteArray_GET_SIZE(string);
if (wraparound & unlikely(i < 0)) i += length;
if ((!boundscheck) || likely((0 <= i) & (i < length))) {
return (unsigned char) (PyByteArray_AS_STRING(string)[i]);
} else {
PyErr_SetString(PyExc_IndexError, "bytearray index out of range");
return -1;
}
} else {
return (unsigned char) (PyByteArray_AS_STRING(string)[i]);
}
}
static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, PyObject* j) {
unsigned char bchar;
PyObject *bchar_string;
if (!j) return -1;
bchar_string = PyObject_GetItem(string, j);
Py_DECREF(j);
if (!bchar_string) return -1;
bchar = (unsigned char) (PyByteArray_AS_STRING(bchar_string)[0]);
Py_DECREF(bchar_string);
return bchar;
}
//////////////////// GetItemIntUnicode.proto //////////////////// //////////////////// GetItemIntUnicode.proto ////////////////////
#define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func, is_list, wraparound, boundscheck) \ #define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func, is_list, wraparound, boundscheck) \
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# NOTE: Py2.6+ only # NOTE: Py2.6+ only
cimport cython
cpdef bytearray coerce_to_charptr(char* b): cpdef bytearray coerce_to_charptr(char* b):
""" """
>>> b = bytearray(b'abc') >>> b = bytearray(b'abc')
...@@ -35,3 +37,29 @@ cpdef bytearray coerce_charptr_slice(char* b): ...@@ -35,3 +37,29 @@ cpdef bytearray coerce_charptr_slice(char* b):
True True
""" """
return b[:2] return b[:2]
def infer_index_types(bytearray b):
"""
>>> b = bytearray(b'a\\xFEc')
>>> print(infer_index_types(b))
(254, 254, 254, 'unsigned char', 'unsigned char', 'unsigned char', 'int')
"""
c = b[1]
with cython.wraparound(False):
d = b[1]
with cython.boundscheck(False):
e = b[1]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1])
def infer_slice_types(bytearray b):
"""
>>> b = bytearray(b'abc')
>>> print(infer_slice_types(b))
(bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'Python object', 'Python object', 'Python object', 'bytearray object')
"""
c = b[1:]
with cython.boundscheck(False):
d = b[1:]
with cython.boundscheck(False), cython.wraparound(False):
e = b[1:]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment