Commit 3d782405 authored by John Ehresman's avatar John Ehresman

Tailmatch optimization for unicode, str, and bytes

--HG--
extra : rebase_source : 57c2e54efe6b57be41942eb182a49ccd4723158d
parent 1cf2bd8c
......@@ -2424,20 +2424,22 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
exception_value = '-1')
def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
return self._inject_unicode_tailmatch(
node, args, is_unbound_method, 'endswith', +1)
return self._inject_tailmatch(
node, args, is_unbound_method, 'unicode', 'endswith',
self.PyUnicode_Tailmatch_func_type, unicode_tailmatch_utility_code, +1)
def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
return self._inject_unicode_tailmatch(
node, args, is_unbound_method, 'startswith', -1)
return self._inject_tailmatch(
node, args, is_unbound_method, 'unicode', 'startswith',
self.PyUnicode_Tailmatch_func_type, unicode_tailmatch_utility_code, -1)
def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
method_name, direction):
def _inject_tailmatch(self, node, args, is_unbound_method, type_name,
method_name, func_type, utility_code, direction):
"""Replace unicode.startswith(...) and unicode.endswith(...)
by a direct call to the corresponding C-API function.
"""
if len(args) not in (2,3,4):
self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
return node
self._inject_int_default_argument(
node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
......@@ -2447,21 +2449,11 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
node.pos, value=str(direction), type=PyrexTypes.c_int_type))
method_call = self._substitute_method_call(
node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
node, "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), func_type,
method_name, is_unbound_method, args,
utility_code = unicode_tailmatch_utility_code)
utility_code = utility_code)
return method_call.coerce_to(Builtin.bool_type, self.current_env())
PyBytes_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.str_type, None),
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
],
exception_value = '-1')
PyUnicode_Find_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_py_ssize_t_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
......@@ -2787,34 +2779,45 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return (encoding, encoding_node, error_handling, error_handling_node)
PyStr_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.str_type, None),
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
],
exception_value = '-1')
def _handle_simple_method_str_endswith(self, node, args, is_unbound_method):
return self._inject_str_tailmatch(
node, args, is_unbound_method, 'endswith', +1)
return self._inject_tailmatch(
node, args, is_unbound_method, 'str', 'endswith',
self.PyStr_Tailmatch_func_type, str_tailmatch_utility_code, +1)
def _handle_simple_method_str_startswith(self, node, args, is_unbound_method):
return self._inject_str_tailmatch(
node, args, is_unbound_method, 'startswith', -1)
def _inject_str_tailmatch(self, node, args, is_unbound_method,
method_name, direction):
"""Replace unicode.startswith(...) and unicode.endswith(...)
by a direct call to the corresponding C-API function.
"""
if len(args) not in (2,3,4):
self._error_wrong_arg_count('str.%s' % method_name, node, args, "2-4")
return node
self._inject_int_default_argument(
node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
self._inject_int_default_argument(
node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
args.append(ExprNodes.IntNode(
node.pos, value=str(direction), type=PyrexTypes.c_int_type))
return self._inject_tailmatch(
node, args, is_unbound_method, 'str', 'startswith',
self.PyStr_Tailmatch_func_type, str_tailmatch_utility_code,-1)
method_call = self._substitute_method_call(
node, "__Pyx_PyBytes_Tailmatch", self.PyBytes_Tailmatch_func_type,
method_name, is_unbound_method, args,
utility_code = bytes_tailmatch_utility_code)
return method_call.coerce_to(Builtin.bool_type, self.current_env())
PyBytes_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.bytes_type, None),
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
],
exception_value = '-1')
def _handle_simple_method_bytes_endswith(self, node, args, is_unbound_method):
return self._inject_tailmatch(
node, args, is_unbound_method, 'bytes', 'endswith',
self.PyBytes_Tailmatch_func_type, bytes_tailmatch_utility_code, +1)
def _handle_simple_method_bytes_startswith(self, node, args, is_unbound_method):
return self._inject_tailmatch(
node, args, is_unbound_method, 'bytes', 'startswith',
self.PyBytes_Tailmatch_func_type, bytes_tailmatch_utility_code,-1)
### helpers
......@@ -2991,6 +2994,25 @@ static int __Pyx_PyBytes_Tailmatch(PyObject* self, PyObject* substr, Py_ssize_t
""")
str_tailmatch_utility_code = UtilityCode(
proto = '''
static int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
Py_ssize_t end, int direction);
''',
impl = '''
static int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
Py_ssize_t end, int direction)
{
#if PY_MAJOR_VERSION < 3
return __Pyx_PyBytes_Tailmatch(self, arg, start, end, direction);
#else
return __Pyx_PyUnicode_Tailmatch(self, arg, start, end, direction);
#endif
}
''',
requires=[unicode_tailmatch_utility_code, bytes_tailmatch_utility_code]
)
dict_getitem_default_utility_code = UtilityCode(
proto = '''
static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
......
cimport cython
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
def bytes_startswith(bytes s, sub, start=None, stop=None):
"""
>>> bytes_startswith(b'a', b'a')
True
>>> bytes_startswith(b'a', b'b')
False
>>> bytes_startswith(b'a', (b'a', b'b'))
True
>>> bytes_startswith(b'a', b'a', 1)
False
>>> bytes_startswith(b'a', b'a', 0, 0)
False
"""
if start is None:
return s.startswith(sub)
elif stop is None:
return s.startswith(sub, start)
else:
return s.startswith(sub, start, stop)
@cython.test_assert_path_exists(
"//PythonCapiCallNode")
def bytes_endswith(bytes s, sub, start=None, stop=None):
"""
>>> bytes_endswith(b'a', b'a')
True
>>> bytes_endswith(b'a', b'b')
False
>>> bytes_endswith(b'a', (b'a', b'b'))
True
>>> bytes_endswith(b'a', b'a', 1)
False
>>> bytes_endswith(b'a', b'a', 0, 0)
False
"""
if start is None:
return s.endswith(sub)
elif stop is None:
return s.endswith(sub, start)
else:
return s.endswith(sub, start, stop)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment