Commit 1a58859d authored by Stefan Behnel's avatar Stefan Behnel

optimise bytes.decode()

parent 17fa6caf
......@@ -8,6 +8,8 @@ Cython Changelog
Features added
--------------
* ``py_bytes_string.decode(...)`` is optimised.
Bugs fixed
----------
......
......@@ -2623,6 +2623,16 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
])
_decode_bytes_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
])
_decode_cpp_string_func_type = None # lazy init
def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
......@@ -2634,19 +2644,29 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
return node
# normalise input nodes
if isinstance(args[0], ExprNodes.SliceIndexNode):
index_node = args[0]
string_node = args[0]
start = stop = None
if isinstance(string_node, ExprNodes.SliceIndexNode):
index_node = string_node
string_node = index_node.base
start, stop = index_node.start, index_node.stop
if not start or start.constant_result == 0:
start = None
elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode):
string_node = args[0].arg
start = stop = None
else:
return node
if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
string_node = string_node.arg
if not string_node.type.is_string and not string_node.type.is_cpp_string:
string_type = string_node.type
if string_type is Builtin.bytes_type:
if is_unbound_method:
string_node = string_node.as_none_safe_node(
"descriptor '%s' requires a '%s' object but received a 'NoneType'",
format_args = ['decode', 'bytes'])
else:
string_node = string_node.as_none_safe_node(
"'NoneType' object has no attribute '%s'",
error = "PyExc_AttributeError",
format_args = ['decode'])
elif not string_type.is_string and not string_type.is_cpp_string:
# nothing to optimise here
return node
......@@ -2676,7 +2696,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
# build the helper function call
temps = []
if string_node.type.is_string:
if string_type.is_string:
# C string
if not stop:
# use strlen() to find the string length, just as CPython would
......@@ -2691,7 +2711,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
helper_func_type = self._decode_c_string_func_type
utility_code_name = 'decode_c_string'
else:
elif string_type.is_cpp_string:
# C++ std::string
if not stop:
stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
......@@ -2700,7 +2720,7 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
# lazy init to reuse the C++ string type
self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", string_node.type, None),
PyrexTypes.CFuncTypeArg("string", string_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
......@@ -2709,6 +2729,13 @@ class OptimizeBuiltinCalls(Visitor.MethodDispatcherTransform):
])
helper_func_type = self._decode_cpp_string_func_type
utility_code_name = 'decode_cpp_string'
else:
# Python bytes object
if not stop:
stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
constant_result=ExprNodes.not_a_constant)
helper_func_type = self._decode_bytes_func_type
utility_code_name = 'decode_bytes'
node = ExprNodes.PythonCapiCallNode(
node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
......
......@@ -301,3 +301,40 @@ static CYTHON_INLINE PyObject* __Pyx_decode_c_string(
return PyUnicode_Decode(cstring, length, encoding, errors);
}
}
/////////////// decode_bytes.proto ///////////////
static CYTHON_INLINE PyObject* __Pyx_decode_bytes(
PyObject* string, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors));
/////////////// decode_bytes ///////////////
static CYTHON_INLINE PyObject* __Pyx_decode_bytes(
PyObject* string, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors)) {
char* cstring;
Py_ssize_t length = PyBytes_GET_SIZE(string);
if (unlikely((start < 0) | (stop < 0))) {
if (start < 0) {
start += length;
if (start < 0)
start = 0;
}
if (stop < 0)
stop += length;
}
if (stop > length)
stop = length;
length = stop - start;
if (unlikely(length <= 0))
return PyUnicode_FromUnicode(NULL, 0);
cstring = PyBytes_AS_STRING(string) + start;
if (decode_func) {
return decode_func(cstring, length, errors);
} else {
return PyUnicode_Decode(cstring, length, encoding, errors);
}
}
......@@ -3,8 +3,11 @@ cimport cython
b_a = b'a'
b_b = b'b'
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
@cython.test_fail_if_path_exists(
"//SimpleCallNode")
def bytes_startswith(bytes s, sub, start=None, stop=None):
"""
>>> bytes_startswith(b_a, b_a)
......@@ -30,8 +33,11 @@ def bytes_startswith(bytes s, sub, start=None, stop=None):
else:
return s.startswith(sub, start, stop)
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
@cython.test_fail_if_path_exists(
"//SimpleCallNode")
def bytes_endswith(bytes s, sub, start=None, stop=None):
"""
>>> bytes_endswith(b_a, b_a)
......@@ -56,3 +62,131 @@ def bytes_endswith(bytes s, sub, start=None, stop=None):
return s.endswith(sub, start)
else:
return s.endswith(sub, start, stop)
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
@cython.test_fail_if_path_exists(
"//SimpleCallNode")
def bytes_decode(bytes s, start=None, stop=None):
"""
>>> s = b_a+b_b+b_a+b_a+b_b
>>> print(bytes_decode(s))
abaab
>>> print(bytes_decode(s, 2))
aab
>>> print(bytes_decode(s, -3))
aab
>>> print(bytes_decode(s, None, 4))
abaa
>>> print(bytes_decode(s, None, 400))
abaab
>>> print(bytes_decode(s, None, -2))
aba
>>> print(bytes_decode(s, None, -4))
a
>>> print(bytes_decode(s, None, -5))
<BLANKLINE>
>>> print(bytes_decode(s, None, -200))
<BLANKLINE>
>>> print(bytes_decode(s, 2, 5))
aab
>>> print(bytes_decode(s, 2, 500))
aab
>>> print(bytes_decode(s, 2, -1))
aa
>>> print(bytes_decode(s, 2, -3))
<BLANKLINE>
>>> print(bytes_decode(s, 2, -300))
<BLANKLINE>
>>> print(bytes_decode(s, -3, -1))
aa
>>> print(bytes_decode(s, -300, 300))
abaab
>>> print(bytes_decode(s, -300, -4))
a
>>> print(bytes_decode(s, -300, -5))
<BLANKLINE>
>>> print(bytes_decode(s, -300, -6))
<BLANKLINE>
>>> print(bytes_decode(s, -300, -500))
<BLANKLINE>
>>> s[:'test'] # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError:...
>>> print(bytes_decode(s, 'test')) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError:...
>>> print(bytes_decode(s, None, 'test')) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError:...
>>> print(bytes_decode(s, 'test', 'test')) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError:...
>>> print(bytes_decode(None))
Traceback (most recent call last):
AttributeError: 'NoneType' object has no attribute 'decode'
>>> print(bytes_decode(None, 1))
Traceback (most recent call last):
AttributeError: 'NoneType' object has no attribute 'decode'
>>> print(bytes_decode(None, None, 1))
Traceback (most recent call last):
AttributeError: 'NoneType' object has no attribute 'decode'
>>> print(bytes_decode(None, 0, 1))
Traceback (most recent call last):
AttributeError: 'NoneType' object has no attribute 'decode'
"""
if start is None:
if stop is None:
return s.decode('utf8')
else:
return s[:stop].decode('utf8')
elif stop is None:
return s[start:].decode('utf8')
else:
return s[start:stop].decode('utf8')
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
@cython.test_fail_if_path_exists(
"//SimpleCallNode")
def bytes_decode_unbound_method(bytes s, start=None, stop=None):
"""
>>> s = b_a+b_b+b_a+b_a+b_b
>>> print(bytes_decode_unbound_method(s))
abaab
>>> print(bytes_decode_unbound_method(s, 1))
baab
>>> print(bytes_decode_unbound_method(s, None, 3))
aba
>>> print(bytes_decode_unbound_method(s, 1, 4))
baa
>>> print(bytes_decode_unbound_method(None))
Traceback (most recent call last):
TypeError: descriptor 'decode' requires a 'bytes' object but received a 'NoneType'
>>> print(bytes_decode_unbound_method(None, 1))
Traceback (most recent call last):
TypeError: descriptor 'decode' requires a 'bytes' object but received a 'NoneType'
>>> print(bytes_decode_unbound_method(None, None, 1))
Traceback (most recent call last):
TypeError: descriptor 'decode' requires a 'bytes' object but received a 'NoneType'
>>> print(bytes_decode_unbound_method(None, 0, 1))
Traceback (most recent call last):
TypeError: descriptor 'decode' requires a 'bytes' object but received a 'NoneType'
"""
if start is None:
if stop is None:
return bytes.decode(s, 'utf8')
else:
return bytes.decode(s[:stop], 'utf8')
elif stop is None:
return bytes.decode(s[start:], 'utf8')
else:
return bytes.decode(s[start:stop], 'utf8')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment