Commit 34d8e41a authored by Brett Cannon's avatar Brett Cannon

Refactor importlib to make it easier to re-implement in C.

parent b429d3b0
...@@ -918,20 +918,12 @@ def _find_module(name, path): ...@@ -918,20 +918,12 @@ def _find_module(name, path):
return None return None
def _set___package__(module):
"""Set __package__ on a module."""
# Watch out for what comes out of sys.modules to not be a module,
# e.g. an int.
try:
module.__package__ = module.__name__
if not hasattr(module, '__path__'):
module.__package__ = module.__package__.rpartition('.')[0]
except AttributeError:
pass
def _sanity_check(name, package, level): def _sanity_check(name, package, level):
"""Verify arguments are "sane".""" """Verify arguments are "sane"."""
if not hasattr(name, 'rpartition'):
raise TypeError("module name must be str, not {}".format(type(name)))
if level < 0:
raise ValueError('level must be >= 0')
if package: if package:
if not hasattr(package, 'rindex'): if not hasattr(package, 'rindex'):
raise ValueError("__package__ not set to a string") raise ValueError("__package__ not set to a string")
...@@ -943,13 +935,12 @@ def _sanity_check(name, package, level): ...@@ -943,13 +935,12 @@ def _sanity_check(name, package, level):
raise ValueError("Empty module name") raise ValueError("Empty module name")
def _find_search_path(name, import_): _IMPLICIT_META_PATH = [BuiltinImporter, FrozenImporter, _DefaultPathFinder]
"""Find the search path for a module.
import_ is expected to be a callable which takes the name of a module to _ERR_MSG = 'No module named {!r}'
import. It is required to decouple the function from importlib.
""" def _find_and_load(name, import_):
"""Find and load the module."""
path = None path = None
parent = name.rpartition('.')[0] parent = name.rpartition('.')[0]
if parent: if parent:
...@@ -962,13 +953,28 @@ def _find_search_path(name, import_): ...@@ -962,13 +953,28 @@ def _find_search_path(name, import_):
except AttributeError: except AttributeError:
msg = (_ERR_MSG + '; {} is not a package').format(name, parent) msg = (_ERR_MSG + '; {} is not a package').format(name, parent)
raise ImportError(msg) raise ImportError(msg)
return parent, path loader = _find_module(name, path)
if loader is None:
raise ImportError(_ERR_MSG.format(name))
elif name not in sys.modules:
_IMPLICIT_META_PATH = [BuiltinImporter, FrozenImporter, _DefaultPathFinder] # The parent import may have already imported this module.
loader.load_module(name)
# Backwards-compatibility; be nicer to skip the dict lookup.
module = sys.modules[name]
if parent:
# Set the module as an attribute on its parent.
parent_module = sys.modules[parent]
setattr(parent_module, name.rpartition('.')[2], module)
# Set __package__ if the loader did not.
if not hasattr(module, '__package__') or module.__package__ is None:
try:
module.__package__ = module.__name__
if not hasattr(module, '__path__'):
module.__package__ = module.__package__.rpartition('.')[0]
except AttributeError:
pass
return module
_ERR_MSG = 'No module named {!r}'
def _gcd_import(name, package=None, level=0): def _gcd_import(name, package=None, level=0):
"""Import and return the module based on its name, the package the call is """Import and return the module based on its name, the package the call is
...@@ -991,24 +997,8 @@ def _gcd_import(name, package=None, level=0): ...@@ -991,24 +997,8 @@ def _gcd_import(name, package=None, level=0):
raise ImportError(message) raise ImportError(message)
return module return module
except KeyError: except KeyError:
pass pass # Don't want to chain the exception
parent, path = _find_search_path(name, _gcd_import) return _find_and_load(name, _gcd_import)
loader = _find_module(name, path)
if loader is None:
raise ImportError(_ERR_MSG.format(name))
elif name not in sys.modules:
# The parent import may have already imported this module.
loader.load_module(name)
# Backwards-compatibility; be nicer to skip the dict lookup.
module = sys.modules[name]
if parent:
# Set the module as an attribute on its parent.
parent_module = sys.modules[parent]
setattr(parent_module, name.rpartition('.')[2], module)
# Set __package__ if the loader did not.
if not hasattr(module, '__package__') or module.__package__ is None:
_set___package__(module)
return module
def _return_module(module, name, fromlist, level, import_): def _return_module(module, name, fromlist, level, import_):
...@@ -1071,12 +1061,8 @@ def __import__(name, globals={}, locals={}, fromlist=[], level=0): ...@@ -1071,12 +1061,8 @@ def __import__(name, globals={}, locals={}, fromlist=[], level=0):
import (e.g. ``from ..pkg import mod`` would have a 'level' of 2). import (e.g. ``from ..pkg import mod`` would have a 'level' of 2).
""" """
if not hasattr(name, 'rpartition'):
raise TypeError("module name must be str, not {}".format(type(name)))
if level == 0: if level == 0:
module = _gcd_import(name) module = _gcd_import(name)
elif level < 0:
raise ValueError('level must be >= 0')
else: else:
package = _calc___package__(globals) package = _calc___package__(globals)
module = _gcd_import(name, package, level) module = _gcd_import(name, package, level)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment