Commit b69b4a5b authored by Robert Bradshaw's avatar Robert Bradshaw

cythonize performance improvements for large codebases

parent 9560cbec
# cython: profile=True
import cython
from glob import glob from glob import glob
import re, os, sys import re, os, sys
...@@ -5,33 +9,10 @@ import re, os, sys ...@@ -5,33 +9,10 @@ import re, os, sys
from distutils.extension import Extension from distutils.extension import Extension
from Cython import Utils from Cython import Utils
from Cython.Utils import cached_function, cached_method, path_exists
from Cython.Compiler.Main import Context, CompilationOptions, default_options from Cython.Compiler.Main import Context, CompilationOptions, default_options
def cached_function(f): os.path.join = cached_function(os.path.join)
cache_name = '__%s_cache' % f.__name__
def wrapper(*args):
cache = getattr(f, cache_name, None)
if cache is None:
cache = {}
setattr(f, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(*args)
return res
return wrapper
def cached_method(f):
cache_name = '__%s_cache' % f.__name__
def wrapper(self, *args):
cache = getattr(self, cache_name, None)
if cache is None:
cache = {}
setattr(self, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(self, *args)
return res
return wrapper
def extended_iglob(pattern): def extended_iglob(pattern):
if '**/' in pattern: if '**/' in pattern:
...@@ -97,6 +78,7 @@ distutils_settings = { ...@@ -97,6 +78,7 @@ distutils_settings = {
'language': transitive_str, 'language': transitive_str,
} }
@cython.locals(start=long, end=long)
def line_iter(source): def line_iter(source):
start = 0 start = 0
while True: while True:
...@@ -175,7 +157,8 @@ class DistutilsInfo(object): ...@@ -175,7 +157,8 @@ class DistutilsInfo(object):
resolved.values[key] = value resolved.values[key] = value
return resolved return resolved
@cython.locals(start=long, q=long, single_q=long, double_q=long, hash_mark=long,
end=long, k=long, counter=long, quote_len=long)
def strip_string_literals(code, prefix='__Pyx_L'): def strip_string_literals(code, prefix='__Pyx_L'):
""" """
Normalizes every string literal to be of the form '__Pyx_Lxxx', Normalizes every string literal to be of the form '__Pyx_Lxxx',
...@@ -187,10 +170,15 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -187,10 +170,15 @@ def strip_string_literals(code, prefix='__Pyx_L'):
counter = 0 counter = 0
start = q = 0 start = q = 0
in_quote = False in_quote = False
raw = False hash_mark = single_q = double_q = -1
code_len = len(code)
while True: while True:
if hash_mark < q:
hash_mark = code.find('#', q) hash_mark = code.find('#', q)
if single_q < q:
single_q = code.find("'", q) single_q = code.find("'", q)
if double_q < q:
double_q = code.find('"', q) double_q = code.find('"', q)
q = min(single_q, double_q) q = min(single_q, double_q)
if q == -1: q = max(single_q, double_q) if q == -1: q = max(single_q, double_q)
...@@ -202,19 +190,22 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -202,19 +190,22 @@ def strip_string_literals(code, prefix='__Pyx_L'):
# Try to close the quote. # Try to close the quote.
elif in_quote: elif in_quote:
if code[q-1] == '\\' and not raw: if code[q-1] == u'\\':
k = 2 k = 2
while q >= k and code[q-k] == '\\': while q >= k and code[q-k] == u'\\':
k += 1 k += 1
if k % 2 == 0: if k % 2 == 0:
q += 1 q += 1
continue continue
if code[q:q+len(in_quote)] == in_quote: if code[q] == quote_type and (quote_len == 1 or (code_len > q + 2 and quote_type == code[q+1] == code[q+2])):
counter += 1 counter += 1
label = "%s%s_" % (prefix, counter) label = "%s%s_" % (prefix, counter)
literals[label] = code[start+len(in_quote):q] literals[label] = code[start+quote_len:q]
new_code.append("%s%s%s" % (in_quote, label, in_quote)) full_quote = code[q:q+quote_len]
q += len(in_quote) new_code.append(full_quote)
new_code.append(label)
new_code.append(full_quote)
q += quote_len
in_quote = False in_quote = False
start = q start = q
else: else:
...@@ -222,70 +213,67 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -222,70 +213,67 @@ def strip_string_literals(code, prefix='__Pyx_L'):
# Process comment. # Process comment.
elif -1 != hash_mark and (hash_mark < q or q == -1): elif -1 != hash_mark and (hash_mark < q or q == -1):
end = code.find('\n', hash_mark)
if end == -1:
end = None
new_code.append(code[start:hash_mark+1]) new_code.append(code[start:hash_mark+1])
end = code.find('\n', hash_mark)
counter += 1 counter += 1
label = "%s%s_" % (prefix, counter) label = "%s%s_" % (prefix, counter)
literals[label] = code[hash_mark+1:end] if end == -1:
end_or_none = None
else:
end_or_none = end
literals[label] = code[hash_mark+1:end_or_none]
new_code.append(label) new_code.append(label)
if end is None: if end == -1:
break break
q = end start = q = end
start = q
# Open the quote. # Open the quote.
else: else:
raw = False if code_len >= q+3 and (code[q] == code[q+1] == code[q+2]):
if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]): quote_len = 3
in_quote = code[q]*3
else: else:
in_quote = code[q] quote_len = 1
end = marker = q in_quote = True
while marker > 0 and code[marker-1] in 'rRbBuU': quote_type = code[q]
if code[marker-1] in 'rR': new_code.append(code[start:q])
raw = True
marker -= 1
new_code.append(code[start:end])
start = q start = q
q += len(in_quote) q += quote_len
return "".join(new_code), literals return "".join(new_code), literals
dependancy_regex = re.compile(r"(?:^from +([0-9a-zA-Z_.]+) +cimport)|"
r"(?:^cimport +([0-9a-zA-Z_.]+)\b)|"
r"(?:^cdef +extern +from +['\"]([^'\"]+)['\"])|"
r"(?:^include +['\"]([^'\"]+)['\"])", re.M)
def parse_dependencies(source_filename): def parse_dependencies(source_filename):
# Actual parsing is way to slow, so we use regular expressions. # Actual parsing is way to slow, so we use regular expressions.
# The only catch is that we must strip comments and string # The only catch is that we must strip comments and string
# literals ahead of time. # literals ahead of time.
fh = Utils.open_source_file(source_filename, "rU") fh = Utils.open_source_file(source_filename, "rU", error_handling='ignore')
try: try:
source = fh.read() source = fh.read()
finally: finally:
fh.close() fh.close()
distutils_info = DistutilsInfo(source) distutils_info = DistutilsInfo(source)
source, literals = strip_string_literals(source) source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ') source = source.replace('\\\n', ' ').replace('\t', ' ')
if '\t' in source:
source = source.replace('\t', ' ')
# TODO: pure mode # TODO: pure mode
dependancy = re.compile(r"(cimport +([0-9a-zA-Z_.]+)\b)|"
"(from +([0-9a-zA-Z_.]+) +cimport)|"
"(include +['\"]([^'\"]+)['\"])|"
"(cdef +extern +from +['\"]([^'\"]+)['\"])")
cimports = [] cimports = []
includes = [] includes = []
externs = [] externs = []
for m in dependancy.finditer(source): for m in dependancy_regex.finditer(source):
groups = m.groups() cimport_from, cimport, extern, include = m.groups()
if groups[0]: if cimport_from:
cimports.append(groups[1]) cimports.append(cimport_from)
elif groups[2]: elif cimport:
cimports.append(groups[3]) cimports.append(cimport)
elif groups[4]: elif extern:
includes.append(literals[groups[5]]) externs.append(literals[extern])
else: else:
externs.append(literals[groups[7]]) includes.append(literals[include])
return cimports, includes, externs, distutils_info return cimports, includes, externs, distutils_info
...@@ -306,9 +294,19 @@ class DependencyTree(object): ...@@ -306,9 +294,19 @@ class DependencyTree(object):
externs = set(externs) externs = set(externs)
for include in includes: for include in includes:
include_path = os.path.join(os.path.dirname(filename), include) include_path = os.path.join(os.path.dirname(filename), include)
if not os.path.exists(include_path): if not path_exists(include_path):
include_path = self.context.find_include_file(include, None) include_path = self.context.find_include_file(include, None)
if include_path: if include_path:
if '.' + os.path.sep in include_path:
path_segments = include_path.split(os.path.sep)
while '.' in path_segments:
path_segments.remove('.')
while '..' in path_segments:
ix = path_segments.index('..')
if ix == 0:
break
del path_segments[ix-1:ix+1]
include_path = os.path.sep.join(path_segments)
a, b = self.cimports_and_externs(include_path) a, b = self.cimports_and_externs(include_path)
cimports.update(a) cimports.update(a)
externs.update(b) externs.update(b)
...@@ -322,7 +320,7 @@ class DependencyTree(object): ...@@ -322,7 +320,7 @@ class DependencyTree(object):
@cached_method @cached_method
def package(self, filename): def package(self, filename):
dir = os.path.dirname(os.path.abspath(str(filename))) dir = os.path.dirname(os.path.abspath(str(filename)))
if dir != filename and os.path.exists(os.path.join(dir, '__init__.py')): if dir != filename and path_exists(os.path.join(dir, '__init__.py')):
return self.package(dir) + (os.path.basename(dir),) return self.package(dir) + (os.path.basename(dir),)
else: else:
return () return ()
...@@ -345,17 +343,20 @@ class DependencyTree(object): ...@@ -345,17 +343,20 @@ class DependencyTree(object):
@cached_method @cached_method
def cimported_files(self, filename): def cimported_files(self, filename):
if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'): if filename[-4:] == '.pyx' and path_exists(filename[:-4] + '.pxd'):
self_pxd = [filename[:-4] + '.pxd'] pxd_list = [filename[:-4] + '.pxd']
else: else:
self_pxd = [] pxd_list = []
a = list(x for x in self.cimports(filename) if x.split('.')[0] != 'cython') for module in self.cimports(filename):
b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)]) if module[:7] == 'cython.':
if len(a) != len(b): continue
pxd_file = self.find_pxd(module, filename)
if pxd_file is None:
print("missing cimport: %s" % filename) print("missing cimport: %s" % filename)
print("\n\t".join(a)) print(module)
print("\n\t".join(b)) else:
return tuple(self_pxd + filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])) pxd_list.append(pxd_file)
return tuple(pxd_list)
def immediate_dependencies(self, filename): def immediate_dependencies(self, filename):
all = list(self.cimported_files(filename)) all = list(self.cimported_files(filename))
......
...@@ -35,10 +35,10 @@ class TestStripLiterals(CythonTest): ...@@ -35,10 +35,10 @@ class TestStripLiterals(CythonTest):
self.t("u'abc'", "u'_L1_'") self.t("u'abc'", "u'_L1_'")
def test_raw(self): def test_raw(self):
self.t(r"r'abc\'", "r'_L1_'") self.t(r"r'abc\\'", "r'_L1_'")
def test_raw_unicode(self): def test_raw_unicode(self):
self.t(r"ru'abc\'", "ru'_L1_'") self.t(r"ru'abc\\'", "ru'_L1_'")
def test_comment(self): def test_comment(self):
self.t("abc # foo", "abc #_L1_") self.t("abc # foo", "abc #_L1_")
......
...@@ -222,65 +222,14 @@ class Context(object): ...@@ -222,65 +222,14 @@ class Context(object):
def search_include_directories(self, qualified_name, suffix, pos, def search_include_directories(self, qualified_name, suffix, pos,
include=False, sys_path=False): include=False, sys_path=False):
# Search the list of include directories for the given return Utils.search_include_directories(
# file name. If a source file position is given, first tuple(self.include_directories), qualified_name, suffix, pos, include, sys_path)
# searches the directory containing that file. Returns
# None if not found, but does not report an error.
# The 'include' option will disable package dereferencing.
# If 'sys_path' is True, also search sys.path.
dirs = self.include_directories
if sys_path:
dirs = dirs + sys.path
if pos:
file_desc = pos[0]
if not isinstance(file_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
if include:
dirs = [os.path.dirname(file_desc.filename)] + dirs
else:
dirs = [self.find_root_package_dir(file_desc.filename)] + dirs
dotted_filename = qualified_name
if suffix:
dotted_filename += suffix
if not include:
names = qualified_name.split('.')
package_names = names[:-1]
module_name = names[-1]
module_filename = module_name + suffix
package_filename = "__init__" + suffix
for dir in dirs:
path = os.path.join(dir, dotted_filename)
if Utils.path_exists(path):
return path
if not include:
package_dir = self.check_package_dir(dir, package_names)
if package_dir is not None:
path = os.path.join(package_dir, module_filename)
if Utils.path_exists(path):
return path
path = os.path.join(dir, package_dir, module_name,
package_filename)
if Utils.path_exists(path):
return path
return None
def find_root_package_dir(self, file_path): def find_root_package_dir(self, file_path):
dir = os.path.dirname(file_path) return Utils.find_root_package_dir(file_path)
while self.is_package_dir(dir):
parent = os.path.dirname(dir)
if parent == dir:
break
dir = parent
return dir
def check_package_dir(self, dir, package_names): def check_package_dir(self, dir, package_names):
for dirname in package_names: return Utils.check_package_dir(dir, tuple(package_names))
dir = os.path.join(dir, dirname)
if not self.is_package_dir(dir):
return None
return dir
def c_file_out_of_date(self, source_path): def c_file_out_of_date(self, source_path):
c_path = Utils.replace_suffix(source_path, ".c") c_path = Utils.replace_suffix(source_path, ".c")
...@@ -309,13 +258,7 @@ class Context(object): ...@@ -309,13 +258,7 @@ class Context(object):
if kind == "cimport" ] if kind == "cimport" ]
def is_package_dir(self, dir_path): def is_package_dir(self, dir_path):
# Return true if the given directory is a package directory. return Utils.is_package_dir(dir_path)
for filename in ("__init__.py",
"__init__.pyx",
"__init__.pxd"):
path = os.path.join(dir_path, filename)
if Utils.path_exists(path):
return 1
def read_dependency_file(self, source_path): def read_dependency_file(self, source_path):
dep_path = Utils.replace_suffix(source_path, ".dep") dep_path = Utils.replace_suffix(source_path, ".dep")
......
...@@ -7,6 +7,29 @@ import os, sys, re, codecs ...@@ -7,6 +7,29 @@ import os, sys, re, codecs
modification_time = os.path.getmtime modification_time = os.path.getmtime
def cached_function(f):
cache = {}
uncomputed = object()
def wrapper(*args):
res = cache.get(args, uncomputed)
if res is uncomputed:
res = cache[args] = f(*args)
return res
return wrapper
def cached_method(f):
cache_name = '__%s_cache' % f.__name__
def wrapper(self, *args):
cache = getattr(self, cache_name, None)
if cache is None:
cache = {}
setattr(self, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(self, *args)
return res
return wrapper
def replace_suffix(path, newsuf): def replace_suffix(path, newsuf):
base, _ = os.path.splitext(path) base, _ = os.path.splitext(path)
return base + newsuf return base + newsuf
...@@ -43,6 +66,82 @@ def file_newer_than(path, time): ...@@ -43,6 +66,82 @@ def file_newer_than(path, time):
ftime = modification_time(path) ftime = modification_time(path)
return ftime > time return ftime > time
@cached_function
def search_include_directories(dirs, qualified_name, suffix, pos,
include=False, sys_path=False):
# Search the list of include directories for the given
# file name. If a source file position is given, first
# searches the directory containing that file. Returns
# None if not found, but does not report an error.
# The 'include' option will disable package dereferencing.
# If 'sys_path' is True, also search sys.path.
if sys_path:
dirs = dirs + tuple(sys.path)
if pos:
file_desc = pos[0]
from Cython.Compiler.Scanning import FileSourceDescriptor
if not isinstance(file_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
if include:
dirs = (os.path.dirname(file_desc.filename),) + dirs
else:
dirs = (find_root_package_dir(file_desc.filename),) + dirs
dotted_filename = qualified_name
if suffix:
dotted_filename += suffix
if not include:
names = qualified_name.split('.')
package_names = tuple(names[:-1])
module_name = names[-1]
module_filename = module_name + suffix
package_filename = "__init__" + suffix
for dir in dirs:
path = os.path.join(dir, dotted_filename)
if path_exists(path):
return path
if not include:
package_dir = check_package_dir(dir, package_names)
if package_dir is not None:
path = os.path.join(package_dir, module_filename)
if path_exists(path):
return path
path = os.path.join(dir, package_dir, module_name,
package_filename)
if path_exists(path):
return path
return None
@cached_function
def find_root_package_dir(file_path):
dir = os.path.dirname(file_path)
while is_package_dir(dir):
parent = os.path.dirname(dir)
if parent == dir:
break
dir = parent
return dir
@cached_function
def check_package_dir(dir, package_names):
for dirname in package_names:
dir = os.path.join(dir, dirname)
if not is_package_dir(dir):
return None
return dir
@cached_function
def is_package_dir(dir_path):
for filename in ("__init__.py",
"__init__.pyx",
"__init__.pxd"):
path = os.path.join(dir_path, filename)
if path_exists(path):
return 1
@cached_function
def path_exists(path): def path_exists(path):
# try on the filesystem first # try on the filesystem first
if os.path.exists(path): if os.path.exists(path):
...@@ -85,9 +184,26 @@ def decode_filename(filename): ...@@ -85,9 +184,26 @@ def decode_filename(filename):
_match_file_encoding = re.compile(u"coding[:=]\s*([-\w.]+)").search _match_file_encoding = re.compile(u"coding[:=]\s*([-\w.]+)").search
def detect_file_encoding(source_filename): def detect_file_encoding(source_filename):
# PEPs 263 and 3120
f = open_source_file(source_filename, encoding="UTF-8", error_handling='ignore') f = open_source_file(source_filename, encoding="UTF-8", error_handling='ignore')
try: try:
return detect_opened_file_encoding(f)
finally:
f.close()
def detect_opened_file_encoding(f):
# PEPs 263 and 3120
# Most of the time the first two lines fall in the first 250 chars,
# and this bulk read/split is much faster.
lines = f.read(250).split("\n")
if len(lines) > 2:
m = _match_file_encoding(lines[0]) or _match_file_encoding(lines[1])
if m:
return m.group(1)
else:
return "UTF-8"
else:
# Fallback to one-char-at-a-time detection.
f.seek(0)
chars = [] chars = []
for i in range(2): for i in range(2):
c = f.read(1) c = f.read(1)
...@@ -97,8 +213,6 @@ def detect_file_encoding(source_filename): ...@@ -97,8 +213,6 @@ def detect_file_encoding(source_filename):
encoding = _match_file_encoding(u''.join(chars)) encoding = _match_file_encoding(u''.join(chars))
if encoding: if encoding:
return encoding.group(1) return encoding.group(1)
finally:
f.close()
return "UTF-8" return "UTF-8"
normalise_newlines = re.compile(u'\r\n?|\n').sub normalise_newlines = re.compile(u'\r\n?|\n').sub
...@@ -111,6 +225,7 @@ class NormalisedNewlineStream(object): ...@@ -111,6 +225,7 @@ class NormalisedNewlineStream(object):
""" """
def __init__(self, stream): def __init__(self, stream):
# let's assume .read() doesn't change # let's assume .read() doesn't change
self.stream = stream
self._read = stream.read self._read = stream.read
self.close = stream.close self.close = stream.close
self.encoding = getattr(stream, 'encoding', 'UTF-8') self.encoding = getattr(stream, 'encoding', 'UTF-8')
...@@ -133,6 +248,12 @@ class NormalisedNewlineStream(object): ...@@ -133,6 +248,12 @@ class NormalisedNewlineStream(object):
return u''.join(content).splitlines(True) return u''.join(content).splitlines(True)
def seek(self, pos):
if pos == 0:
self.stream.seek(0)
else:
raise NotImplementedError
io = None io = None
if sys.version_info >= (2,6): if sys.version_info >= (2,6):
try: try:
...@@ -144,8 +265,17 @@ def open_source_file(source_filename, mode="r", ...@@ -144,8 +265,17 @@ def open_source_file(source_filename, mode="r",
encoding=None, error_handling=None, encoding=None, error_handling=None,
require_normalised_newlines=True): require_normalised_newlines=True):
if encoding is None: if encoding is None:
encoding = detect_file_encoding(source_filename) # Most of the time the coding is unspecified, so be optimistic that
# it's UTF-8.
f = open_source_file(source_filename, encoding="UTF-8", mode=mode, error_handling='ignore')
encoding = detect_opened_file_encoding(f)
if encoding == "UTF-8" and error_handling=='ignore' and require_normalised_newlines:
f.seek(0)
return f
else:
f.close()
# #
if not os.path.exists(source_filename):
try: try:
loader = __loader__ loader = __loader__
if source_filename.startswith(loader.archive): if source_filename.startswith(loader.archive):
......
...@@ -116,6 +116,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan ...@@ -116,6 +116,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan
] ]
if compile_more: if compile_more:
compiled_modules.extend([ compiled_modules.extend([
"Cython.Build.Dependencies",
"Cython.Compiler.ParseTreeTransforms", "Cython.Compiler.ParseTreeTransforms",
"Cython.Compiler.Nodes", "Cython.Compiler.Nodes",
"Cython.Compiler.ExprNodes", "Cython.Compiler.ExprNodes",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment