Commit 57773d75 authored by Stefan Behnel's avatar Stefan Behnel

clean up PyStr_Tailmatch() implementation

parent d74d4bd0
...@@ -2413,9 +2413,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -2413,9 +2413,9 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
node, "PyUnicode_Split", self.PyUnicode_Split_func_type, node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
'split', is_unbound_method, args) 'split', is_unbound_method, args)
PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType( PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [ PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
...@@ -2426,15 +2426,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -2426,15 +2426,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method): def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'unicode', 'endswith', node, args, is_unbound_method, 'unicode', 'endswith',
self.PyUnicode_Tailmatch_func_type, unicode_tailmatch_utility_code, +1) unicode_tailmatch_utility_code, +1)
def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method): def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'unicode', 'startswith', node, args, is_unbound_method, 'unicode', 'startswith',
self.PyUnicode_Tailmatch_func_type, unicode_tailmatch_utility_code, -1) unicode_tailmatch_utility_code, -1)
def _inject_tailmatch(self, node, args, is_unbound_method, type_name, def _inject_tailmatch(self, node, args, is_unbound_method, type_name,
method_name, func_type, utility_code, direction): method_name, utility_code, direction):
"""Replace unicode.startswith(...) and unicode.endswith(...) """Replace unicode.startswith(...) and unicode.endswith(...)
by a direct call to the corresponding C-API function. by a direct call to the corresponding C-API function.
""" """
...@@ -2449,7 +2449,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -2449,7 +2449,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
node.pos, value=str(direction), type=PyrexTypes.c_int_type)) node.pos, value=str(direction), type=PyrexTypes.c_int_type))
method_call = self._substitute_method_call( method_call = self._substitute_method_call(
node, "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), func_type, node, "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
self.PyString_Tailmatch_func_type,
method_name, is_unbound_method, args, method_name, is_unbound_method, args,
utility_code = utility_code) utility_code = utility_code)
return method_call.coerce_to(Builtin.bool_type, self.current_env()) return method_call.coerce_to(Builtin.bool_type, self.current_env())
...@@ -2779,47 +2780,26 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -2779,47 +2780,26 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return (encoding, encoding_node, error_handling, error_handling_node) return (encoding, encoding_node, error_handling, error_handling_node)
PyStr_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.str_type, None),
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
],
exception_value = '-1')
def _handle_simple_method_str_endswith(self, node, args, is_unbound_method): def _handle_simple_method_str_endswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'str', 'endswith', node, args, is_unbound_method, 'str', 'endswith',
self.PyStr_Tailmatch_func_type, str_tailmatch_utility_code, +1) str_tailmatch_utility_code, +1)
def _handle_simple_method_str_startswith(self, node, args, is_unbound_method): def _handle_simple_method_str_startswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'str', 'startswith', node, args, is_unbound_method, 'str', 'startswith',
self.PyStr_Tailmatch_func_type, str_tailmatch_utility_code,-1) str_tailmatch_utility_code, -1)
PyBytes_Tailmatch_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_bint_type, [
PyrexTypes.CFuncTypeArg("str", Builtin.bytes_type, None),
PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
],
exception_value = '-1')
def _handle_simple_method_bytes_endswith(self, node, args, is_unbound_method): def _handle_simple_method_bytes_endswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'bytes', 'endswith', node, args, is_unbound_method, 'bytes', 'endswith',
self.PyBytes_Tailmatch_func_type, bytes_tailmatch_utility_code, +1) bytes_tailmatch_utility_code, +1)
def _handle_simple_method_bytes_startswith(self, node, args, is_unbound_method): def _handle_simple_method_bytes_startswith(self, node, args, is_unbound_method):
return self._inject_tailmatch( return self._inject_tailmatch(
node, args, is_unbound_method, 'bytes', 'startswith', node, args, is_unbound_method, 'bytes', 'startswith',
self.PyBytes_Tailmatch_func_type, bytes_tailmatch_utility_code,-1) bytes_tailmatch_utility_code, -1)
### helpers ### helpers
def _substitute_method_call(self, node, name, func_type, def _substitute_method_call(self, node, name, func_type,
...@@ -2929,9 +2909,7 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize ...@@ -2929,9 +2909,7 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize
int retval; int retval;
#if PY_VERSION_HEX >= 0x02060000 #if PY_VERSION_HEX >= 0x02060000
PyBufferProcs *pb = NULL;
Py_buffer view; Py_buffer view;
view.obj = NULL; view.obj = NULL;
#endif #endif
...@@ -2947,19 +2925,11 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize ...@@ -2947,19 +2925,11 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize
#endif #endif
else { else {
#if PY_VERSION_HEX < 0x02060000 #if PY_VERSION_HEX < 0x02060000
if (PyObject_AsCharBuffer(arg, &sub_ptr, &sub_len)) if (unlikely(PyObject_AsCharBuffer(arg, &sub_ptr, &sub_len)))
return -1; return -1;
#else #else
pb = Py_TYPE(self)->tp_as_buffer; if (unlikely(PyObject_GetBuffer(self, &view, PyBUF_SIMPLE) == -1))
if (pb == NULL || pb->bf_getbuffer == NULL) {
PyErr_SetString(PyExc_TypeError,
"expected an object with the buffer interface");
return -1;
}
if ((*pb->bf_getbuffer)(self, &view, PyBUF_SIMPLE)) {
return -1; return -1;
}
sub_ptr = (const char*) view.buf; sub_ptr = (const char*) view.buf;
sub_len = view.len; sub_len = view.len;
#endif #endif
...@@ -2988,9 +2958,8 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize ...@@ -2988,9 +2958,8 @@ static int __Pyx_PyBytes_SingleTailmatch(PyObject* self, PyObject* arg, Py_ssize
retval = 0; retval = 0;
#if PY_VERSION_HEX >= 0x02060000 #if PY_VERSION_HEX >= 0x02060000
if (pb != NULL && pb->bf_releasebuffer != NULL) if (view.obj)
(*pb->bf_releasebuffer)(self, &view); PyBuffer_Release(&view);
Py_XDECREF(view.obj);
#endif #endif
return retval; return retval;
...@@ -3019,18 +2988,21 @@ static int __Pyx_PyBytes_Tailmatch(PyObject* self, PyObject* substr, Py_ssize_t ...@@ -3019,18 +2988,21 @@ static int __Pyx_PyBytes_Tailmatch(PyObject* self, PyObject* substr, Py_ssize_t
str_tailmatch_utility_code = UtilityCode( str_tailmatch_utility_code = UtilityCode(
proto = ''' proto = '''
static int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start, static CYTHON_INLINE int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
Py_ssize_t end, int direction); Py_ssize_t end, int direction);
''', ''',
# We do not use a C compiler macro here to avoid "unused function"
# warnings for the *_Tailmatch() function that is not being used in
# the specific CPython version. The C compiler will generate the same
# code anyway, and will usually just remove the unused function.
impl = ''' impl = '''
static int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start, static CYTHON_INLINE int __Pyx_PyStr_Tailmatch(PyObject* self, PyObject* arg, Py_ssize_t start,
Py_ssize_t end, int direction) Py_ssize_t end, int direction)
{ {
#if PY_MAJOR_VERSION < 3 if (PY_MAJOR_VERSION < 3)
return __Pyx_PyBytes_Tailmatch(self, arg, start, end, direction); return __Pyx_PyBytes_Tailmatch(self, arg, start, end, direction);
#else else
return __Pyx_PyUnicode_Tailmatch(self, arg, start, end, direction); return __Pyx_PyUnicode_Tailmatch(self, arg, start, end, direction);
#endif
} }
''', ''',
requires=[unicode_tailmatch_utility_code, bytes_tailmatch_utility_code] requires=[unicode_tailmatch_utility_code, bytes_tailmatch_utility_code]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment