Commit b69b4a5b authored by Robert Bradshaw's avatar Robert Bradshaw

cythonize performance improvements for large codebases

parent 9560cbec
# cython: profile=True
import cython
from glob import glob
import re, os, sys
......@@ -5,33 +9,10 @@ import re, os, sys
from distutils.extension import Extension
from Cython import Utils
from Cython.Utils import cached_function, cached_method, path_exists
from Cython.Compiler.Main import Context, CompilationOptions, default_options
def cached_function(f):
cache_name = '__%s_cache' % f.__name__
def wrapper(*args):
cache = getattr(f, cache_name, None)
if cache is None:
cache = {}
setattr(f, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(*args)
return res
return wrapper
def cached_method(f):
cache_name = '__%s_cache' % f.__name__
def wrapper(self, *args):
cache = getattr(self, cache_name, None)
if cache is None:
cache = {}
setattr(self, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(self, *args)
return res
return wrapper
os.path.join = cached_function(os.path.join)
def extended_iglob(pattern):
if '**/' in pattern:
......@@ -97,6 +78,7 @@ distutils_settings = {
'language': transitive_str,
}
@cython.locals(start=long, end=long)
def line_iter(source):
start = 0
while True:
......@@ -175,7 +157,8 @@ class DistutilsInfo(object):
resolved.values[key] = value
return resolved
@cython.locals(start=long, q=long, single_q=long, double_q=long, hash_mark=long,
end=long, k=long, counter=long, quote_len=long)
def strip_string_literals(code, prefix='__Pyx_L'):
"""
Normalizes every string literal to be of the form '__Pyx_Lxxx',
......@@ -187,11 +170,16 @@ def strip_string_literals(code, prefix='__Pyx_L'):
counter = 0
start = q = 0
in_quote = False
raw = False
hash_mark = single_q = double_q = -1
code_len = len(code)
while True:
hash_mark = code.find('#', q)
single_q = code.find("'", q)
double_q = code.find('"', q)
if hash_mark < q:
hash_mark = code.find('#', q)
if single_q < q:
single_q = code.find("'", q)
if double_q < q:
double_q = code.find('"', q)
q = min(single_q, double_q)
if q == -1: q = max(single_q, double_q)
......@@ -202,19 +190,22 @@ def strip_string_literals(code, prefix='__Pyx_L'):
# Try to close the quote.
elif in_quote:
if code[q-1] == '\\' and not raw:
if code[q-1] == u'\\':
k = 2
while q >= k and code[q-k] == '\\':
while q >= k and code[q-k] == u'\\':
k += 1
if k % 2 == 0:
q += 1
continue
if code[q:q+len(in_quote)] == in_quote:
if code[q] == quote_type and (quote_len == 1 or (code_len > q + 2 and quote_type == code[q+1] == code[q+2])):
counter += 1
label = "%s%s_" % (prefix, counter)
literals[label] = code[start+len(in_quote):q]
new_code.append("%s%s%s" % (in_quote, label, in_quote))
q += len(in_quote)
literals[label] = code[start+quote_len:q]
full_quote = code[q:q+quote_len]
new_code.append(full_quote)
new_code.append(label)
new_code.append(full_quote)
q += quote_len
in_quote = False
start = q
else:
......@@ -222,70 +213,67 @@ def strip_string_literals(code, prefix='__Pyx_L'):
# Process comment.
elif -1 != hash_mark and (hash_mark < q or q == -1):
end = code.find('\n', hash_mark)
if end == -1:
end = None
new_code.append(code[start:hash_mark+1])
end = code.find('\n', hash_mark)
counter += 1
label = "%s%s_" % (prefix, counter)
literals[label] = code[hash_mark+1:end]
if end == -1:
end_or_none = None
else:
end_or_none = end
literals[label] = code[hash_mark+1:end_or_none]
new_code.append(label)
if end is None:
if end == -1:
break
q = end
start = q
start = q = end
# Open the quote.
else:
raw = False
if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
in_quote = code[q]*3
if code_len >= q+3 and (code[q] == code[q+1] == code[q+2]):
quote_len = 3
else:
in_quote = code[q]
end = marker = q
while marker > 0 and code[marker-1] in 'rRbBuU':
if code[marker-1] in 'rR':
raw = True
marker -= 1
new_code.append(code[start:end])
quote_len = 1
in_quote = True
quote_type = code[q]
new_code.append(code[start:q])
start = q
q += len(in_quote)
q += quote_len
return "".join(new_code), literals
dependancy_regex = re.compile(r"(?:^from +([0-9a-zA-Z_.]+) +cimport)|"
r"(?:^cimport +([0-9a-zA-Z_.]+)\b)|"
r"(?:^cdef +extern +from +['\"]([^'\"]+)['\"])|"
r"(?:^include +['\"]([^'\"]+)['\"])", re.M)
def parse_dependencies(source_filename):
# Actual parsing is way to slow, so we use regular expressions.
# The only catch is that we must strip comments and string
# literals ahead of time.
fh = Utils.open_source_file(source_filename, "rU")
fh = Utils.open_source_file(source_filename, "rU", error_handling='ignore')
try:
source = fh.read()
finally:
fh.close()
distutils_info = DistutilsInfo(source)
source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ')
if '\t' in source:
source = source.replace('\t', ' ')
source = source.replace('\\\n', ' ').replace('\t', ' ')
# TODO: pure mode
dependancy = re.compile(r"(cimport +([0-9a-zA-Z_.]+)\b)|"
"(from +([0-9a-zA-Z_.]+) +cimport)|"
"(include +['\"]([^'\"]+)['\"])|"
"(cdef +extern +from +['\"]([^'\"]+)['\"])")
cimports = []
includes = []
externs = []
for m in dependancy.finditer(source):
groups = m.groups()
if groups[0]:
cimports.append(groups[1])
elif groups[2]:
cimports.append(groups[3])
elif groups[4]:
includes.append(literals[groups[5]])
for m in dependancy_regex.finditer(source):
cimport_from, cimport, extern, include = m.groups()
if cimport_from:
cimports.append(cimport_from)
elif cimport:
cimports.append(cimport)
elif extern:
externs.append(literals[extern])
else:
externs.append(literals[groups[7]])
includes.append(literals[include])
return cimports, includes, externs, distutils_info
......@@ -306,9 +294,19 @@ class DependencyTree(object):
externs = set(externs)
for include in includes:
include_path = os.path.join(os.path.dirname(filename), include)
if not os.path.exists(include_path):
if not path_exists(include_path):
include_path = self.context.find_include_file(include, None)
if include_path:
if '.' + os.path.sep in include_path:
path_segments = include_path.split(os.path.sep)
while '.' in path_segments:
path_segments.remove('.')
while '..' in path_segments:
ix = path_segments.index('..')
if ix == 0:
break
del path_segments[ix-1:ix+1]
include_path = os.path.sep.join(path_segments)
a, b = self.cimports_and_externs(include_path)
cimports.update(a)
externs.update(b)
......@@ -322,7 +320,7 @@ class DependencyTree(object):
@cached_method
def package(self, filename):
dir = os.path.dirname(os.path.abspath(str(filename)))
if dir != filename and os.path.exists(os.path.join(dir, '__init__.py')):
if dir != filename and path_exists(os.path.join(dir, '__init__.py')):
return self.package(dir) + (os.path.basename(dir),)
else:
return ()
......@@ -345,17 +343,20 @@ class DependencyTree(object):
@cached_method
def cimported_files(self, filename):
if filename[-4:] == '.pyx' and os.path.exists(filename[:-4] + '.pxd'):
self_pxd = [filename[:-4] + '.pxd']
if filename[-4:] == '.pyx' and path_exists(filename[:-4] + '.pxd'):
pxd_list = [filename[:-4] + '.pxd']
else:
self_pxd = []
a = list(x for x in self.cimports(filename) if x.split('.')[0] != 'cython')
b = filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)])
if len(a) != len(b):
print("missing cimport: %s" % filename)
print("\n\t".join(a))
print("\n\t".join(b))
return tuple(self_pxd + filter(None, [self.find_pxd(m, filename) for m in self.cimports(filename)]))
pxd_list = []
for module in self.cimports(filename):
if module[:7] == 'cython.':
continue
pxd_file = self.find_pxd(module, filename)
if pxd_file is None:
print("missing cimport: %s" % filename)
print(module)
else:
pxd_list.append(pxd_file)
return tuple(pxd_list)
def immediate_dependencies(self, filename):
all = list(self.cimported_files(filename))
......
......@@ -35,10 +35,10 @@ class TestStripLiterals(CythonTest):
self.t("u'abc'", "u'_L1_'")
def test_raw(self):
self.t(r"r'abc\'", "r'_L1_'")
self.t(r"r'abc\\'", "r'_L1_'")
def test_raw_unicode(self):
self.t(r"ru'abc\'", "ru'_L1_'")
self.t(r"ru'abc\\'", "ru'_L1_'")
def test_comment(self):
self.t("abc # foo", "abc #_L1_")
......
......@@ -222,65 +222,14 @@ class Context(object):
def search_include_directories(self, qualified_name, suffix, pos,
include=False, sys_path=False):
# Search the list of include directories for the given
# file name. If a source file position is given, first
# searches the directory containing that file. Returns
# None if not found, but does not report an error.
# The 'include' option will disable package dereferencing.
# If 'sys_path' is True, also search sys.path.
dirs = self.include_directories
if sys_path:
dirs = dirs + sys.path
if pos:
file_desc = pos[0]
if not isinstance(file_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
if include:
dirs = [os.path.dirname(file_desc.filename)] + dirs
else:
dirs = [self.find_root_package_dir(file_desc.filename)] + dirs
dotted_filename = qualified_name
if suffix:
dotted_filename += suffix
if not include:
names = qualified_name.split('.')
package_names = names[:-1]
module_name = names[-1]
module_filename = module_name + suffix
package_filename = "__init__" + suffix
for dir in dirs:
path = os.path.join(dir, dotted_filename)
if Utils.path_exists(path):
return path
if not include:
package_dir = self.check_package_dir(dir, package_names)
if package_dir is not None:
path = os.path.join(package_dir, module_filename)
if Utils.path_exists(path):
return path
path = os.path.join(dir, package_dir, module_name,
package_filename)
if Utils.path_exists(path):
return path
return None
return Utils.search_include_directories(
tuple(self.include_directories), qualified_name, suffix, pos, include, sys_path)
def find_root_package_dir(self, file_path):
dir = os.path.dirname(file_path)
while self.is_package_dir(dir):
parent = os.path.dirname(dir)
if parent == dir:
break
dir = parent
return dir
return Utils.find_root_package_dir(file_path)
def check_package_dir(self, dir, package_names):
for dirname in package_names:
dir = os.path.join(dir, dirname)
if not self.is_package_dir(dir):
return None
return dir
return Utils.check_package_dir(dir, tuple(package_names))
def c_file_out_of_date(self, source_path):
c_path = Utils.replace_suffix(source_path, ".c")
......@@ -309,13 +258,7 @@ class Context(object):
if kind == "cimport" ]
def is_package_dir(self, dir_path):
# Return true if the given directory is a package directory.
for filename in ("__init__.py",
"__init__.pyx",
"__init__.pxd"):
path = os.path.join(dir_path, filename)
if Utils.path_exists(path):
return 1
return Utils.is_package_dir(dir_path)
def read_dependency_file(self, source_path):
dep_path = Utils.replace_suffix(source_path, ".dep")
......
......@@ -7,6 +7,29 @@ import os, sys, re, codecs
modification_time = os.path.getmtime
def cached_function(f):
cache = {}
uncomputed = object()
def wrapper(*args):
res = cache.get(args, uncomputed)
if res is uncomputed:
res = cache[args] = f(*args)
return res
return wrapper
def cached_method(f):
cache_name = '__%s_cache' % f.__name__
def wrapper(self, *args):
cache = getattr(self, cache_name, None)
if cache is None:
cache = {}
setattr(self, cache_name, cache)
if args in cache:
return cache[args]
res = cache[args] = f(self, *args)
return res
return wrapper
def replace_suffix(path, newsuf):
base, _ = os.path.splitext(path)
return base + newsuf
......@@ -43,6 +66,82 @@ def file_newer_than(path, time):
ftime = modification_time(path)
return ftime > time
@cached_function
def search_include_directories(dirs, qualified_name, suffix, pos,
include=False, sys_path=False):
# Search the list of include directories for the given
# file name. If a source file position is given, first
# searches the directory containing that file. Returns
# None if not found, but does not report an error.
# The 'include' option will disable package dereferencing.
# If 'sys_path' is True, also search sys.path.
if sys_path:
dirs = dirs + tuple(sys.path)
if pos:
file_desc = pos[0]
from Cython.Compiler.Scanning import FileSourceDescriptor
if not isinstance(file_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported")
if include:
dirs = (os.path.dirname(file_desc.filename),) + dirs
else:
dirs = (find_root_package_dir(file_desc.filename),) + dirs
dotted_filename = qualified_name
if suffix:
dotted_filename += suffix
if not include:
names = qualified_name.split('.')
package_names = tuple(names[:-1])
module_name = names[-1]
module_filename = module_name + suffix
package_filename = "__init__" + suffix
for dir in dirs:
path = os.path.join(dir, dotted_filename)
if path_exists(path):
return path
if not include:
package_dir = check_package_dir(dir, package_names)
if package_dir is not None:
path = os.path.join(package_dir, module_filename)
if path_exists(path):
return path
path = os.path.join(dir, package_dir, module_name,
package_filename)
if path_exists(path):
return path
return None
@cached_function
def find_root_package_dir(file_path):
dir = os.path.dirname(file_path)
while is_package_dir(dir):
parent = os.path.dirname(dir)
if parent == dir:
break
dir = parent
return dir
@cached_function
def check_package_dir(dir, package_names):
for dirname in package_names:
dir = os.path.join(dir, dirname)
if not is_package_dir(dir):
return None
return dir
@cached_function
def is_package_dir(dir_path):
for filename in ("__init__.py",
"__init__.pyx",
"__init__.pxd"):
path = os.path.join(dir_path, filename)
if path_exists(path):
return 1
@cached_function
def path_exists(path):
# try on the filesystem first
if os.path.exists(path):
......@@ -85,9 +184,26 @@ def decode_filename(filename):
_match_file_encoding = re.compile(u"coding[:=]\s*([-\w.]+)").search
def detect_file_encoding(source_filename):
# PEPs 263 and 3120
f = open_source_file(source_filename, encoding="UTF-8", error_handling='ignore')
try:
return detect_opened_file_encoding(f)
finally:
f.close()
def detect_opened_file_encoding(f):
# PEPs 263 and 3120
# Most of the time the first two lines fall in the first 250 chars,
# and this bulk read/split is much faster.
lines = f.read(250).split("\n")
if len(lines) > 2:
m = _match_file_encoding(lines[0]) or _match_file_encoding(lines[1])
if m:
return m.group(1)
else:
return "UTF-8"
else:
# Fallback to one-char-at-a-time detection.
f.seek(0)
chars = []
for i in range(2):
c = f.read(1)
......@@ -97,8 +213,6 @@ def detect_file_encoding(source_filename):
encoding = _match_file_encoding(u''.join(chars))
if encoding:
return encoding.group(1)
finally:
f.close()
return "UTF-8"
normalise_newlines = re.compile(u'\r\n?|\n').sub
......@@ -111,6 +225,7 @@ class NormalisedNewlineStream(object):
"""
def __init__(self, stream):
# let's assume .read() doesn't change
self.stream = stream
self._read = stream.read
self.close = stream.close
self.encoding = getattr(stream, 'encoding', 'UTF-8')
......@@ -133,6 +248,12 @@ class NormalisedNewlineStream(object):
return u''.join(content).splitlines(True)
def seek(self, pos):
if pos == 0:
self.stream.seek(0)
else:
raise NotImplementedError
io = None
if sys.version_info >= (2,6):
try:
......@@ -144,17 +265,26 @@ def open_source_file(source_filename, mode="r",
encoding=None, error_handling=None,
require_normalised_newlines=True):
if encoding is None:
encoding = detect_file_encoding(source_filename)
# Most of the time the coding is unspecified, so be optimistic that
# it's UTF-8.
f = open_source_file(source_filename, encoding="UTF-8", mode=mode, error_handling='ignore')
encoding = detect_opened_file_encoding(f)
if encoding == "UTF-8" and error_handling=='ignore' and require_normalised_newlines:
f.seek(0)
return f
else:
f.close()
#
try:
loader = __loader__
if source_filename.startswith(loader.archive):
return open_source_from_loader(
loader, source_filename,
encoding, error_handling,
require_normalised_newlines)
except (NameError, AttributeError):
pass
if not os.path.exists(source_filename):
try:
loader = __loader__
if source_filename.startswith(loader.archive):
return open_source_from_loader(
loader, source_filename,
encoding, error_handling,
require_normalised_newlines)
except (NameError, AttributeError):
pass
#
if io is not None:
return io.open(source_filename, mode=mode,
......
......@@ -116,6 +116,7 @@ def compile_cython_modules(profile=False, compile_more=False, cython_with_refnan
]
if compile_more:
compiled_modules.extend([
"Cython.Build.Dependencies",
"Cython.Compiler.ParseTreeTransforms",
"Cython.Compiler.Nodes",
"Cython.Compiler.ExprNodes",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment