Commit 858709b6 authored by Robert Bradshaw's avatar Robert Bradshaw

String literal parsing in inline mode, hook up to cythonize.

parent 8d722faf
...@@ -94,7 +94,7 @@ class DistutilsInfo(object): ...@@ -94,7 +94,7 @@ class DistutilsInfo(object):
value = [tuple(macro.split('=')) for macro in value] value = [tuple(macro.split('=')) for macro in value]
self.values[key] = value self.values[key] = value
elif exn is not None: elif exn is not None:
for key in self.distutils_settings: for key in distutils_settings:
if key in ('name', 'sources'): if key in ('name', 'sources'):
pass pass
value = getattr(exn, key, None) value = getattr(exn, key, None)
...@@ -154,19 +154,32 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -154,19 +154,32 @@ def strip_string_literals(code, prefix='__Pyx_L'):
in_quote = False in_quote = False
raw = False raw = False
while True: while True:
hash_mark = code.find('#', q)
single_q = code.find("'", q) single_q = code.find("'", q)
double_q = code.find('"', q) double_q = code.find('"', q)
q = min(single_q, double_q) q = min(single_q, double_q)
if q == -1: q = max(single_q, double_q) if q == -1: q = max(single_q, double_q)
if q == -1:
if in_quote: # Process comment.
if hash_mark < q or hash_mark > -1 == q:
end = code.find('\n', hash_mark)
if end == -1:
end = None
new_code.append(code[start:hash_mark+1])
counter += 1 counter += 1
label = "'%s%s" % (prefix, counter) label = "%s%s" % (prefix, counter)
literals[label] = code[start:] literals[label] = code[hash_mark+1:end]
new_code.append(label) new_code.append(label)
else: if end is None:
break
q = end
# We're done.
elif q == -1:
new_code.append(code[start:]) new_code.append(code[start:])
break break
# Try to close the quote.
elif in_quote: elif in_quote:
if code[q-1] == '\\': if code[q-1] == '\\':
k = 2 k = 2
...@@ -179,12 +192,14 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -179,12 +192,14 @@ def strip_string_literals(code, prefix='__Pyx_L'):
counter += 1 counter += 1
label = "%s%s" % (prefix, counter) label = "%s%s" % (prefix, counter)
literals[label] = code[start+len(in_quote):q] literals[label] = code[start+len(in_quote):q]
new_code.append("'%s'" % label) new_code.append("%s%s%s" % (in_quote, label, in_quote))
q += len(in_quote) q += len(in_quote)
start = q start = q
in_quote = False in_quote = False
else: else:
q += 1 q += 1
# Open the quote.
else: else:
raw = False raw = False
if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]): if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
...@@ -202,13 +217,13 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -202,13 +217,13 @@ def strip_string_literals(code, prefix='__Pyx_L'):
return "".join(new_code), literals return "".join(new_code), literals
def parse_dependencies(source_filename): def parse_dependencies(source_filename):
# Actual parsing is way to slow, so we use regular expressions. # Actual parsing is way to slow, so we use regular expressions.
# The only catch is that we must strip comments and string # The only catch is that we must strip comments and string
# literals ahead of time. # literals ahead of time.
source = Utils.open_source_file(source_filename, "rU").read() source = Utils.open_source_file(source_filename, "rU").read()
distutils_info = DistutilsInfo(source) distutils_info = DistutilsInfo(source)
source = re.sub('#.*', '', source)
source, literals = strip_string_literals(source) source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ') source = source.replace('\\\n', ' ')
if '\t' in source: if '\t' in source:
...@@ -389,8 +404,8 @@ def create_extension_list(patterns, ctx=None, aliases=None): ...@@ -389,8 +404,8 @@ def create_extension_list(patterns, ctx=None, aliases=None):
continue continue
template = pattern template = pattern
name = template.name name = template.name
base = DistutilsInfo(template) base = DistutilsInfo(exn=template)
exn_type = type(template) exn_type = template.__class__
else: else:
raise TypeError(pattern) raise TypeError(pattern)
for file in glob(filepattern): for file in glob(filepattern):
......
...@@ -9,14 +9,14 @@ try: ...@@ -9,14 +9,14 @@ try:
except ImportError: except ImportError:
import md5 as hashlib import md5 as hashlib
from distutils.dist import Distribution from distutils.core import Distribution, Extension
from Cython.Distutils.extension import Extension from distutils.command.build_ext import build_ext
from Cython.Distutils import build_ext
from Cython.Compiler.Main import Context, CompilationOptions, default_options from Cython.Compiler.Main import Context, CompilationOptions, default_options
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings from Cython.Compiler.TreeFragment import parse_from_strings
from Cython.Build.Dependencies import strip_string_literals, cythonize
_code_cache = {} _code_cache = {}
...@@ -82,6 +82,7 @@ def cython_inline(code, ...@@ -82,6 +82,7 @@ def cython_inline(code,
locals=None, locals=None,
globals=None, globals=None,
**kwds): **kwds):
code, literals = strip_string_literals(code)
code = strip_common_indent(code) code = strip_common_indent(code)
ctx = Context(include_dirs, default_options) ctx = Context(include_dirs, default_options)
if locals is None: if locals is None:
...@@ -116,22 +117,23 @@ def cython_inline(code, ...@@ -116,22 +117,23 @@ def cython_inline(code,
module_body, func_body = extract_func_code(code) module_body, func_body = extract_func_code(code)
params = ', '.join(['%s %s' % a for a in arg_sigs]) params = ', '.join(['%s %s' % a for a in arg_sigs])
module_code = """ module_code = """
%(cimports)s
%(module_body)s %(module_body)s
%(cimports)s
def __invoke(%(params)s): def __invoke(%(params)s):
%(func_body)s %(func_body)s
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
# print module_code for key, value in literals.items():
_, pyx_file = tempfile.mkstemp('.pyx') module_code = module_code.replace(key, value)
open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
pyx_file = os.path.join(tempfile.mkdtemp(), module + '.pyx')
open(pyx_file, 'w').write(module_code)
extension = Extension( extension = Extension(
name = module, name = module,
sources = [pyx_file], sources = [pyx_file],
pyrex_include_dirs = include_dirs) pyrex_include_dirs = include_dirs)
build_extension = build_ext(Distribution()) build_extension = build_ext(Distribution())
build_extension.finalize_options() build_extension.finalize_options()
build_extension.extensions = [extension] build_extension.extensions = cythonize([extension])
build_extension.build_temp = os.path.dirname(pyx_file) build_extension.build_temp = os.path.dirname(pyx_file)
if lib_dir not in sys.path: if lib_dir not in sys.path:
sys.path.append(lib_dir) sys.path.append(lib_dir)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment