Commit 80638c1a authored by Robert Bradshaw's avatar Robert Bradshaw

Let cython.inline return all defined variables by default.

parent 87049841
...@@ -150,10 +150,10 @@ def cython_inline(code, ...@@ -150,10 +150,10 @@ def cython_inline(code,
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__ key = orig_code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest() module_name = "_cython_inline_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
if module_name in sys.modules: if module_name in sys.modules:
module = sys.modules[module_name] module = sys.modules[module_name]
else: else:
build_extension = None build_extension = None
if cython_inline.so_ext is None: if cython_inline.so_ext is None:
...@@ -185,12 +185,16 @@ def cython_inline(code, ...@@ -185,12 +185,16 @@ def cython_inline(code,
%(cimports)s %(cimports)s
def __invoke(%(params)s): def __invoke(%(params)s):
%(func_body)s %(func_body)s
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } return locals()
""" % {'cimports': '\n'.join(cimports),
'module_body': module_body,
'params': params,
'func_body': func_body }
for key, value in literals.items(): for key, value in literals.items():
module_code = module_code.replace(key, value) module_code = module_code.replace(key, value)
pyx_file = os.path.join(lib_dir, module_name + '.pyx') pyx_file = os.path.join(lib_dir, module_name + '.pyx')
fh = open(pyx_file, 'w') fh = open(pyx_file, 'w')
try: try:
fh.write(module_code) fh.write(module_code)
finally: finally:
fh.close() fh.close()
......
...@@ -40,7 +40,14 @@ class TestInline(CythonTest): ...@@ -40,7 +40,14 @@ class TestInline(CythonTest):
def test_globals(self): def test_globals(self):
self.assertEquals(inline("return global_value + 1", **self.test_kwds), global_value + 1) self.assertEquals(inline("return global_value + 1", **self.test_kwds), global_value + 1)
def test_pure(self): def test_no_return(self):
self.assertEquals(inline("""
a = 1
cdef double b = 2
cdef c = []
"""), dict(a=1, b=2.0, c=[]))
def test_pure(self):
import cython as cy import cython as cy
b = inline(""" b = inline("""
b = cy.declare(float, a) b = cy.declare(float, a)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment