Commit bbb6680e authored by Brett Cannon's avatar Brett Cannon

Have importlib take advantage of ImportError's new 'name' and 'path'

attributes.
parent 79ec55e9
......@@ -227,7 +227,7 @@ def _check_name(method):
"""
def _check_name_wrapper(self, name, *args, **kwargs):
if self._name != name:
raise ImportError("loader cannot handle %s" % name)
raise ImportError("loader cannot handle %s" % name, name=name)
return method(self, name, *args, **kwargs)
_wrap(_check_name_wrapper, method)
return _check_name_wrapper
......@@ -237,7 +237,8 @@ def _requires_builtin(fxn):
"""Decorator to verify the named module is built-in."""
def _requires_builtin_wrapper(self, fullname):
if fullname not in sys.builtin_module_names:
raise ImportError("{0} is not a built-in module".format(fullname))
raise ImportError("{0} is not a built-in module".format(fullname),
name=fullname)
return fxn(self, fullname)
_wrap(_requires_builtin_wrapper, fxn)
return _requires_builtin_wrapper
......@@ -247,7 +248,8 @@ def _requires_frozen(fxn):
"""Decorator to verify the named module is frozen."""
def _requires_frozen_wrapper(self, fullname):
if not imp.is_frozen(fullname):
raise ImportError("{0} is not a frozen module".format(fullname))
raise ImportError("{0} is not a frozen module".format(fullname),
name=fullname)
return fxn(self, fullname)
_wrap(_requires_frozen_wrapper, fxn)
return _requires_frozen_wrapper
......@@ -372,7 +374,7 @@ class _LoaderBasics:
filename = self.get_filename(fullname).rpartition(path_sep)[2]
return filename.rsplit('.', 1)[0] == '__init__'
def _bytes_from_bytecode(self, fullname, data, source_stats):
def _bytes_from_bytecode(self, fullname, data, bytecode_path, source_stats):
"""Return the marshalled bytes from bytecode, verifying the magic
number, timestamp and source size along the way.
......@@ -383,7 +385,8 @@ class _LoaderBasics:
raw_timestamp = data[4:8]
raw_size = data[8:12]
if len(magic) != 4 or magic != imp.get_magic():
raise ImportError("bad magic number in {}".format(fullname))
raise ImportError("bad magic number in {}".format(fullname),
name=fullname, path=bytecode_path)
elif len(raw_timestamp) != 4:
raise EOFError("bad timestamp in {}".format(fullname))
elif len(raw_size) != 4:
......@@ -396,7 +399,8 @@ class _LoaderBasics:
else:
if _r_long(raw_timestamp) != source_mtime:
raise ImportError(
"bytecode is stale for {}".format(fullname))
"bytecode is stale for {}".format(fullname),
name=fullname, path=bytecode_path)
try:
source_size = source_stats['size'] & 0xFFFFFFFF
except KeyError:
......@@ -404,7 +408,8 @@ class _LoaderBasics:
else:
if _r_long(raw_size) != source_size:
raise ImportError(
"bytecode is stale for {}".format(fullname))
"bytecode is stale for {}".format(fullname),
name=fullname, path=bytecode_path)
# Can't return the code object as errors from marshal loading need to
# propagate even when source is available.
return data[12:]
......@@ -466,7 +471,8 @@ class SourceLoader(_LoaderBasics):
try:
source_bytes = self.get_data(path)
except IOError:
raise ImportError("source not available through get_data()")
raise ImportError("source not available through get_data()",
name=fullname)
encoding = tokenize.detect_encoding(_io.BytesIO(source_bytes).readline)
newline_decoder = _io.IncrementalNewlineDecoder(None, True)
return newline_decoder.decode(source_bytes.decode(encoding[0]))
......@@ -495,6 +501,7 @@ class SourceLoader(_LoaderBasics):
else:
try:
bytes_data = self._bytes_from_bytecode(fullname, data,
bytecode_path,
st)
except (ImportError, EOFError):
pass
......@@ -505,7 +512,8 @@ class SourceLoader(_LoaderBasics):
return found
else:
msg = "Non-code object in {}"
raise ImportError(msg.format(bytecode_path))
raise ImportError(msg.format(bytecode_path),
name=fullname, path=bytecode_path)
source_bytes = self.get_data(source_path)
code_object = compile(source_bytes, source_path, 'exec',
dont_inherit=True)
......@@ -604,12 +612,13 @@ class _SourcelessFileLoader(_FileLoader, _LoaderBasics):
def get_code(self, fullname):
path = self.get_filename(fullname)
data = self.get_data(path)
bytes_data = self._bytes_from_bytecode(fullname, data, None)
bytes_data = self._bytes_from_bytecode(fullname, data, path, None)
found = marshal.loads(bytes_data)
if isinstance(found, code_type):
return found
else:
raise ImportError("Non-code object in {}".format(path))
raise ImportError("Non-code object in {}".format(path),
name=fullname, path=path)
def get_source(self, fullname):
"""Return None as there is no source code."""
......@@ -678,7 +687,8 @@ class PathFinder:
except ImportError:
continue
else:
raise ImportError("no path hook found for {0}".format(path))
raise ImportError("no path hook found for {0}".format(path),
path=path)
@classmethod
def _path_importer_cache(cls, path, default=None):
......@@ -836,7 +846,7 @@ def _file_path_hook(path):
_SourceFinderDetails(),
_SourcelessFinderDetails())
else:
raise ImportError("only directories are supported")
raise ImportError("only directories are supported", path=path)
_DEFAULT_PATH_HOOK = _file_path_hook
......@@ -936,10 +946,10 @@ def _find_and_load(name, import_):
path = parent_module.__path__
except AttributeError:
msg = (_ERR_MSG + '; {} is not a package').format(name, parent)
raise ImportError(msg)
raise ImportError(msg, name=name)
loader = _find_module(name, path)
if loader is None:
raise ImportError(_ERR_MSG.format(name))
raise ImportError(_ERR_MSG.format(name), name=name)
elif name not in sys.modules:
# The parent import may have already imported this module.
loader.load_module(name)
......@@ -978,7 +988,7 @@ def _gcd_import(name, package=None, level=0):
if module is None:
message = ("import of {} halted; "
"None in sys.modules".format(name))
raise ImportError(message)
raise ImportError(message, name=name)
return module
except KeyError:
pass # Don't want to chain the exception
......
......@@ -207,7 +207,7 @@ class PyLoader(SourceLoader):
DeprecationWarning)
path = self.source_path(fullname)
if path is None:
raise ImportError
raise ImportError(name=fullname)
else:
return path
......@@ -235,7 +235,7 @@ class PyPycLoader(PyLoader):
if path is not None:
return path
raise ImportError("no source or bytecode path available for "
"{0!r}".format(fullname))
"{0!r}".format(fullname), name=fullname)
def get_code(self, fullname):
"""Get a code object from source or bytecode."""
......@@ -253,7 +253,8 @@ class PyPycLoader(PyLoader):
magic = data[:4]
if len(magic) < 4:
raise ImportError(
"bad magic number in {}".format(fullname))
"bad magic number in {}".format(fullname),
name=fullname, path=bytecode_path)
raw_timestamp = data[4:8]
if len(raw_timestamp) < 4:
raise EOFError("bad timestamp in {}".format(fullname))
......@@ -262,12 +263,14 @@ class PyPycLoader(PyLoader):
# Verify that the magic number is valid.
if imp.get_magic() != magic:
raise ImportError(
"bad magic number in {}".format(fullname))
"bad magic number in {}".format(fullname),
name=fullname, path=bytecode_path)
# Verify that the bytecode is not stale (only matters when
# there is source to fall back on.
if source_timestamp:
if pyc_timestamp < source_timestamp:
raise ImportError("bytecode is stale")
raise ImportError("bytecode is stale", name=fullname,
path=bytecode_path)
except (ImportError, EOFError):
# If source is available give it a shot.
if source_timestamp is not None:
......@@ -279,12 +282,13 @@ class PyPycLoader(PyLoader):
return marshal.loads(bytecode)
elif source_timestamp is None:
raise ImportError("no source or bytecode available to create code "
"object for {0!r}".format(fullname))
"object for {0!r}".format(fullname),
name=fullname)
# Use the source.
source_path = self.source_path(fullname)
if source_path is None:
message = "a source path must exist to load {0}".format(fullname)
raise ImportError(message)
raise ImportError(message, name=fullname)
source = self.get_data(source_path)
code_object = compile(source, source_path, 'exec', dont_inherit=True)
# Generate bytecode and write it out.
......
......@@ -54,15 +54,17 @@ class LoaderTests(abc.LoaderTests):
def test_unloadable(self):
name = 'dssdsdfff'
assert name not in sys.builtin_module_names
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.load_module(name)
self.assertEqual(cm.exception.name, name)
def test_already_imported(self):
# Using the name of a module already imported but not a built-in should
# still fail.
assert hasattr(importlib, '__file__') # Not a built-in.
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.load_module('importlib')
self.assertEqual(cm.exception.name, 'importlib')
class InspectLoaderTests(unittest.TestCase):
......@@ -88,8 +90,9 @@ class InspectLoaderTests(unittest.TestCase):
# Modules not built-in should raise ImportError.
for meth_name in ('get_code', 'get_source', 'is_package'):
method = getattr(machinery.BuiltinImporter, meth_name)
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
method(builtin_util.BAD_NAME)
self.assertRaises(builtin_util.BAD_NAME)
......
......@@ -46,8 +46,10 @@ class LoaderTests(abc.LoaderTests):
pass
def test_unloadable(self):
with self.assertRaises(ImportError):
self.load_module('asdfjkl;')
name = 'asdfjkl;'
with self.assertRaises(ImportError) as cm:
self.load_module(name)
self.assertEqual(cm.exception.name, name)
def test_main():
......
......@@ -57,8 +57,9 @@ class LoaderTests(abc.LoaderTests):
def test_unloadable(self):
assert machinery.FrozenImporter.find_module('_not_real') is None
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
machinery.FrozenImporter.load_module('_not_real')
self.assertEqual(cm.exception.name, '_not_real')
class InspectLoaderTests(unittest.TestCase):
......@@ -92,8 +93,9 @@ class InspectLoaderTests(unittest.TestCase):
# Raise ImportError for modules that are not frozen.
for meth_name in ('get_code', 'get_source', 'is_package'):
method = getattr(machinery.FrozenImporter, meth_name)
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
method('importlib')
self.assertEqual(cm.exception.name, 'importlib')
def test_main():
......
......@@ -34,8 +34,9 @@ class UseCache(unittest.TestCase):
name = 'using_None'
with util.uncache(name):
sys.modules[name] = None
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
import_util.import_(name)
self.assertEqual(cm.exception.name, name)
def create_mock(self, *names, return_=None):
mock = util.mock_modules(*names)
......
......@@ -19,14 +19,16 @@ class ParentModuleTests(unittest.TestCase):
def test_bad_parent(self):
with util.mock_modules('pkg.module') as mock:
with util.import_state(meta_path=[mock]):
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
import_util.import_('pkg.module')
self.assertEqual(cm.exception.name, 'pkg')
def test_module_not_package(self):
# Try to import a submodule from a non-package should raise ImportError.
assert not hasattr(sys, '__path__')
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
import_util.import_('sys.no_submodules_here')
self.assertEqual(cm.exception.name, 'sys.no_submodules_here')
def test_module_not_package_but_side_effects(self):
# If a module injects something into sys.modules as a side-effect, then
......
......@@ -471,8 +471,9 @@ class BadBytecodeFailureTests(unittest.TestCase):
{'path': os.path.join('path', 'to', 'mod'),
'magic': bad_magic}}
mock = PyPycLoaderMock({name: None}, bc)
with util.uncache(name), self.assertRaises(ImportError):
with util.uncache(name), self.assertRaises(ImportError) as cm:
mock.load_module(name)
self.assertEqual(cm.exception.name, name)
def test_no_bytecode(self):
# Missing code object bytecode should lead to an EOFError.
......@@ -516,8 +517,9 @@ class MissingPathsTests(unittest.TestCase):
# If all *_path methods return None, raise ImportError.
name = 'mod'
mock = PyPycLoaderMock({name: None})
with util.uncache(name), self.assertRaises(ImportError):
with util.uncache(name), self.assertRaises(ImportError) as cm:
mock.load_module(name)
self.assertEqual(cm.exception.name, name)
def test_source_path_ImportError(self):
# An ImportError from source_path should trigger an ImportError.
......@@ -533,7 +535,7 @@ class MissingPathsTests(unittest.TestCase):
mock = PyPycLoaderMock({name: os.path.join('path', 'to', 'mod')})
bad_meth = types.MethodType(raise_ImportError, mock)
mock.bytecode_path = bad_meth
with util.uncache(name), self.assertRaises(ImportError):
with util.uncache(name), self.assertRaises(ImportError) as cm:
mock.load_module(name)
......@@ -594,8 +596,9 @@ class SourceOnlyLoaderTests(SourceLoaderTestHarness):
def raise_IOError(path):
raise IOError
self.loader.get_data = raise_IOError
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.loader.get_source(self.name)
self.assertEqual(cm.exception.name, self.name)
def test_is_package(self):
# Properly detect when loading a package.
......
......@@ -232,8 +232,10 @@ class BadBytecodeTest(unittest.TestCase):
lambda bc: bc[:12] + marshal.dumps(b'abcd'),
del_source=del_source)
file_path = mapping['_temp'] if not del_source else bytecode_path
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.import_(file_path, '_temp')
self.assertEqual(cm.exception.name, '_temp')
self.assertEqual(cm.exception.path, bytecode_path)
def _test_bad_marshal(self, *, del_source=False):
with source_util.create_modules('_temp') as mapping:
......@@ -381,15 +383,19 @@ class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
def test_empty_file(self):
def test(name, mapping, bytecode_path):
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.import_(bytecode_path, name)
self.assertEqual(cm.exception.name, name)
self.assertEqual(cm.exception.path, bytecode_path)
self._test_empty_file(test, del_source=True)
def test_partial_magic(self):
def test(name, mapping, bytecode_path):
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.import_(bytecode_path, name)
self.assertEqual(cm.exception.name, name)
self.assertEqual(cm.exception.path, bytecode_path)
self._test_partial_magic(test, del_source=True)
def test_magic_only(self):
......@@ -401,8 +407,10 @@ class SourcelessLoaderBadBytecodeTest(BadBytecodeTest):
def test_bad_magic(self):
def test(name, mapping, bytecode_path):
with self.assertRaises(ImportError):
with self.assertRaises(ImportError) as cm:
self.import_(bytecode_path, name)
self.assertEqual(cm.exception.name, name)
self.assertEqual(cm.exception.path, bytecode_path)
self._test_bad_magic(test, del_source=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment