Commit 880dee0a authored by Robert Bradshaw's avatar Robert Bradshaw

Use unbound symbols from local/global scope.

parent 2d2fba93
print "Warning: Using prototype cython.inline code..."
import tempfile import tempfile
import sys, os, re import sys, os, re, inspect
try: try:
import hashlib import hashlib
...@@ -12,12 +14,44 @@ from Cython.Distutils import build_ext ...@@ -12,12 +14,44 @@ from Cython.Distutils import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options from Cython.Compiler.Main import Context, CompilationOptions, default_options
code_cache = {} from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
_code_cache = {}
class AllSymbols(CythonTransform, SkipDeclarations):
def __init__(self):
CythonTransform.__init__(self, None)
self.names = set()
def visit_NameNode(self, node):
self.names.add(node.name)
def unbound_symbols(code, context=None):
if context is None:
context = Context([], default_options)
from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
if isinstance(code, str):
code = code.decode('ascii')
tree = parse_from_strings('(tree fragment)', code)
for phase in context.create_pipeline(pxd=False):
if phase is None:
continue
tree = phase(tree)
if isinstance(phase, AnalyseDeclarationsTransform):
break
symbol_collector = AllSymbols()
symbol_collector(tree)
unbound = []
import __builtin__
for name in symbol_collector.names:
if not tree.scope.lookup(name) and not hasattr(__builtin__, name):
unbound.append(name)
return unbound
def get_type(arg, context=None): def get_type(arg, context=None):
py_type = type(arg) py_type = type(arg)
# TODO: extension types
if py_type in [list, tuple, dict, str]: if py_type in [list, tuple, dict, str]:
return py_type.__name__ return py_type.__name__
elif py_type is float: elif py_type is float:
...@@ -40,21 +74,43 @@ def get_type(arg, context=None): ...@@ -40,21 +74,43 @@ def get_type(arg, context=None):
return 'object' return 'object'
# TODO: use locals/globals for unbound variables # TODO: use locals/globals for unbound variables
def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cython/inline'), include_dirs=['.'], **kwds): def cython_inline(code,
types='aggressive',
lib_dir=os.path.expanduser('~/.cython/inline'),
include_dirs=['.'],
locals=None,
globals=None,
**kwds):
ctx = Context(include_dirs, default_options) ctx = Context(include_dirs, default_options)
_, pyx_file = tempfile.mkstemp('.pyx') if locals is None:
locals = inspect.currentframe().f_back.f_back.f_locals
if globals is None:
globals = inspect.currentframe().f_back.f_back.f_globals
try:
for symbol in unbound_symbols(code):
if symbol in kwds:
continue
elif symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print "Couldn't find ", symbol
except AssertionError:
# Parsing from strings not fully supported (e.g. cimports).
print "Could not parse code as a string (to extract unbound symbols)."
arg_names = kwds.keys() arg_names = kwds.keys()
arg_names.sort() arg_names.sort()
arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names) arg_sigs = tuple((get_type(kwds[arg], ctx), arg) for arg in arg_names)
key = code, arg_sigs key = code, arg_sigs
module = code_cache.get(key) module = _code_cache.get(key)
if not module: if not module:
cimports = '' cimports = []
qualified = re.compile(r'([.\w]+)[.]') qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs: for type, _ in arg_sigs:
m = qualified.match(type) m = qualified.match(type)
if m: if m:
cimports += '\ncimport %s' % m.groups()[0] cimports.append('\ncimport %s' % m.groups()[0])
module_body, func_body = extract_func_code(code) module_body, func_body = extract_func_code(code)
params = ', '.join('%s %s' % a for a in arg_sigs) params = ', '.join('%s %s' % a for a in arg_sigs)
module_code = """ module_code = """
...@@ -62,8 +118,9 @@ def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cytho ...@@ -62,8 +118,9 @@ def cython_inline(code, types='aggressive', lib_dir=os.path.expanduser('~/.cytho
%(module_body)s %(module_body)s
def __invoke(%(params)s): def __invoke(%(params)s):
%(func_body)s %(func_body)s
""" % locals() """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
print module_code # print module_code
_, pyx_file = tempfile.mkstemp('.pyx')
open(pyx_file, 'w').write(module_code) open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension( extension = Extension(
...@@ -78,7 +135,7 @@ def __invoke(%(params)s): ...@@ -78,7 +135,7 @@ def __invoke(%(params)s):
sys.path.append(lib_dir) sys.path.append(lib_dir)
build_extension.build_lib = lib_dir build_extension.build_lib = lib_dir
build_extension.run() build_extension.run()
code_cache[key] = module _code_cache[key] = module
arg_list = [kwds[arg] for arg in arg_names] arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list) return __import__(module).__invoke(*arg_list)
......
...@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): ...@@ -60,6 +60,7 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
scope = scope, context = context, initial_pos = initial_pos) scope = scope, context = context, initial_pos = initial_pos)
if level is None: if level is None:
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
tree.scope = scope
else: else:
tree = Parsing.p_code(scanner, level=level) tree = Parsing.p_code(scanner, level=level)
return tree return tree
...@@ -201,6 +202,8 @@ class TreeFragment(object): ...@@ -201,6 +202,8 @@ class TreeFragment(object):
if not isinstance(t, StatListNode): if not isinstance(t, StatListNode):
t = StatListNode(pos=mod.pos, stats=[t]) t = StatListNode(pos=mod.pos, stats=[t])
for transform in pipeline: for transform in pipeline:
if transform is None:
continue
t = transform(t) t = transform(t)
self.root = t self.root = t
elif isinstance(code, Node): elif isinstance(code, Node):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment