Commit 858709b6 authored by Robert Bradshaw's avatar Robert Bradshaw

String literal parsing in inline mode, hook up to cythonize.

parent 8d722faf
......@@ -94,7 +94,7 @@ class DistutilsInfo(object):
value = [tuple(macro.split('=')) for macro in value]
self.values[key] = value
elif exn is not None:
for key in self.distutils_settings:
for key in distutils_settings:
if key in ('name', 'sources'):
pass
value = getattr(exn, key, None)
......@@ -154,19 +154,32 @@ def strip_string_literals(code, prefix='__Pyx_L'):
in_quote = False
raw = False
while True:
hash_mark = code.find('#', q)
single_q = code.find("'", q)
double_q = code.find('"', q)
q = min(single_q, double_q)
if q == -1: q = max(single_q, double_q)
if q == -1:
if in_quote:
# Process comment.
if hash_mark < q or hash_mark > -1 == q:
end = code.find('\n', hash_mark)
if end == -1:
end = None
new_code.append(code[start:hash_mark+1])
counter += 1
label = "'%s%s" % (prefix, counter)
literals[label] = code[start:]
label = "%s%s" % (prefix, counter)
literals[label] = code[hash_mark+1:end]
new_code.append(label)
else:
if end is None:
break
q = end
# We're done.
elif q == -1:
new_code.append(code[start:])
break
# Try to close the quote.
elif in_quote:
if code[q-1] == '\\':
k = 2
......@@ -179,12 +192,14 @@ def strip_string_literals(code, prefix='__Pyx_L'):
counter += 1
label = "%s%s" % (prefix, counter)
literals[label] = code[start+len(in_quote):q]
new_code.append("'%s'" % label)
new_code.append("%s%s%s" % (in_quote, label, in_quote))
q += len(in_quote)
start = q
in_quote = False
else:
q += 1
# Open the quote.
else:
raw = False
if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
......@@ -202,13 +217,13 @@ def strip_string_literals(code, prefix='__Pyx_L'):
return "".join(new_code), literals
def parse_dependencies(source_filename):
# Actual parsing is way to slow, so we use regular expressions.
# The only catch is that we must strip comments and string
# literals ahead of time.
source = Utils.open_source_file(source_filename, "rU").read()
distutils_info = DistutilsInfo(source)
source = re.sub('#.*', '', source)
source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ')
if '\t' in source:
......@@ -389,8 +404,8 @@ def create_extension_list(patterns, ctx=None, aliases=None):
continue
template = pattern
name = template.name
base = DistutilsInfo(template)
exn_type = type(template)
base = DistutilsInfo(exn=template)
exn_type = template.__class__
else:
raise TypeError(pattern)
for file in glob(filepattern):
......
......@@ -9,14 +9,14 @@ try:
except ImportError:
import md5 as hashlib
from distutils.dist import Distribution
from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext
from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings
from Cython.Build.Dependencies import strip_string_literals, cythonize
_code_cache = {}
......@@ -82,6 +82,7 @@ def cython_inline(code,
locals=None,
globals=None,
**kwds):
code, literals = strip_string_literals(code)
code = strip_common_indent(code)
ctx = Context(include_dirs, default_options)
if locals is None:
......@@ -116,22 +117,23 @@ def cython_inline(code,
module_body, func_body = extract_func_code(code)
params = ', '.join(['%s %s' % a for a in arg_sigs])
module_code = """
%(cimports)s
%(module_body)s
%(cimports)s
def __invoke(%(params)s):
%(func_body)s
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
# print module_code
_, pyx_file = tempfile.mkstemp('.pyx')
open(pyx_file, 'w').write(module_code)
for key, value in literals.items():
module_code = module_code.replace(key, value)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
pyx_file = os.path.join(tempfile.mkdtemp(), module + '.pyx')
open(pyx_file, 'w').write(module_code)
extension = Extension(
name = module,
sources = [pyx_file],
pyrex_include_dirs = include_dirs)
build_extension = build_ext(Distribution())
build_extension.finalize_options()
build_extension.extensions = [extension]
build_extension.extensions = cythonize([extension])
build_extension.build_temp = os.path.dirname(pyx_file)
if lib_dir not in sys.path:
sys.path.append(lib_dir)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment