Commit 8ec6ae1d authored by David Wilson's avatar David Wilson

importer: module whitelist/blacklist support

Hoped to avoid it, but it's the obvious solution for Ansible.
parent 43ba1c76
......@@ -394,7 +394,7 @@ class Importer(object):
:param context: Context to communicate via.
"""
def __init__(self, router, context, core_src):
def __init__(self, router, context, core_src, whitelist=(), blacklist=()):
self._context = context
self._present = {'mitogen': [
'mitogen.compat',
......@@ -407,6 +407,15 @@ class Importer(object):
'mitogen.utils',
]}
self._lock = threading.Lock()
self.whitelist = whitelist or ['']
self.blacklist = list(blacklist) + [
# 2.x generates needless imports for 'builtins', while 3.x does the
# same for '__builtin__'. The correct one is built-in, the other
# always a negative round-trip.
'builtins',
'__builtin__',
]
# Presence of an entry in this map indicates in-flight GET_MODULE.
self._callbacks = {}
router.add_handler(self._on_load_module, LOAD_MODULE)
......@@ -451,12 +460,9 @@ class Importer(object):
finally:
del _tls.running
def _load_module_hacks(self, fullname):
if fullname in ('builtins', '__builtin__'):
# Python 2.x will generate needless imports for 'builtins', while
# Python 3.x will generate needless imports for '__builtin__'. The
# correct one is already present in sys.modules, the other is
# always a negative round-trip.
def _refuse_imports(self, fullname):
if ((not any(fullname.startswith(s) for s in self.whitelist)) or
(any(fullname.startswith(s) for s in self.blacklist))):
raise ImportError('Refused')
f = sys._getframe(2)
......@@ -515,7 +521,7 @@ class Importer(object):
def load_module(self, fullname):
_v and LOG.debug('Importer.load_module(%r)', fullname)
self._load_module_hacks(fullname)
self._refuse_imports(fullname)
event = threading.Event()
self._request_module(fullname, event.set)
......@@ -1260,7 +1266,7 @@ class ExternalContext(object):
if debug:
enable_debug_logging()
def _setup_importer(self, core_src_fd):
def _setup_importer(self, core_src_fd, whitelist, blacklist):
if core_src_fd:
with os.fdopen(101, 'r', 1) as fp:
core_size = int(fp.readline())
......@@ -1271,7 +1277,9 @@ class ExternalContext(object):
else:
core_src = None
self.importer = Importer(self.router, self.parent, core_src)
self.importer = Importer(self.router, self.parent, core_src,
whitelist, blacklist)
self.router.importer = self.importer
sys.meta_path.append(self.importer)
def _setup_package(self, context_id, parent_ids):
......@@ -1328,12 +1336,13 @@ class ExternalContext(object):
self.dispatch_stopped = True
def main(self, parent_ids, context_id, debug, profiling, log_level,
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True):
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True,
whitelist=(), blacklist=()):
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd)
try:
try:
self._setup_logging(debug, log_level)
self._setup_importer(core_src_fd)
self._setup_importer(core_src_fd, whitelist, blacklist)
self._setup_package(context_id, parent_ids)
if setup_stdio:
self._setup_stdio()
......@@ -1342,7 +1351,7 @@ class ExternalContext(object):
sys.executable = os.environ.pop('ARGV0', sys.executable)
_v and LOG.debug('Connected to %s; my ID is %r, PID is %r',
self.parent, context_id, os.getpid())
self.parent, context_id, os.getpid())
_v and LOG.debug('Recovered sys.executable: %r', sys.executable)
_profile_hook('main', self._dispatch_calls)
......
......@@ -341,17 +341,17 @@ def run(dest, router, args, deadline=None, econtext=None):
fp.write('#!%s\n' % (sys.executable,))
fp.write(inspect.getsource(mitogen.core))
fp.write('\n')
fp.write('ExternalContext().main%r\n' % ((
parent_ids, # parent_ids
context_id, # context_id
router.debug, # debug
router.profiling, # profiling
logging.getLogger().level, # log_level
sock2.fileno(), # in_fd
sock2.fileno(), # out_fd
None, # core_src_fd
False, # setup_stdio
),))
fp.write('ExternalContext().main(**%r)\n' % ({
'parent_ids': parent_ids,
'context_id': context_id,
'debug': router.debug,
'profiling': router.profiling,
'log_level': mitogen.parent.get_log_level(),
'in_fd': sock2.fileno(),
'out_fd': sock2.fileno(),
'core_src_fd': None,
'setup_stdio': False,
},))
finally:
fp.close()
......
......@@ -441,6 +441,8 @@ class ModuleResponder(object):
self._router = router
self._finder = ModuleFinder()
self._cache = {} # fullname -> pickled
self.blacklist = []
self.whitelist = []
router.add_handler(self._on_get_module, mitogen.core.GET_MODULE)
def __repr__(self):
......@@ -448,6 +450,12 @@ class ModuleResponder(object):
MAIN_RE = re.compile(r'^if\s+__name__\s*==\s*.__main__.\s*:', re.M)
def whitelist_prefix(self, fullname):
self.whitelist.append(fullname)
def blacklist_prefix(self, fullname):
self.blacklist.append(fullname)
def neutralize_main(self, src):
"""Given the source for the __main__ module, try to find where it
begins conditional execution based on a "if __name__ == '__main__'"
......@@ -458,6 +466,9 @@ class ModuleResponder(object):
return src
def _build_tuple(self, fullname):
if fullname in self._blacklist:
raise ImportError('blacklisted')
if fullname in self._cache:
return self._cache[fullname]
......
......@@ -63,6 +63,10 @@ class Argv(object):
return ' '.join(map(self.escape, self.argv))
def get_log_level():
return (LOG.level or logging.getLogger().level or logging.INFO)
def minimize_source(source):
subber = lambda match: '""' + ('\n' * match.group(0).count('\n'))
source = DOCSTRING_RE.sub(subber, source)
......@@ -336,14 +340,17 @@ class Stream(mitogen.core.Stream):
def get_preamble(self):
parent_ids = mitogen.parent_ids[:]
parent_ids.insert(0, mitogen.context_id)
source = inspect.getsource(mitogen.core)
source += '\nExternalContext().main%r\n' % ((
parent_ids, # parent_ids
self.remote_id, # context_id
self.debug,
self.profiling,
LOG.level or logging.getLogger().level or logging.INFO,
),)
source += '\nExternalContext().main(**%r)\n' % ({
'parent_ids': parent_ids,
'context_id': self.remote_id,
'debug': self.debug,
'profiling': self.profiling,
'log_level': get_log_level(),
'whitelist': self._router.get_module_whitelist(),
'blacklist': self._router.get_module_blacklist(),
},)
compressed = zlib.compress(minimize_source(source))
return str(len(compressed)) + '\n' + compressed
......@@ -385,6 +392,16 @@ class ChildIdAllocator(object):
class Router(mitogen.core.Router):
context_class = mitogen.core.Context
def get_module_blacklist(self):
if mitogen.context_id == 0:
return self.responder.blacklist
return self.importer.blacklist
def get_module_whitelist(self):
if mitogen.context_id == 0:
return self.responder.whitelist
return self.importer.whitelist
def allocate_id(self):
return self.id_allocator.allocate()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment