Commit 499d2622 authored by Robert Bradshaw's avatar Robert Bradshaw

Another 10x reduction in overhead for cython.inline.

parent 0dc185de
......@@ -27,6 +27,8 @@ Features added
* ``libc/math.pxd`` provides ``e`` and ``pi`` as alias constants to simplify
usage as a drop-in replacement for Python's math module.
* Speed up cython.inline().
Bugs fixed
----------
......
......@@ -66,7 +66,7 @@ def unbound_symbols(code, context=None):
import builtins
except ImportError:
import __builtin__ as builtins
return UnboundSymbols()(tree) - set(dir(builtins))
return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
def unsafe_type(arg, context=None):
......@@ -79,7 +79,7 @@ def unsafe_type(arg, context=None):
def safe_type(arg, context=None):
py_type = type(arg)
if py_type in [list, tuple, dict, str]:
if py_type in (list, tuple, dict, str):
return py_type.__name__
elif py_type is complex:
return 'double complex'
......@@ -117,31 +117,54 @@ def _create_context(cython_include_dirs):
return Context(list(cython_include_dirs), default_options)
_cython_inline_cache = {}
_cython_inline_default_context = _create_context(('.',))
def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
for symbol in unbound_symbols:
if symbol not in kwds:
if locals is None or globals is None:
calling_frame = inspect.currentframe().f_back.f_back.f_back
if locals is None:
locals = calling_frame.f_locals
if globals is None:
globals = calling_frame.f_globals
if symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print("Couldn't find %r" % symbol)
def cython_inline(code, get_type=unsafe_type, lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
cython_include_dirs=None, force=False, quiet=False, locals=None, globals=None, **kwds):
if cython_include_dirs is None:
cython_include_dirs = ['.']
if get_type is None:
get_type = lambda x: 'object'
code = to_unicode(code)
ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
# Fast path if this has been called in this session.
_unbound_symbols = _cython_inline_cache.get(code)
if _unbound_symbols is not None:
_populate_unbound(kwds, _unbound_symbols, locals, globals)
args = sorted(kwds.items())
arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
invoke = _cython_inline_cache.get((code, arg_sigs))
if invoke is not None:
arg_list = [arg[1] for arg in args]
return invoke(*arg_list)
orig_code = code
code = to_unicode(code)
code, literals = strip_string_literals(code)
code = strip_common_indent(code)
ctx = _create_context(tuple(cython_include_dirs))
if locals is None:
locals = inspect.currentframe().f_back.f_back.f_locals
if globals is None:
globals = inspect.currentframe().f_back.f_back.f_globals
try:
for symbol in unbound_symbols(code):
if symbol in kwds:
continue
elif symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print("Couldn't find %r" % symbol)
_cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
_populate_unbound(kwds, _unbound_symbols, locals, globals)
except AssertionError:
if not quiet:
# Parsing from strings not fully supported (e.g. cimports).
......@@ -210,13 +233,14 @@ def __invoke(%(params)s):
extra_compile_args = cflags)
if build_extension is None:
build_extension = _get_build_extension()
build_extension.extensions = cythonize([extension], include_path=cython_include_dirs, quiet=quiet)
build_extension.extensions = cythonize([extension], include_path=cython_include_dirs or ['.'], quiet=quiet)
build_extension.build_temp = os.path.dirname(pyx_file)
build_extension.build_lib = lib_dir
build_extension.run()
module = imp.load_dynamic(module_name, module_path)
_cython_inline_cache[orig_code, arg_sigs] = module.__invoke
arg_list = [kwds[arg] for arg in arg_names]
return module.__invoke(*arg_list)
......
......@@ -121,10 +121,13 @@ overflowcheck.fold = optimization.use_switch = \
final = internal = type_version_tag = no_gc_clear = no_gc = _empty_decorator
_cython_inline = None
def inline(f, *args, **kwds):
if isinstance(f, basestring):
from Cython.Build.Inline import cython_inline
return cython_inline(f, *args, **kwds)
global _cython_inline
if _cython_inline is None:
from Cython.Build.Inline import cython_inline as _cython_inline
return _cython_inline(f, *args, **kwds)
else:
assert len(args) == len(kwds) == 0
return f
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment