Commit 59c3408c authored by Stefan Behnel's avatar Stefan Behnel

implement (and thus, fix) index assignments to bytearray objects

parent 18eb3fe6
...@@ -9,7 +9,7 @@ Features added ...@@ -9,7 +9,7 @@ Features added
-------------- --------------
* ``bytearray`` has become a known type and supports coercion from and * ``bytearray`` has become a known type and supports coercion from and
to C strings. to C strings. Indexing, slicing and decoding is optimised.
* Using ``cdef basestring stringvar`` and function arguments typed as * Using ``cdef basestring stringvar`` and function arguments typed as
``basestring`` is now meaningful and allows assigning exactly ``basestring`` is now meaningful and allows assigning exactly
......
...@@ -259,6 +259,7 @@ class ExprNode(Node): ...@@ -259,6 +259,7 @@ class ExprNode(Node):
is_sequence_constructor = 0 is_sequence_constructor = 0
is_string_literal = 0 is_string_literal = 0
is_attribute = 0 is_attribute = 0
is_subscript = 0
saved_subexpr_nodes = None saved_subexpr_nodes = None
is_temp = 0 is_temp = 0
...@@ -2645,6 +2646,7 @@ class IndexNode(ExprNode): ...@@ -2645,6 +2646,7 @@ class IndexNode(ExprNode):
subexprs = ['base', 'index', 'indices'] subexprs = ['base', 'index', 'indices']
indices = None indices = None
is_subscript = True
is_fused_index = False is_fused_index = False
# Whether we're assigning to a buffer (in that case it needs to be # Whether we're assigning to a buffer (in that case it needs to be
...@@ -3036,7 +3038,10 @@ class IndexNode(ExprNode): ...@@ -3036,7 +3038,10 @@ class IndexNode(ExprNode):
# if required, so this is fast and safe # if required, so this is fast and safe
self.type = PyrexTypes.c_py_ucs4_type self.type = PyrexTypes.c_py_ucs4_type
elif self.index.type.is_int and base_type is bytearray_type: elif self.index.type.is_int and base_type is bytearray_type:
# not using uchar here to enable error reporting as '-1' if setting:
self.type = PyrexTypes.c_uchar_type
else:
# not using 'uchar' to enable fast and safe error reporting as '-1'
self.type = PyrexTypes.c_int_type self.type = PyrexTypes.c_int_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type): elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type self.type = base_type
...@@ -3378,10 +3383,15 @@ class IndexNode(ExprNode): ...@@ -3378,10 +3383,15 @@ class IndexNode(ExprNode):
def generate_setitem_code(self, value_code, code): def generate_setitem_code(self, value_code, code):
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_SetItemInt" if self.base.type is bytearray_type:
index_code = self.index.result() code.globalstate.use_utility_code(
UtilityCode.load_cached("SetItemIntByteArray", "StringTools.c"))
function = "__Pyx_SetItemInt_ByteArray"
else:
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load_cached("SetItemInt", "ObjectHandling.c")) UtilityCode.load_cached("SetItemInt", "ObjectHandling.c"))
function = "__Pyx_SetItemInt"
index_code = self.index.result()
else: else:
index_code = self.index.py_result() index_code = self.index.py_result()
if self.base.type is dict_type: if self.base.type is dict_type:
...@@ -3396,7 +3406,7 @@ class IndexNode(ExprNode): ...@@ -3396,7 +3406,7 @@ class IndexNode(ExprNode):
else: else:
function = "PyObject_SetItem" function = "PyObject_SetItem"
code.putln( code.putln(
"if (%s(%s, %s, %s%s) < 0) %s" % ( "if (unlikely(%s(%s, %s, %s%s) < 0)) %s" % (
function, function,
self.base.py_result(), self.base.py_result(),
index_code, index_code,
...@@ -3441,6 +3451,8 @@ class IndexNode(ExprNode): ...@@ -3441,6 +3451,8 @@ class IndexNode(ExprNode):
self.generate_memoryviewslice_setslice_code(rhs, code) self.generate_memoryviewslice_setslice_code(rhs, code)
elif self.type.is_pyobject: elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
elif self.base.type is bytearray_type:
self.generate_setitem_code(rhs.result(), code)
else: else:
code.putln( code.putln(
"%s = %s;" % ( "%s = %s;" % (
......
...@@ -4860,6 +4860,8 @@ class DelStatNode(StatNode): ...@@ -4860,6 +4860,8 @@ class DelStatNode(StatNode):
self.cpp_check(env) self.cpp_check(env)
elif arg.type.is_cpp_class: elif arg.type.is_cpp_class:
error(arg.pos, "Deletion of non-heap C++ object") error(arg.pos, "Deletion of non-heap C++ object")
elif arg.is_subscript and arg.base.type is Builtin.bytearray_type:
pass # del ba[i]
else: else:
error(arg.pos, "Deletion of non-Python, non-C++ object") error(arg.pos, "Deletion of non-Python, non-C++ object")
#arg.release_target_temp(env) #arg.release_target_temp(env)
...@@ -4874,7 +4876,9 @@ class DelStatNode(StatNode): ...@@ -4874,7 +4876,9 @@ class DelStatNode(StatNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
for arg in self.args: for arg in self.args:
if arg.type.is_pyobject or arg.type.is_memoryviewslice: if (arg.type.is_pyobject or
arg.type.is_memoryviewslice or
arg.is_subscript and arg.base.type is Builtin.bytearray_type):
arg.generate_deletion_code( arg.generate_deletion_code(
code, ignore_nonexisting=self.ignore_nonexisting) code, ignore_nonexisting=self.ignore_nonexisting)
elif arg.type.is_ptr and arg.type.base_type.is_cpp_class: elif arg.type.is_ptr and arg.type.base_type.is_cpp_class:
......
...@@ -270,6 +270,51 @@ static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, Py ...@@ -270,6 +270,51 @@ static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, Py
} }
//////////////////// SetItemIntByteArray.proto ////////////////////
#define __Pyx_SetItemInt_ByteArray(o, i, v, size, to_py_func, is_list, wraparound, boundscheck) \
(((size) <= sizeof(Py_ssize_t)) ? \
__Pyx_SetItemInt_ByteArray_Fast(o, i, v, wraparound, boundscheck) : \
__Pyx_SetItemInt_ByteArray_Generic(o, to_py_func(i), v))
static CYTHON_INLINE int __Pyx_SetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i, unsigned char v,
int wraparound, int boundscheck);
static CYTHON_INLINE int __Pyx_SetItemInt_ByteArray_Generic(PyObject* string, PyObject* j, unsigned char v);
//////////////////// SetItemIntByteArray ////////////////////
static CYTHON_INLINE int __Pyx_SetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i, unsigned char v,
int wraparound, int boundscheck) {
Py_ssize_t length;
if (wraparound | boundscheck) {
length = PyByteArray_GET_SIZE(string);
if (wraparound & unlikely(i < 0)) i += length;
if ((!boundscheck) || likely((0 <= i) & (i < length))) {
PyByteArray_AS_STRING(string)[i] = (char) v;
return 0;
} else {
PyErr_SetString(PyExc_IndexError, "bytearray index out of range");
return -1;
}
} else {
PyByteArray_AS_STRING(string)[i] = (char) v;
return 0;
}
}
static CYTHON_INLINE int __Pyx_SetItemInt_ByteArray_Generic(PyObject* string, PyObject* j, unsigned char v) {
unsigned char bchar;
PyObject *bchar_string;
if (!j) return -1;
bchar_string = PyObject_GetItem(string, j);
Py_DECREF(j);
if (!bchar_string) return -1;
bchar = (unsigned char) (PyByteArray_AS_STRING(bchar_string)[0]);
Py_DECREF(bchar_string);
return bchar;
}
//////////////////// GetItemIntUnicode.proto //////////////////// //////////////////// GetItemIntUnicode.proto ////////////////////
#define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func, is_list, wraparound, boundscheck) \ #define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func, is_list, wraparound, boundscheck) \
......
...@@ -63,3 +63,52 @@ def infer_slice_types(bytearray b): ...@@ -63,3 +63,52 @@ def infer_slice_types(bytearray b):
with cython.boundscheck(False), cython.wraparound(False): with cython.boundscheck(False), cython.wraparound(False):
e = b[1:] e = b[1:]
return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:]) return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:])
def assign_to_index(bytearray b, value):
"""
>>> b = bytearray(b'0abcdefg')
>>> assign_to_index(b, 1)
bytearray(b'xyzee\\x01h')
>>> b
bytearray(b'xyzee\\x01h')
>>> assign_to_index(bytearray(b'0ABCDEFG'), 40)
bytearray(b'xyzEE(o')
>>> assign_to_index(bytearray(b'0abcdefg'), -1)
Traceback (most recent call last):
OverflowError: can't convert negative value to unsigned char
>>> assign_to_index(bytearray(b'0abcdef\\x00'), 255)
bytearray(b'xyzee\\xff\\xff')
>>> assign_to_index(bytearray(b'0abcdef\\x01'), 255)
Traceback (most recent call last):
OverflowError: value too large to convert to unsigned char
>>> assign_to_index(bytearray(b'0abcdef\\x00'), 256)
Traceback (most recent call last):
OverflowError: value too large to convert to unsigned char
"""
b[1] = 'x'
b[2] = b'y'
b[3] = c'z'
b[4] += 1
b[5] |= 1
b[6] = value
b[7] += value
del b[0]
try:
b[7] = 1
except IndexError:
pass
else:
assert False, "IndexError not raised"
try:
b[int(str(len(b)))] = 1 # test non-int-index assignment
except IndexError:
pass
else:
assert False, "IndexError not raised"
return b
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment