Commit 00980f04 authored by Stefan Behnel's avatar Stefan Behnel

optimise str/bytes.join() and infer the result type; improve type inference for called builtins

parent 066c5b04
......@@ -74,12 +74,13 @@ builtin_utility_code = {
# mapping from builtins to their C-level equivalents
class _BuiltinOverride(object):
def __init__(self, py_name, args, ret_type, cname, py_equiv = "*",
utility_code = None, sig = None, func_type = None,
is_strict_signature = False):
def __init__(self, py_name, args, ret_type, cname, py_equiv="*",
utility_code=None, sig=None, func_type=None,
is_strict_signature=False, builtin_return_type=None):
self.py_name, self.cname, self.py_equiv = py_name, cname, py_equiv
self.args, self.ret_type = args, ret_type
self.func_type, self.sig = func_type, sig
self.builtin_return_type = builtin_return_type
self.is_strict_signature = is_strict_signature
self.utility_code = utility_code
......@@ -89,6 +90,8 @@ class _BuiltinOverride(object):
func_type = sig.function_type(self_arg)
if self.is_strict_signature:
func_type.is_strict_signature = True
if self.builtin_return_type:
func_type.return_type = builtin_types[self.builtin_return_type]
return func_type
......@@ -212,7 +215,7 @@ builtin_function_table = [
#('raw_input', "", "", ""),
#('reduce', "", "", ""),
BuiltinFunction('reload', "O", "O", "PyImport_ReloadModule"),
BuiltinFunction('repr', "O", "O", "PyObject_Repr"),
BuiltinFunction('repr', "O", "O", "PyObject_Repr", builtin_return_type='basestring'),
#('round', "", "", ""),
BuiltinFunction('setattr', "OOO", "r", "PyObject_SetAttr"),
#('sum', "", "", ""),
......@@ -276,8 +279,13 @@ builtin_types_table = [
("basestring", "PyBaseString_Type", []),
("bytearray", "PyByteArray_Type", []),
("bytes", "PyBytes_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"),
BuiltinMethod("join", "TO", "O", "__Pyx_PyBytes_Join",
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
]),
("str", "PyString_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"),
BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join",
builtin_return_type='basestring',
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
]),
("unicode", "PyUnicode_Type", [BuiltinMethod("__contains__", "TO", "b", "PyUnicode_Contains"),
BuiltinMethod("join", "TO", "T", "PyUnicode_Join"),
......@@ -404,10 +412,11 @@ def init_builtin_structs():
builtin_scope.declare_struct_or_union(
name, "struct", scope, 1, None, cname = cname)
def init_builtins():
init_builtin_structs()
init_builtin_funcs()
init_builtin_types()
init_builtin_funcs()
builtin_scope.declare_var(
'__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True)
......@@ -429,4 +438,5 @@ def init_builtins():
bool_type = builtin_scope.lookup('bool').type
complex_type = builtin_scope.lookup('complex').type
init_builtins()
......@@ -4163,9 +4163,14 @@ class CallNode(ExprNode):
def infer_type(self, env):
function = self.function
if isinstance(function, NewExprNode):
return PyrexTypes.CPtrType(function.class_type)
func_type = function.infer_type(env)
if isinstance(self.function, NewExprNode):
return PyrexTypes.CPtrType(self.function.class_type)
if func_type is py_object_type:
# function might have lied for safety => try to find better type
entry = getattr(function, 'entry', None)
if entry is not None:
func_type = entry.type or func_type
if func_type.is_ptr:
func_type = func_type.base_type
if func_type.is_cfunction:
......
......@@ -648,3 +648,27 @@ static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t i
index += PyBytes_GET_SIZE(bytes);
return PyBytes_AS_STRING(bytes)[index];
}
//////////////////// StringJoin.proto ////////////////////
#if PY_MAJOR_VERSION < 3
#define __Pyx_PyString_Join __Pyx_PyBytes_Join
#else
#define __Pyx_PyString_Join PyUnicode_Join
#endif
#if CYTHON_COMPILING_IN_CPYTHON
#define __Pyx_PyBytes_Join _PyBytes_Join
#else
static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values); /*proto*/
#endif
//////////////////// StringJoin ////////////////////
#if !CYTHON_COMPILING_IN_CPYTHON
static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values) {
return PyObject_CallMethodObjArgs(sep, PYIDENT("join"), values, NULL)
}
#endif
......@@ -190,3 +190,33 @@ def bytes_decode_unbound_method(bytes s, start=None, stop=None):
return bytes.decode(s[start:], 'utf8')
else:
return bytes.decode(s[start:stop], 'utf8')
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def bytes_join(bytes s, *args):
"""
>>> print(bytes_join(b_a, b_b, b_b, b_b).decode('utf8'))
babab
"""
result = s.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
return result
@cython.test_fail_if_path_exists(
"//SimpleCallNode//NoneCheckNode",
)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def literal_join(*args):
"""
>>> print(literal_join(b_b, b_b, b_b, b_b).decode('utf8'))
b|b|b|b
"""
result = b'|'.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
return result
......@@ -53,3 +53,33 @@ def str_endswith(str s, sub, start=None, stop=None):
return s.endswith(sub, start)
else:
return s.endswith(sub, start, stop)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def str_join(str s, args):
"""
>>> print(str_join('a', list('bbb')))
babab
"""
result = s.join(args)
assert cython.typeof(result) == 'basestring object', cython.typeof(result)
return result
@cython.test_fail_if_path_exists(
"//SimpleCallNode//NoneCheckNode",
)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def literal_join(args):
"""
>>> print(literal_join(list('abcdefg')))
a|b|c|d|e|f|g
"""
result = '|'.join(args)
assert cython.typeof(result) == 'basestring object', cython.typeof(result)
return result
......@@ -218,7 +218,9 @@ def join_sep(l):
>>> print( join_sep(l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf
"""
return u'|'.join(l)
result = u'|'.join(l)
assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result
@cython.test_assert_path_exists(
"//SimpleCallNode",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment