Commit 00980f04 authored by Stefan Behnel's avatar Stefan Behnel

optimise str/bytes.join() and infer the result type; improve type inference for called builtins

parent 066c5b04
...@@ -74,12 +74,13 @@ builtin_utility_code = { ...@@ -74,12 +74,13 @@ builtin_utility_code = {
# mapping from builtins to their C-level equivalents # mapping from builtins to their C-level equivalents
class _BuiltinOverride(object): class _BuiltinOverride(object):
def __init__(self, py_name, args, ret_type, cname, py_equiv = "*", def __init__(self, py_name, args, ret_type, cname, py_equiv="*",
utility_code = None, sig = None, func_type = None, utility_code=None, sig=None, func_type=None,
is_strict_signature = False): is_strict_signature=False, builtin_return_type=None):
self.py_name, self.cname, self.py_equiv = py_name, cname, py_equiv self.py_name, self.cname, self.py_equiv = py_name, cname, py_equiv
self.args, self.ret_type = args, ret_type self.args, self.ret_type = args, ret_type
self.func_type, self.sig = func_type, sig self.func_type, self.sig = func_type, sig
self.builtin_return_type = builtin_return_type
self.is_strict_signature = is_strict_signature self.is_strict_signature = is_strict_signature
self.utility_code = utility_code self.utility_code = utility_code
...@@ -89,6 +90,8 @@ class _BuiltinOverride(object): ...@@ -89,6 +90,8 @@ class _BuiltinOverride(object):
func_type = sig.function_type(self_arg) func_type = sig.function_type(self_arg)
if self.is_strict_signature: if self.is_strict_signature:
func_type.is_strict_signature = True func_type.is_strict_signature = True
if self.builtin_return_type:
func_type.return_type = builtin_types[self.builtin_return_type]
return func_type return func_type
...@@ -212,7 +215,7 @@ builtin_function_table = [ ...@@ -212,7 +215,7 @@ builtin_function_table = [
#('raw_input', "", "", ""), #('raw_input', "", "", ""),
#('reduce', "", "", ""), #('reduce', "", "", ""),
BuiltinFunction('reload', "O", "O", "PyImport_ReloadModule"), BuiltinFunction('reload', "O", "O", "PyImport_ReloadModule"),
BuiltinFunction('repr', "O", "O", "PyObject_Repr"), BuiltinFunction('repr', "O", "O", "PyObject_Repr", builtin_return_type='basestring'),
#('round', "", "", ""), #('round', "", "", ""),
BuiltinFunction('setattr', "OOO", "r", "PyObject_SetAttr"), BuiltinFunction('setattr', "OOO", "r", "PyObject_SetAttr"),
#('sum', "", "", ""), #('sum', "", "", ""),
...@@ -276,8 +279,13 @@ builtin_types_table = [ ...@@ -276,8 +279,13 @@ builtin_types_table = [
("basestring", "PyBaseString_Type", []), ("basestring", "PyBaseString_Type", []),
("bytearray", "PyByteArray_Type", []), ("bytearray", "PyByteArray_Type", []),
("bytes", "PyBytes_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"), ("bytes", "PyBytes_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"),
BuiltinMethod("join", "TO", "O", "__Pyx_PyBytes_Join",
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
]), ]),
("str", "PyString_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"), ("str", "PyString_Type", [BuiltinMethod("__contains__", "TO", "b", "PySequence_Contains"),
BuiltinMethod("join", "TO", "O", "__Pyx_PyString_Join",
builtin_return_type='basestring',
utility_code=UtilityCode.load("StringJoin", "StringTools.c")),
]), ]),
("unicode", "PyUnicode_Type", [BuiltinMethod("__contains__", "TO", "b", "PyUnicode_Contains"), ("unicode", "PyUnicode_Type", [BuiltinMethod("__contains__", "TO", "b", "PyUnicode_Contains"),
BuiltinMethod("join", "TO", "T", "PyUnicode_Join"), BuiltinMethod("join", "TO", "T", "PyUnicode_Join"),
...@@ -404,10 +412,11 @@ def init_builtin_structs(): ...@@ -404,10 +412,11 @@ def init_builtin_structs():
builtin_scope.declare_struct_or_union( builtin_scope.declare_struct_or_union(
name, "struct", scope, 1, None, cname = cname) name, "struct", scope, 1, None, cname = cname)
def init_builtins(): def init_builtins():
init_builtin_structs() init_builtin_structs()
init_builtin_funcs()
init_builtin_types() init_builtin_types()
init_builtin_funcs()
builtin_scope.declare_var( builtin_scope.declare_var(
'__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type), '__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True) pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True)
...@@ -429,4 +438,5 @@ def init_builtins(): ...@@ -429,4 +438,5 @@ def init_builtins():
bool_type = builtin_scope.lookup('bool').type bool_type = builtin_scope.lookup('bool').type
complex_type = builtin_scope.lookup('complex').type complex_type = builtin_scope.lookup('complex').type
init_builtins() init_builtins()
...@@ -4163,9 +4163,14 @@ class CallNode(ExprNode): ...@@ -4163,9 +4163,14 @@ class CallNode(ExprNode):
def infer_type(self, env): def infer_type(self, env):
function = self.function function = self.function
if isinstance(function, NewExprNode):
return PyrexTypes.CPtrType(function.class_type)
func_type = function.infer_type(env) func_type = function.infer_type(env)
if isinstance(self.function, NewExprNode): if func_type is py_object_type:
return PyrexTypes.CPtrType(self.function.class_type) # function might have lied for safety => try to find better type
entry = getattr(function, 'entry', None)
if entry is not None:
func_type = entry.type or func_type
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
if func_type.is_cfunction: if func_type.is_cfunction:
......
...@@ -648,3 +648,27 @@ static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t i ...@@ -648,3 +648,27 @@ static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t i
index += PyBytes_GET_SIZE(bytes); index += PyBytes_GET_SIZE(bytes);
return PyBytes_AS_STRING(bytes)[index]; return PyBytes_AS_STRING(bytes)[index];
} }
//////////////////// StringJoin.proto ////////////////////
#if PY_MAJOR_VERSION < 3
#define __Pyx_PyString_Join __Pyx_PyBytes_Join
#else
#define __Pyx_PyString_Join PyUnicode_Join
#endif
#if CYTHON_COMPILING_IN_CPYTHON
#define __Pyx_PyBytes_Join _PyBytes_Join
#else
static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values); /*proto*/
#endif
//////////////////// StringJoin ////////////////////
#if !CYTHON_COMPILING_IN_CPYTHON
static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* values) {
return PyObject_CallMethodObjArgs(sep, PYIDENT("join"), values, NULL)
}
#endif
...@@ -190,3 +190,33 @@ def bytes_decode_unbound_method(bytes s, start=None, stop=None): ...@@ -190,3 +190,33 @@ def bytes_decode_unbound_method(bytes s, start=None, stop=None):
return bytes.decode(s[start:], 'utf8') return bytes.decode(s[start:], 'utf8')
else: else:
return bytes.decode(s[start:stop], 'utf8') return bytes.decode(s[start:stop], 'utf8')
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def bytes_join(bytes s, *args):
"""
>>> print(bytes_join(b_a, b_b, b_b, b_b).decode('utf8'))
babab
"""
result = s.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
return result
@cython.test_fail_if_path_exists(
"//SimpleCallNode//NoneCheckNode",
)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def literal_join(*args):
"""
>>> print(literal_join(b_b, b_b, b_b, b_b).decode('utf8'))
b|b|b|b
"""
result = b'|'.join(args)
assert cython.typeof(result) == 'Python object', cython.typeof(result)
return result
...@@ -53,3 +53,33 @@ def str_endswith(str s, sub, start=None, stop=None): ...@@ -53,3 +53,33 @@ def str_endswith(str s, sub, start=None, stop=None):
return s.endswith(sub, start) return s.endswith(sub, start)
else: else:
return s.endswith(sub, start, stop) return s.endswith(sub, start, stop)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def str_join(str s, args):
"""
>>> print(str_join('a', list('bbb')))
babab
"""
result = s.join(args)
assert cython.typeof(result) == 'basestring object', cython.typeof(result)
return result
@cython.test_fail_if_path_exists(
"//SimpleCallNode//NoneCheckNode",
)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def literal_join(args):
"""
>>> print(literal_join(list('abcdefg')))
a|b|c|d|e|f|g
"""
result = '|'.join(args)
assert cython.typeof(result) == 'basestring object', cython.typeof(result)
return result
...@@ -218,7 +218,9 @@ def join_sep(l): ...@@ -218,7 +218,9 @@ def join_sep(l):
>>> print( join_sep(l) ) >>> print( join_sep(l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf ab|jd|sdflk|as|sa|sadas|asdas|fsdf
""" """
return u'|'.join(l) result = u'|'.join(l)
assert cython.typeof(result) == 'unicode object', cython.typeof(result)
return result
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//SimpleCallNode", "//SimpleCallNode",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment