Commit 6a30fecf authored by will-ca's avatar will-ca Committed by GitHub

Make `Shadow.inline()` caching account for language version and compilation environment. (GH-3440)

Closes https://github.com/cython/cython/issues/3419
parent 7cc572f5
......@@ -141,6 +141,10 @@ def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
else:
print("Couldn't find %r" % symbol)
def _inline_key(orig_code, arg_sigs, language_level):
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
def cython_inline(code, get_type=unsafe_type,
lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
cython_include_dirs=None, cython_compiler_directives=None,
......@@ -150,13 +154,20 @@ def cython_inline(code, get_type=unsafe_type,
get_type = lambda x: 'object'
ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
cython_compiler_directives = dict(cython_compiler_directives or {})
if language_level is None and 'language_level' not in cython_compiler_directives:
language_level = '3str'
if language_level is not None:
cython_compiler_directives['language_level'] = language_level
# Fast path if this has been called in this session.
_unbound_symbols = _cython_inline_cache.get(code)
if _unbound_symbols is not None:
_populate_unbound(kwds, _unbound_symbols, locals, globals)
args = sorted(kwds.items())
arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
invoke = _cython_inline_cache.get((code, arg_sigs))
key_hash = _inline_key(code, arg_sigs, language_level)
invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
if invoke is not None:
arg_list = [arg[1] for arg in args]
return invoke(*arg_list)
......@@ -177,12 +188,6 @@ def cython_inline(code, get_type=unsafe_type,
# Parsing from strings not fully supported (e.g. cimports).
print("Could not parse code as a string (to extract unbound symbols).")
cython_compiler_directives = dict(cython_compiler_directives or {})
if language_level is None and 'language_level' not in cython_compiler_directives:
language_level = '3str'
if language_level is not None:
cython_compiler_directives['language_level'] = language_level
cimports = []
for name, arg in list(kwds.items()):
if arg is cython_module:
......@@ -190,8 +195,8 @@ def cython_inline(code, get_type=unsafe_type,
del kwds[name]
arg_names = sorted(kwds)
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
module_name = "_cython_inline_" + hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
key_hash = _inline_key(orig_code, arg_sigs, language_level)
module_name = "_cython_inline_" + key_hash
if module_name in sys.modules:
module = sys.modules[module_name]
......@@ -258,7 +263,7 @@ def __invoke(%(params)s):
module = load_dynamic(module_name, module_path)
_cython_inline_cache[orig_code, arg_sigs] = module.__invoke
_cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
arg_list = [kwds[arg] for arg in arg_names]
return module.__invoke(*arg_list)
......
......@@ -74,6 +74,18 @@ class TestInline(CythonTest):
6
)
def test_lang_version(self):
# GH-3419. Caching for inline code didn't always respect compiler directives.
inline_divcode = "def f(int a, int b): return a/b"
self.assertEqual(
inline(inline_divcode, language_level=2)['f'](5,2),
2
)
self.assertEqual(
inline(inline_divcode, language_level=3)['f'](5,2),
2.5
)
if has_numpy:
def test_numpy(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment