Commit 06c9d96b authored by Brett Cannon's avatar Brett Cannon

Move importlib completely over to using rpartition and accepting the empty

string for top-level modules.
parent d94e558f
to do to do
///// /////
* Use rpartition for getting the package of a module. * Extract test_path_hooks constants into a util module for extension testing.
+ Make sure there is a test for the empty string as acceptable for
__package__.
* Implement PEP 302 protocol for loaders (should just be a matter of testing). * Implement PEP 302 protocol for loaders (should just be a matter of testing).
......
...@@ -90,6 +90,18 @@ class closing: ...@@ -90,6 +90,18 @@ class closing:
self.obj.close() self.obj.close()
def set___package__(fxn):
"""Set __package__ on the returned module."""
def wrapper(*args, **kwargs):
module = fxn(*args, **kwargs)
if not hasattr(module, '__package__') or module.__package__ is None:
module.__package__ = module.__name__
if not hasattr(module, '__path__'):
module.__package__ = module.__package__.rpartition('.')[0]
return module
return wrapper
class BuiltinImporter: class BuiltinImporter:
"""Meta path loader for built-in modules. """Meta path loader for built-in modules.
...@@ -111,12 +123,12 @@ class BuiltinImporter: ...@@ -111,12 +123,12 @@ class BuiltinImporter:
return cls if imp.is_builtin(fullname) else None return cls if imp.is_builtin(fullname) else None
@classmethod @classmethod
@set___package__
def load_module(cls, fullname): def load_module(cls, fullname):
"""Load a built-in module.""" """Load a built-in module."""
if fullname not in sys.builtin_module_names: if fullname not in sys.builtin_module_names:
raise ImportError("{0} is not a built-in module".format(fullname)) raise ImportError("{0} is not a built-in module".format(fullname))
module = imp.init_builtin(fullname) module = imp.init_builtin(fullname)
module.__package__ = ''
return module return module
...@@ -135,14 +147,12 @@ class FrozenImporter: ...@@ -135,14 +147,12 @@ class FrozenImporter:
return cls if imp.is_frozen(fullname) else None return cls if imp.is_frozen(fullname) else None
@classmethod @classmethod
@set___package__
def load_module(cls, fullname): def load_module(cls, fullname):
"""Load a frozen module.""" """Load a frozen module."""
if cls.find_module(fullname) is None: if cls.find_module(fullname) is None:
raise ImportError("{0} is not a frozen module".format(fullname)) raise ImportError("{0} is not a frozen module".format(fullname))
module = imp.init_frozen(fullname) module = imp.init_frozen(fullname)
module.__package__ = module.__name__
if not hasattr(module, '__path__'):
module.__package__ = module.__package__.rpartition('.')[0]
return module return module
...@@ -230,6 +240,7 @@ class _ExtensionFileLoader(object): ...@@ -230,6 +240,7 @@ class _ExtensionFileLoader(object):
raise ValueError("extension modules cannot be packages") raise ValueError("extension modules cannot be packages")
@check_name @check_name
@set___package__
def load_module(self, fullname): def load_module(self, fullname):
"""Load an extension module.""" """Load an extension module."""
assert self._name == fullname assert self._name == fullname
...@@ -368,11 +379,9 @@ class _PyFileLoader(object): ...@@ -368,11 +379,9 @@ class _PyFileLoader(object):
module.__loader__ = self module.__loader__ = self
if self._is_pkg: if self._is_pkg:
module.__path__ = [module.__file__.rsplit(path_sep, 1)[0]] module.__path__ = [module.__file__.rsplit(path_sep, 1)[0]]
module.__package__ = module.__name__ module.__package__ = module.__name__
elif '.' in module.__name__: if not hasattr(module, '__path__'):
module.__package__ = module.__name__.rsplit('.', 1)[0] module.__package__ = module.__package__.rpartition('.')[0]
else:
module.__package__ = None
exec(code_object, module.__dict__) exec(code_object, module.__dict__)
return module return module
......
...@@ -21,7 +21,8 @@ class LoaderTests(abc.LoaderTests): ...@@ -21,7 +21,8 @@ class LoaderTests(abc.LoaderTests):
with util.uncache(test_path_hook.NAME): with util.uncache(test_path_hook.NAME):
module = self.load_module(test_path_hook.NAME) module = self.load_module(test_path_hook.NAME)
for attr, value in [('__name__', test_path_hook.NAME), for attr, value in [('__name__', test_path_hook.NAME),
('__file__', test_path_hook.FILEPATH)]: ('__file__', test_path_hook.FILEPATH),
('__package__', '')]:
self.assertEqual(getattr(module, attr), value) self.assertEqual(getattr(module, attr), value)
self.assert_(test_path_hook.NAME in sys.modules) self.assert_(test_path_hook.NAME in sys.modules)
......
...@@ -23,7 +23,7 @@ class SimpleTest(unittest.TestCase): ...@@ -23,7 +23,7 @@ class SimpleTest(unittest.TestCase):
module = loader.load_module('_temp') module = loader.load_module('_temp')
self.assert_('_temp' in sys.modules) self.assert_('_temp' in sys.modules)
check = {'__name__': '_temp', '__file__': mapping['_temp'], check = {'__name__': '_temp', '__file__': mapping['_temp'],
'__package__': None} '__package__': ''}
for attr, value in check.items(): for attr, value in check.items():
self.assertEqual(getattr(module, attr), value) self.assertEqual(getattr(module, attr), value)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment