Commit ba47032b authored by Stefan Behnel's avatar Stefan Behnel

remove redundant importing overhead for newly built test module in test runner...

remove redundant importing overhead for newly built test module in test runner and make sure we always import exactly the one we just built
parent 5b698407
...@@ -293,6 +293,20 @@ def parse_tags(filepath): ...@@ -293,6 +293,20 @@ def parse_tags(filepath):
list_unchanging_dir = memoize(lambda x: os.listdir(x)) list_unchanging_dir = memoize(lambda x: os.listdir(x))
def import_ext(module_name, file_path=None):
if file_path:
import imp
return imp.load_dynamic(module_name, file_path)
else:
try:
from importlib import invalidate_caches
except ImportError:
pass
else:
invalidate_caches()
return __import__(module_name, globals(), locals(), ['*'])
class build_ext(_build_ext): class build_ext(_build_ext):
def build_extension(self, ext): def build_extension(self, ext):
try: try:
...@@ -554,8 +568,9 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -554,8 +568,9 @@ class CythonCompileTestCase(unittest.TestCase):
self.success = True self.success = True
def runCompileTest(self): def runCompileTest(self):
self.compile(self.test_directory, self.module, self.workdir, return self.compile(
self.test_directory, self.expect_errors, self.annotate) self.test_directory, self.module, self.workdir,
self.test_directory, self.expect_errors, self.annotate)
def find_module_source_file(self, source_file): def find_module_source_file(self, source_file):
if not os.path.exists(source_file): if not os.path.exists(source_file):
...@@ -698,6 +713,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -698,6 +713,8 @@ class CythonCompileTestCase(unittest.TestCase):
finally: finally:
os.chdir(cwd) os.chdir(cwd)
return build_extension.get_ext_fullpath(module)
def compile(self, test_directory, module, workdir, incdir, def compile(self, test_directory, module, workdir, incdir,
expect_errors, annotate): expect_errors, annotate):
expected_errors = errors = () expected_errors = errors = ()
...@@ -732,9 +749,13 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -732,9 +749,13 @@ class CythonCompileTestCase(unittest.TestCase):
print('\n'.join(errors)) print('\n'.join(errors))
print('\n') print('\n')
raise raise
return None
if self.cython_only:
so_path = None
else: else:
if not self.cython_only: so_path = self.run_distutils(test_directory, module, workdir, incdir)
self.run_distutils(test_directory, module, workdir, incdir) return so_path
class CythonRunTestCase(CythonCompileTestCase): class CythonRunTestCase(CythonCompileTestCase):
def shortDescription(self): def shortDescription(self):
...@@ -751,10 +772,10 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -751,10 +772,10 @@ class CythonRunTestCase(CythonCompileTestCase):
self.setUp() self.setUp()
try: try:
self.success = False self.success = False
self.runCompileTest() ext_so_path = self.runCompileTest()
failures, errors = len(result.failures), len(result.errors) failures, errors = len(result.failures), len(result.errors)
if not self.cython_only: if not self.cython_only:
self.run_tests(result) self.run_tests(result, ext_so_path)
if failures == len(result.failures) and errors == len(result.errors): if failures == len(result.failures) and errors == len(result.errors):
# No new errors... # No new errors...
self.success = True self.success = True
...@@ -768,12 +789,13 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -768,12 +789,13 @@ class CythonRunTestCase(CythonCompileTestCase):
except Exception: except Exception:
pass pass
def run_tests(self, result): def run_tests(self, result, ext_so_path):
self.run_doctests(self.module, result) self.run_doctests(result, ext_so_path)
def run_doctests(self, module_name, result): def run_doctests(self, result, ext_so_path):
def run_test(result): def run_test(result):
tests = doctest.DocTestSuite(module_name) module = import_ext(self.module, ext_so_path)
tests = doctest.DocTestSuite(module)
tests.run(result) tests.run(result)
run_forked_test(result, run_test, self.shortDescription(), self.fork) run_forked_test(result, run_test, self.shortDescription(), self.fork)
...@@ -932,8 +954,9 @@ class CythonUnitTestCase(CythonRunTestCase): ...@@ -932,8 +954,9 @@ class CythonUnitTestCase(CythonRunTestCase):
def shortDescription(self): def shortDescription(self):
return "compiling (%s) tests in %s" % (self.language, self.module) return "compiling (%s) tests in %s" % (self.language, self.module)
def run_tests(self, result): def run_tests(self, result, ext_so_path):
unittest.defaultTestLoader.loadTestsFromName(self.module).run(result) module = import_ext(self.module, ext_so_path)
unittest.defaultTestLoader.loadTestsFromModule(module).run(result)
class CythonPyregrTestCase(CythonRunTestCase): class CythonPyregrTestCase(CythonRunTestCase):
...@@ -965,7 +988,7 @@ class CythonPyregrTestCase(CythonRunTestCase): ...@@ -965,7 +988,7 @@ class CythonPyregrTestCase(CythonRunTestCase):
def _run_doctest(self, result, module): def _run_doctest(self, result, module):
self.run_doctests(module, result) self.run_doctests(module, result)
def run_tests(self, result): def run_tests(self, result, ext_so_path):
try: try:
from test import support from test import support
except ImportError: # Python2.x except ImportError: # Python2.x
...@@ -984,7 +1007,7 @@ class CythonPyregrTestCase(CythonRunTestCase): ...@@ -984,7 +1007,7 @@ class CythonPyregrTestCase(CythonRunTestCase):
try: try:
try: try:
sys.stdout.flush() # helps in case of crashes sys.stdout.flush() # helps in case of crashes
module = __import__(self.module) module = import_ext(self.module, ext_so_path)
sys.stdout.flush() # helps in case of crashes sys.stdout.flush() # helps in case of crashes
if hasattr(module, 'test_main'): if hasattr(module, 'test_main'):
module.test_main() module.test_main()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment