Commit 42572ca9 authored by Stefan Behnel's avatar Stefan Behnel

implement ticket #535: fast index access into unicode strings

parent 737c9641
...@@ -1892,11 +1892,15 @@ class IndexNode(ExprNode): ...@@ -1892,11 +1892,15 @@ class IndexNode(ExprNode):
return self.base.type_dependencies(env) return self.base.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
if isinstance(self.base, (StringNode, UnicodeNode)): # FIXME: BytesNode? if isinstance(self.base, StringNode): # FIXME: BytesNode?
return py_object_type return py_object_type
base_type = self.base.infer_type(env) base_type = self.base.infer_type(env)
if base_type.is_ptr or base_type.is_array: if base_type.is_ptr or base_type.is_array:
return base_type.base_type return base_type.base_type
elif base_type is Builtin.unicode_type:
# Py_UNICODE will automatically coerce to a unicode string
# if required, so this is safe
return PyrexTypes.c_py_unicode_type
else: else:
# TODO: Handle buffers (hopefully without too much redundancy). # TODO: Handle buffers (hopefully without too much redundancy).
return py_object_type return py_object_type
...@@ -1965,15 +1969,16 @@ class IndexNode(ExprNode): ...@@ -1965,15 +1969,16 @@ class IndexNode(ExprNode):
else: else:
self.base.entry.buffer_aux.writable_needed = True self.base.entry.buffer_aux.writable_needed = True
else: else:
base_type = self.base.type
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis: elif not skip_child_analysis:
self.index.analyse_types(env) self.index.analyse_types(env)
self.original_index_type = self.index.type self.original_index_type = self.index.type
if self.base.type.is_pyobject: if base_type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
if (not setting if (not setting
and (self.base.type is list_type or self.base.type is tuple_type) and (base_type in (list_type, tuple_type, unicode_type))
and (not self.index.type.signed or isinstance(self.index, IntNode) and int(self.index.value) >= 0) and (not self.index.type.signed or isinstance(self.index, IntNode) and int(self.index.value) >= 0)
and not env.directives['boundscheck']): and not env.directives['boundscheck']):
self.is_temp = 0 self.is_temp = 0
...@@ -1983,10 +1988,15 @@ class IndexNode(ExprNode): ...@@ -1983,10 +1988,15 @@ class IndexNode(ExprNode):
else: else:
self.index = self.index.coerce_to_pyobject(env) self.index = self.index.coerce_to_pyobject(env)
self.is_temp = 1 self.is_temp = 1
self.type = py_object_type if base_type is unicode_type:
# Py_UNICODE will automatically coerce to a unicode string
# if required, so this is safe
self.type = PyrexTypes.c_py_unicode_type
else:
self.type = py_object_type
else: else:
if self.base.type.is_ptr or self.base.type.is_array: if base_type.is_ptr or base_type.is_array:
self.type = self.base.type.base_type self.type = base_type.base_type
if self.index.type.is_pyobject: if self.index.type.is_pyobject:
self.index = self.index.coerce_to( self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env) PyrexTypes.c_py_ssize_t_type, env)
...@@ -1994,10 +2004,10 @@ class IndexNode(ExprNode): ...@@ -1994,10 +2004,10 @@ class IndexNode(ExprNode):
error(self.pos, error(self.pos,
"Invalid index type '%s'" % "Invalid index type '%s'" %
self.index.type) self.index.type)
elif self.base.type.is_cpp_class: elif base_type.is_cpp_class:
function = env.lookup_operator("[]", [self.base, self.index]) function = env.lookup_operator("[]", [self.base, self.index])
if function is None: if function is None:
error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type)) error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
...@@ -2011,7 +2021,7 @@ class IndexNode(ExprNode): ...@@ -2011,7 +2021,7 @@ class IndexNode(ExprNode):
else: else:
error(self.pos, error(self.pos,
"Attempting to index non-array type '%s'" % "Attempting to index non-array type '%s'" %
self.base.type) base_type)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
gil_message = "Indexing Python object" gil_message = "Indexing Python object"
...@@ -2040,6 +2050,8 @@ class IndexNode(ExprNode): ...@@ -2040,6 +2050,8 @@ class IndexNode(ExprNode):
return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
elif self.base.type is tuple_type: elif self.base.type is tuple_type:
return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type:
return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result())
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result(), self.index.result()) self.base.result(), self.index.result())
...@@ -2087,34 +2099,51 @@ class IndexNode(ExprNode): ...@@ -2087,34 +2099,51 @@ class IndexNode(ExprNode):
# is_temp is True, so must pull out value and incref it. # is_temp is True, so must pull out value and incref it.
code.putln("%s = *%s;" % (self.result(), self.buffer_ptr_code)) code.putln("%s = *%s;" % (self.result(), self.buffer_ptr_code))
code.putln("__Pyx_INCREF((PyObject*)%s);" % self.result()) code.putln("__Pyx_INCREF((PyObject*)%s);" % self.result())
elif self.type.is_pyobject and self.is_temp: elif self.is_temp:
if self.index.type.is_int: if self.type.is_pyobject:
index_code = self.index.result() if self.index.type.is_int:
if self.base.type is list_type: index_code = self.index.result()
function = "__Pyx_GetItemInt_List" if self.base.type is list_type:
elif self.base.type is tuple_type: function = "__Pyx_GetItemInt_List"
function = "__Pyx_GetItemInt_Tuple" elif self.base.type is tuple_type:
function = "__Pyx_GetItemInt_Tuple"
else:
function = "__Pyx_GetItemInt"
code.globalstate.use_utility_code(getitem_int_utility_code)
else: else:
function = "__Pyx_GetItemInt" index_code = self.index.py_result()
code.globalstate.use_utility_code(getitem_int_utility_code) if self.base.type is dict_type:
else: function = "__Pyx_PyDict_GetItem"
if self.base.type is dict_type: code.globalstate.use_utility_code(getitem_dict_utility_code)
function = "__Pyx_PyDict_GetItem" else:
code.globalstate.use_utility_code(getitem_dict_utility_code) function = "PyObject_GetItem"
code.putln(
"%s = %s(%s, %s%s); if (!%s) %s" % (
self.result(),
function,
self.base.py_result(),
index_code,
self.extra_index_params(),
self.result(),
code.error_goto(self.pos)))
code.put_gotref(self.py_result())
elif self.type is PyrexTypes.c_py_unicode_type and self.base.type is unicode_type:
code.globalstate.use_utility_code(getitem_int_pyunicode_utility_code)
if self.index.type.is_int:
index_code = self.index.result()
function = "__Pyx_GetItemInt_Unicode"
else: else:
function = "PyObject_GetItem" index_code = self.index.py_result()
index_code = self.index.py_result() function = "__Pyx_GetItemInt_Unicode_Generic"
sign_code = "" code.putln(
code.putln( "%s = %s(%s, %s%s); if (unlikely(%s == (Py_UNICODE)-1)) %s;" % (
"%s = %s(%s, %s%s); if (!%s) %s" % ( self.result(),
self.result(), function,
function, self.base.py_result(),
self.base.py_result(), index_code,
index_code, self.extra_index_params(),
self.extra_index_params(), self.result(),
self.result(), code.error_goto(self.pos)))
code.error_goto(self.pos)))
code.put_gotref(self.py_result())
def generate_setitem_code(self, value_code, code): def generate_setitem_code(self, value_code, code):
if self.index.type.is_int: if self.index.type.is_int:
...@@ -6731,6 +6760,38 @@ requires = [raise_noneindex_error_utility_code]) ...@@ -6731,6 +6760,38 @@ requires = [raise_noneindex_error_utility_code])
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
getitem_int_pyunicode_utility_code = UtilityCode(
proto = '''
#define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func) (((size) <= sizeof(Py_ssize_t)) ? \\
__Pyx_GetItemInt_Unicode_Fast(o, i) : \\
__Pyx_GetItemInt_Generic(o, to_py_func(i)))
static CYTHON_INLINE Py_UNICODE __Pyx_GetItemInt_Unicode_Fast(PyObject* ustring, Py_ssize_t i) {
if (likely((0 <= i) & (i < PyUnicode_GET_SIZE(ustring)))) {
return PyUnicode_AS_UNICODE(ustring)[i];
} else if ((-PyUnicode_GET_SIZE(ustring) <= i) & (i < 0)) {
i += PyUnicode_GET_SIZE(ustring);
return PyUnicode_AS_UNICODE(ustring)[i];
} else {
PyErr_SetString(PyExc_IndexError, "string index out of range");
return (Py_UNICODE)-1;
}
}
static CYTHON_INLINE Py_UNICODE __Pyx_GetItemInt_Unicode_Generic(PyObject* ustring, PyObject* j) {
PyObject *r;
Py_UNICODE uchar;
if (!j) return (Py_UNICODE)-1;
r = PyObject_GetItem(ustring, j);
Py_DECREF(j);
if (!r) return (Py_UNICODE)-1;
uchar = PyUnicode_AS_UNICODE(r)[0];
Py_DECREF(r);
return uchar;
}
''',
)
getitem_int_utility_code = UtilityCode( getitem_int_utility_code = UtilityCode(
proto = """ proto = """
......
cimport cython
cdef unicode _ustring = u'azerty123456'
ustring = _ustring
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index(unicode ustring, Py_ssize_t i):
"""
>>> index(ustring, 0)
u'a'
>>> index(ustring, 2)
u'e'
>>> index(ustring, -1)
u'6'
>>> index(ustring, -len(ustring))
u'a'
>>> index(ustring, len(ustring))
Traceback (most recent call last):
IndexError: string index out of range
"""
return ustring[i]
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index_literal(Py_ssize_t i):
"""
>>> index_literal(0)
u'a'
>>> index_literal(2)
u'e'
>>> index_literal(-1)
u'6'
>>> index_literal(-len('azerty123456'))
u'a'
>>> index_literal(len(ustring))
Traceback (most recent call last):
IndexError: string index out of range
"""
return u'azerty123456'[i]
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
@cython.boundscheck(False)
def index_no_boundscheck(unicode ustring, Py_ssize_t i):
"""
>>> index_no_boundscheck(ustring, 0)
u'a'
>>> index_no_boundscheck(ustring, 2)
u'e'
>>> index_no_boundscheck(ustring, -1)
u'6'
>>> index_no_boundscheck(ustring, len(ustring)-1)
u'6'
>>> index_no_boundscheck(ustring, -len(ustring))
u'a'
"""
return ustring[i]
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
@cython.boundscheck(False)
def unsigned_index_no_boundscheck(unicode ustring, unsigned int i):
"""
>>> unsigned_index_no_boundscheck(ustring, 0)
u'a'
>>> unsigned_index_no_boundscheck(ustring, 2)
u'e'
>>> unsigned_index_no_boundscheck(ustring, len(ustring)-1)
u'6'
"""
return ustring[i]
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode",
"//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index_compare(unicode ustring, Py_ssize_t i):
"""
>>> index_compare(ustring, 0)
True
>>> index_compare(ustring, 1)
False
>>> index_compare(ustring, -1)
False
>>> index_compare(ustring, -len(ustring))
True
>>> index_compare(ustring, len(ustring))
Traceback (most recent call last):
IndexError: string index out of range
"""
return ustring[i] == u'a'
@cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode",
"//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index_compare_string(unicode ustring, Py_ssize_t i, unicode other):
"""
>>> index_compare_string(ustring, 0, ustring[0])
True
>>> index_compare_string(ustring, 0, ustring[:4])
False
>>> index_compare_string(ustring, 1, ustring[0])
False
>>> index_compare_string(ustring, 1, ustring[1])
True
>>> index_compare_string(ustring, -1, ustring[0])
False
>>> index_compare_string(ustring, -1, ustring[-1])
True
>>> index_compare_string(ustring, -len(ustring), ustring[-len(ustring)])
True
>>> index_compare_string(ustring, len(ustring), ustring)
Traceback (most recent call last):
IndexError: string index out of range
"""
return ustring[i] == other
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment