Commit 880dee0a authored by Robert Bradshaw's avatar Robert Bradshaw

Use unbound symbols from local/global scope.

parent 2d2fba93
print "Warning: Using prototype cython.inline code..."
import tempfile
import sys, os, re
import sys, os, re, inspect
try:
import hashlib
......@@ -12,12 +14,44 @@ from Cython.Distutils import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options
code_cache = {}
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
_code_cache = {}
class AllSymbols(CythonTransform, SkipDeclarations):
def __init__(self):
CythonTransform.__init__(self, None)
self.names = set()
def visit_NameNode(self, node):
self.names.add(node.name)
def unbound_symbols(code, context=None):
if context is None:
context = Context([], default_options)
from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
if isinstance(code, str):
code = code.decode('ascii')
tree = parse_from_strings('(tree fragment)', code)
for phase in context.create_pipeline(pxd=False):
if phase is None:
continue
tree = phase(tree)
if isinstance(phase, AnalyseDeclarationsTransform):
break
symbol_collector = AllSymbols()
symbol_collector(tree)
unbound = []
import __builtin__
for name in symbol_collector.names:
if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
unbound.append(name)
return unbound
def get_type(arg, context=None):
py_type = type(arg)
# TODO: extension types
if py_type in [list, tuple, dict, str]:
return py_type.__name__
elif py_type is float:
......@@ -40,21 +74,43 @@ def get_type(arg, context=None):
return 'object'
# TODO: use locals/globals for unbound variables
def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds):
def cython_inline(code,
types='aggressive',
lib_dir=os.path.expanduser('~/.cython/inline'),
include_dirs=['.'],
locals=None,
globals=None,
**kwds):
ctx = Context(include_dirs, default_options)
_, pyx_file = tempfile.mkstemp('.pyx')
if locals is None:
locals = inspect.currentframe().f_back.f_back.f_locals
if globals is None:
globals = inspect.currentframe().f_back.f_back.f_globals
try:
for symbol in unbound_symbols(code):
if symbol in kwds:
continue
elif symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print "Couldn't find ", symbol
except AssertionError:
# Parsing from strings not fully supported (e.g. cimports).
print "Could not parse code as a string (to extract unbound symbols)."
arg_names = kwds.keys()
arg_names.sort()
arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
key = code, arg_sigs
module = code_cache.get(key)
module = _code_cache.get(key)
if not module:
cimports = ''
cimports = []
qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs:
m = qualified.match(type)
if m:
cimports += '\ncimport %s' % m.groups()[0]
cimports.append('\ncimport %s' % m.groups()[0])
module_body, func_body = extract_func_code(code)
params = ', '.join('%s %s' % a for a in arg_sigs)
module_code = """
......@@ -62,8 +118,9 @@ def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cytho
%(module_body)s
def __invoke(%(params)s):
%(func_body)s
""" % locals()
print module_code
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
# print module_code
_, pyx_file = tempfile.mkstemp('.pyx')
open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension(
......@@ -78,7 +135,7 @@ def __invoke(%(params)s):
sys.path.append(lib_dir)
build_extension.build_lib = lib_dir
build_extension.run()
code_cache[key] = module
_code_cache[key] = module
arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list)
......
......@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
scope = scope, context = context, initial_pos = initial_pos)
if level is None:
tree = Parsing.p_module(scanner, 0, module_name)
tree.scope = scope
else:
tree = Parsing.p_code(scanner, level=level)
return tree
......@@ -201,6 +202,8 @@ class TreeFragment(object):
if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline:
if transform is None:
continue
t = transform(t)
self.root = t
elif isinstance(code, Node):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment