Commit bf702727 authored by Stefan Behnel's avatar Stefan Behnel

rewrite of sliced char* decoding as utility functions, implemented efficient...

rewrite of sliced char* decoding as utility functions, implemented efficient sliced decoding for C++ std::string

--HG--
extra : rebase_source : e134f5595be98eb990ab2195e8208940efb171fe
parent 01adf000
...@@ -2355,9 +2355,11 @@ class PyTempNode(TempNode): ...@@ -2355,9 +2355,11 @@ class PyTempNode(TempNode):
class RawCNameExprNode(ExprNode): class RawCNameExprNode(ExprNode):
subexprs = [] subexprs = []
def __init__(self, pos, type=None): def __init__(self, pos, type=None, cname=None):
self.pos = pos self.pos = pos
self.type = type self.type = type
if cname is not None:
self.cname = cname
def analyse_types(self, env): def analyse_types(self, env):
return self.type return self.type
...@@ -3319,7 +3321,7 @@ class SliceIndexNode(ExprNode): ...@@ -3319,7 +3321,7 @@ class SliceIndexNode(ExprNode):
def infer_type(self, env): def infer_type(self, env):
base_type = self.base.infer_type(env) base_type = self.base.infer_type(env)
if base_type.is_string: if base_type.is_string or base_type.is_cpp_class:
return bytes_type return bytes_type
elif base_type in (bytes_type, str_type, unicode_type, elif base_type in (bytes_type, str_type, unicode_type,
list_type, tuple_type): list_type, tuple_type):
...@@ -3383,7 +3385,7 @@ class SliceIndexNode(ExprNode): ...@@ -3383,7 +3385,7 @@ class SliceIndexNode(ExprNode):
if self.stop: if self.stop:
self.stop.analyse_types(env) self.stop.analyse_types(env)
base_type = self.base.type base_type = self.base.type
if base_type.is_string: if base_type.is_string or base_type.is_cpp_string:
self.type = bytes_type self.type = bytes_type
elif base_type.is_ptr: elif base_type.is_ptr:
self.type = base_type self.type = base_type
...@@ -9357,7 +9359,7 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -9357,7 +9359,7 @@ class CoerceToPyTypeNode(CoercionNode):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
if type is py_object_type: if type is py_object_type:
# be specific about some known types # be specific about some known types
if arg.type.is_string: if arg.type.is_string or arg.type.is_cpp_string:
self.type = bytes_type self.type = bytes_type
elif arg.type.is_unicode_char: elif arg.type.is_unicode_char:
self.type = unicode_type self.type = unicode_type
......
...@@ -2699,106 +2699,116 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -2699,106 +2699,116 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
'encode', is_unbound_method, 'encode', is_unbound_method,
[string_node, encoding_node, error_handling_node]) [string_node, encoding_node, error_handling_node])
PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType( PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
Builtin.unicode_type, [ Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
]) ]))
PyUnicode_Decode_func_type = PyrexTypes.CFuncType( _decode_c_string_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [ Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
]) ])
_decode_cpp_string_func_type = None # lazy init
def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method): def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
"""Replace char*.decode() by a direct C-API call to the """Replace char*.decode() by a direct C-API call to the
corresponding codec, possibly resoving a slice on the char*. corresponding codec, possibly resoving a slice on the char*.
""" """
if len(args) < 1 or len(args) > 3: if not (1 <= len(args) <= 3):
self._error_wrong_arg_count('bytes.decode', node, args, '1-3') self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
return node return node
temps = []
# normalise input nodes
if isinstance(args[0], ExprNodes.SliceIndexNode): if isinstance(args[0], ExprNodes.SliceIndexNode):
index_node = args[0] index_node = args[0]
string_node = index_node.base string_node = index_node.base
if not string_node.type.is_string:
# nothing to optimise here
return node
start, stop = index_node.start, index_node.stop start, stop = index_node.start, index_node.stop
if not start or start.constant_result == 0: if not start or start.constant_result == 0:
start = None start = None
else: elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode):
if start.type.is_pyobject:
start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop:
start = UtilNodes.LetRefNode(start)
temps.append(start)
string_node = ExprNodes.AddNode(pos=start.pos,
operand1=string_node,
operator='+',
operand2=start,
is_temp=False,
type=string_node.type
)
if stop and stop.type.is_pyobject:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
and args[0].arg.type.is_string:
# use strlen() to find the string length, just as CPython would
start = stop = None
string_node = args[0].arg string_node = args[0].arg
start = stop = None
else: else:
# let Python do its job
return node return node
if not stop: if not string_node.type.is_string and not string_node.type.is_cpp_string:
if start or not string_node.is_name: # nothing to optimise here
string_node = UtilNodes.LetRefNode(string_node) return node
temps.append(string_node)
stop = ExprNodes.PythonCapiCallNode(
string_node.pos, "strlen", self.Pyx_strlen_func_type,
args = [string_node],
is_temp = False,
utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
elif start:
stop = ExprNodes.SubNode(
pos = stop.pos,
operand1 = stop,
operator = '-',
operand2 = start,
is_temp = False,
type = PyrexTypes.c_py_ssize_t_type
)
parameters = self._unpack_encoding_and_error_mode(node.pos, args) parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None: if parameters is None:
return node return node
encoding, encoding_node, error_handling, error_handling_node = parameters encoding, encoding_node, error_handling, error_handling_node = parameters
if not start:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
elif not start.type.is_int:
start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
if stop and not stop.type.is_int:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
# try to find a specific encoder function # try to find a specific encoder function
codec_name = None codec_name = None
if encoding is not None: if encoding is not None:
codec_name = self._find_special_codec_name(encoding) codec_name = self._find_special_codec_name(encoding)
if codec_name is not None: if codec_name is not None:
decode_function = "PyUnicode_Decode%s" % codec_name decode_function = ExprNodes.RawCNameExprNode(
node = ExprNodes.PythonCapiCallNode( node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type,
node.pos, decode_function, cname="PyUnicode_Decode%s" % codec_name)
self.PyUnicode_DecodeXyz_func_type, encoding_node = ExprNodes.NullNode(node.pos)
args = [string_node, stop, error_handling_node],
is_temp = node.is_temp,
)
else: else:
node = ExprNodes.PythonCapiCallNode( decode_function = ExprNodes.NullNode(node.pos)
node.pos, "PyUnicode_Decode",
self.PyUnicode_Decode_func_type, # build the helper function call
args = [string_node, stop, encoding_node, error_handling_node], temps = []
is_temp = node.is_temp, if string_node.type.is_string:
) # C string
if not stop:
# use strlen() to find the string length, just as CPython would
if not string_node.is_name:
string_node = UtilNodes.LetRefNode(string_node) # used twice
temps.append(string_node)
stop = ExprNodes.PythonCapiCallNode(
string_node.pos, "strlen", self.Pyx_strlen_func_type,
args = [string_node],
is_temp = False,
utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
helper_func_type = self._decode_c_string_func_type
utility_code_name = 'decode_c_string'
else:
# C++ std::string
if not stop:
stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
constant_result=ExprNodes.not_a_constant)
if self._decode_cpp_string_func_type is None:
# lazy init to reuse the C++ string type
self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", string_node.type, None),
PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
])
helper_func_type = self._decode_cpp_string_func_type
utility_code_name = 'decode_cpp_string'
node = ExprNodes.PythonCapiCallNode(
node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
args = [string_node, start, stop, encoding_node, error_handling_node, decode_function],
is_temp = node.is_temp,
utility_code=UtilityCode.load_cached(utility_code_name, 'Optimize.c'),
)
for temp in temps[::-1]: for temp in temps[::-1]:
node = UtilNodes.EvalWithTempExprNode(temp, node) node = UtilNodes.EvalWithTempExprNode(temp, node)
......
...@@ -144,7 +144,7 @@ class PyrexType(BaseType): ...@@ -144,7 +144,7 @@ class PyrexType(BaseType):
# is_enum boolean Is a C enum type # is_enum boolean Is a C enum type
# is_typedef boolean Is a typedef type # is_typedef boolean Is a typedef type
# is_string boolean Is a C char * type # is_string boolean Is a C char * type
# is_unicode boolean Is a UTF-8 encoded C char * type # is_cpp_string boolean Is a C++ std::string type
# is_unicode_char boolean Is either Py_UCS4 or Py_UNICODE # is_unicode_char boolean Is either Py_UCS4 or Py_UNICODE
# is_returncode boolean Is used only to signal exceptions # is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type # is_error boolean Is the dummy error type
...@@ -195,11 +195,11 @@ class PyrexType(BaseType): ...@@ -195,11 +195,11 @@ class PyrexType(BaseType):
is_cfunction = 0 is_cfunction = 0
is_struct_or_union = 0 is_struct_or_union = 0
is_cpp_class = 0 is_cpp_class = 0
is_cpp_string = 0
is_struct = 0 is_struct = 0
is_enum = 0 is_enum = 0
is_typedef = 0 is_typedef = 0
is_string = 0 is_string = 0
is_unicode = 0
is_unicode_char = 0 is_unicode_char = 0
is_returncode = 0 is_returncode = 0
is_error = 0 is_error = 0
...@@ -3011,6 +3011,7 @@ class CppClassType(CType): ...@@ -3011,6 +3011,7 @@ class CppClassType(CType):
self.templates = templates self.templates = templates
self.template_type = template_type self.template_type = template_type
self.specializations = {} self.specializations = {}
self.is_cpp_string = cname == 'std::string'
def use_conversion_utility(self, from_or_to): def use_conversion_utility(self, from_or_to):
pass pass
......
...@@ -487,3 +487,61 @@ static double __Pyx__PyObject_AsDouble(PyObject* obj) { ...@@ -487,3 +487,61 @@ static double __Pyx__PyObject_AsDouble(PyObject* obj) {
bad: bad:
return (double)-1; return (double)-1;
} }
/////////////// decode_cpp_string.proto ///////////////
#include <string>
static CYTHON_INLINE PyObject* __Pyx_decode_cpp_string(
std::string cppstring, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors));
/////////////// decode_cpp_string ///////////////
static CYTHON_INLINE PyObject* __Pyx_decode_cpp_string(
std::string cppstring, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors)) {
const char* cstring = cppstring.data();
Py_ssize_t length = cppstring.size();
if (start < 0)
start += length;
if (stop < 0)
length += stop;
else if (length > stop)
length = stop;
if ((start < 0) | (start >= length))
return PyUnicode_FromUnicode(NULL, 0);
cstring += start;
length -= start;
if (decode_func) {
return decode_func(cstring, length, errors);
} else {
return PyUnicode_Decode(cstring, length, encoding, errors);
}
}
/////////////// decode_c_string.proto ///////////////
static CYTHON_INLINE PyObject* __Pyx_decode_c_string(
const char* cstring, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors));
/////////////// decode_c_string ///////////////
static CYTHON_INLINE PyObject* __Pyx_decode_c_string(
const char* cstring, Py_ssize_t start, Py_ssize_t stop,
const char* encoding, const char* errors,
PyObject* (*decode_func)(const char *s, Py_ssize_t size, const char *errors)) {
Py_ssize_t length = stop - start;
cstring += start;
if (decode_func) {
return decode_func(cstring, length, errors);
} else {
return PyUnicode_Decode(cstring, length, encoding, errors);
}
}
...@@ -100,6 +100,21 @@ def slice_charptr_dynamic_bounds(): ...@@ -100,6 +100,21 @@ def slice_charptr_dynamic_bounds():
cstring[return1():return5()].decode('UTF-8'), cstring[return1():return5()].decode('UTF-8'),
cstring[return4():return9()].decode('UTF-8')) cstring[return4():return9()].decode('UTF-8'))
@cython.test_assert_path_exists("//PythonCapiCallNode")
@cython.test_fail_if_path_exists("//AttributeNode")
def slice_charptr_dynamic_bounds_non_name():
"""
>>> print(str(slice_charptr_dynamic_bounds_non_name()).replace("u'", "'"))
('bcA', 'bcA', 'BCqtp', 'ABCqtp', 'bcABCqtp', 'bcABCqtp', 'cABC')
"""
return ((cstring+1)[:return3()].decode('UTF-8'),
(cstring+1)[0:return3()].decode('UTF-8'),
(cstring+1)[return3():].decode('UTF-8'),
(cstring+1)[2:].decode('UTF-8'),
(cstring+1)[0:].decode('UTF-8'),
(cstring+1)[:].decode('UTF-8'),
(cstring+1)[return1():return5()].decode('UTF-8'))
cdef return1(): return 1 cdef return1(): return 1
cdef return3(): return 3 cdef return3(): return 3
cdef return4(): return 4 cdef return4(): return 4
......
# tag: cpp # tag: cpp
cimport cython
from libcpp.string cimport string from libcpp.string cimport string
b_asdf = b'asdf' b_asdf = b'asdf'
...@@ -143,14 +145,46 @@ def test_cstr(char *a): ...@@ -143,14 +145,46 @@ def test_cstr(char *a):
cdef string b = string(a) cdef string b = string(a)
return b.c_str() return b.c_str()
@cython.test_assert_path_exists("//PythonCapiCallNode")
@cython.test_fail_if_path_exists("//AttributeNode")
def test_decode(char* a): def test_decode(char* a):
""" """
>>> test_decode(b_asdf) == 'asdf' >>> print(test_decode(b_asdf))
True asdf
""" """
cdef string b = string(a) cdef string b = string(a)
return b.decode('ascii') return b.decode('ascii')
@cython.test_assert_path_exists("//PythonCapiCallNode")
@cython.test_fail_if_path_exists("//AttributeNode")
def test_decode_sliced(char* a):
"""
>>> print(test_decode_sliced(b_asdf))
sd
"""
cdef string b = string(a)
return b[1:3].decode('ascii')
@cython.test_assert_path_exists("//PythonCapiCallNode")
@cython.test_fail_if_path_exists("//AttributeNode")
def test_decode_sliced_end(char* a):
"""
>>> print(test_decode_sliced_end(b_asdf))
asd
"""
cdef string b = string(a)
return b[:3].decode('ascii')
@cython.test_assert_path_exists("//PythonCapiCallNode")
@cython.test_fail_if_path_exists("//AttributeNode")
def test_decode_sliced_start(char* a):
"""
>>> print(test_decode_sliced_start(b_asdf))
df
"""
cdef string b = string(a)
return b[2:].decode('ascii')
def test_equals_operator(char *a, char *b): def test_equals_operator(char *a, char *b):
""" """
>>> test_equals_operator(b_asdf, b_asdf) >>> test_equals_operator(b_asdf, b_asdf)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment