Commit 9bef8a3b authored by Jason Madden's avatar Jason Madden Committed by GitHub

Merge pull request #1109 from gevent/issue1108

Be more careful about issuing the SSL warning on Py2.
parents 811837cb eadd6ba0
...@@ -77,6 +77,9 @@ ...@@ -77,6 +77,9 @@
- Update c-ares to 1.14.0. See :issue:`1105`. - Update c-ares to 1.14.0. See :issue:`1105`.
- Be more careful about issuing a warning about patching SSL on
Python 2. See :issue:`1108`.
1.3a1 (2018-01-27) 1.3a1 (2018-01-27)
================== ==================
......
...@@ -90,19 +90,42 @@ else: ...@@ -90,19 +90,42 @@ else:
WIN = sys.platform.startswith("win") WIN = sys.platform.startswith("win")
class MonkeyPatchWarning(RuntimeWarning):
"""
The type of warnings we issue.
.. versionadded:: 1.3a2
"""
# maps module name -> {attribute name: original item} # maps module name -> {attribute name: original item}
# e.g. "time" -> {"sleep": built-in function sleep} # e.g. "time" -> {"sleep": built-in function sleep}
saved = {} saved = {}
def is_module_patched(modname): def is_module_patched(mod_name):
"""Check if a module has been replaced with a cooperative version.""" """
return modname in saved Check if a module has been replaced with a cooperative version.
:param str mod_name: The name of the standard library module,
e.g., ``'socket'``.
"""
return mod_name in saved
def is_object_patched(mod_name, item_name):
"""
Check if an object in a module has been replaced with a
cooperative version.
:param str mod_name: The name of the standard library module,
e.g., ``'socket'``.
:param str item_name: The name of the attribute in the module,
e.g., ``'create_connection'``.
def is_object_patched(modname, objname): """
"""Check if an object in a module has been replaced with a cooperative version.""" return is_module_patched(mod_name) and item_name in saved[mod_name]
return is_module_patched(modname) and objname in saved[modname]
def _get_original(name, items): def _get_original(name, items):
...@@ -120,14 +143,20 @@ def _get_original(name, items): ...@@ -120,14 +143,20 @@ def _get_original(name, items):
def get_original(mod_name, item_name): def get_original(mod_name, item_name):
"""Retrieve the original object from a module. """
Retrieve the original object from a module.
If the object has not been patched, then that object will still be
retrieved.
If the object has not been patched, then that object will still be retrieved. :param str mod_name: The name of the standard library module,
e.g., ``'socket'``.
:param item_name: A string or sequence of strings naming the
attribute(s) on the module ``mod_name`` to return.
:param item_name: A string or sequence of strings naming the attribute(s) on the module :return: The original value if a string was given for
``mod_name`` to return. ``item_name`` or a sequence of original values if a
:return: The original value if a string was given for ``item_name`` or a sequence sequence was passed.
of original values if a sequence was passed.
""" """
if isinstance(item_name, string_types): if isinstance(item_name, string_types):
return _get_original(mod_name, [item_name])[0] return _get_original(mod_name, [item_name])[0]
...@@ -195,7 +224,7 @@ def _queue_warning(message, _warnings): ...@@ -195,7 +224,7 @@ def _queue_warning(message, _warnings):
def _process_warnings(_warnings): def _process_warnings(_warnings):
import warnings import warnings
for warning in _warnings: for warning in _warnings:
warnings.warn(warning, RuntimeWarning, stacklevel=3) warnings.warn(warning, MonkeyPatchWarning, stacklevel=3)
def _patch_sys_std(name): def _patch_sys_std(name):
...@@ -206,15 +235,19 @@ def _patch_sys_std(name): ...@@ -206,15 +235,19 @@ def _patch_sys_std(name):
def patch_sys(stdin=True, stdout=True, stderr=True): def patch_sys(stdin=True, stdout=True, stderr=True):
"""Patch sys.std[in,out,err] to use a cooperative IO via a threadpool. """
Patch sys.std[in,out,err] to use a cooperative IO via a
threadpool.
This is relatively dangerous and can have unintended consequences such as hanging This is relatively dangerous and can have unintended consequences
the process or `misinterpreting control keys`_ when ``input`` and ``raw_input`` such as hanging the process or `misinterpreting control keys`_
are used. when :func:`input` and :func:`raw_input` are used. :func:`patch_all`
does *not* call this function by default.
This method does nothing on Python 3. The Python 3 interpreter wants to flush This method does nothing on Python 3. The Python 3 interpreter
the TextIOWrapper objects that make up stderr/stdout at shutdown time, but wants to flush the TextIOWrapper objects that make up
using a threadpool at that time leads to a hang. stderr/stdout at shutdown time, but using a threadpool at that
time leads to a hang.
.. _`misinterpreting control keys`: https://github.com/gevent/gevent/issues/274 .. _`misinterpreting control keys`: https://github.com/gevent/gevent/issues/274
""" """
...@@ -238,18 +271,20 @@ def patch_os(): ...@@ -238,18 +271,20 @@ def patch_os():
environment variable ``GEVENT_NOWAITPID`` is not defined). Does environment variable ``GEVENT_NOWAITPID`` is not defined). Does
nothing if fork is not available. nothing if fork is not available.
.. caution:: This method must be used with :func:`patch_signal` to have proper SIGCHLD .. caution:: This method must be used with :func:`patch_signal` to have proper `SIGCHLD`
handling and thus correct results from ``waitpid``. handling and thus correct results from ``waitpid``.
:func:`patch_all` calls both by default. :func:`patch_all` calls both by default.
.. caution:: For SIGCHLD handling to work correctly, the event loop must run. .. caution:: For `SIGCHLD` handling to work correctly, the event loop must run.
The easiest way to help ensure this is to use :func:`patch_all`. The easiest way to help ensure this is to use :func:`patch_all`.
""" """
patch_module('os') patch_module('os')
def patch_time(): def patch_time():
"""Replace :func:`time.sleep` with :func:`gevent.sleep`.""" """
Replace :func:`time.sleep` with :func:`gevent.sleep`.
"""
patch_module('time') patch_module('time')
...@@ -295,14 +330,20 @@ def patch_thread(threading=True, _threading_local=True, Event=False, logging=Tru ...@@ -295,14 +330,20 @@ def patch_thread(threading=True, _threading_local=True, Event=False, logging=Tru
existing_locks=True, existing_locks=True,
_warnings=None): _warnings=None):
""" """
patch_thread(threading=True, _threading_local=True, Event=False, logging=True, existing_locks=True) -> None
Replace the standard :mod:`thread` module to make it greenlet-based. Replace the standard :mod:`thread` module to make it greenlet-based.
- If *threading* is true (the default), also patch ``threading``. :keyword bool threading: When True (the default),
- If *_threading_local* is true (the default), also patch ``_threading_local.local``. also patch :mod:`threading`.
- If *logging* is True (the default), also patch locks taken if the logging module has :keyword bool _threading_local: When True (the default),
been configured. also patch :class:`_threading_local.local`.
- If *existing_locks* is True (the default), and the process is still single threaded, :keyword bool logging: When True (the default), also patch locks
make sure than any :class:`threading.RLock` (and, under Python 3, :class:`importlib._bootstrap._ModuleLock`) taken if the logging module has been configured.
:keyword bool existing_locks: When True (the default), and the
process is still single threaded, make sure that any
:class:`threading.RLock` (and, under Python 3, :class:`importlib._bootstrap._ModuleLock`)
instances that are currently locked can be properly unlocked. instances that are currently locked can be properly unlocked.
.. caution:: .. caution::
...@@ -444,9 +485,12 @@ def patch_thread(threading=True, _threading_local=True, Event=False, logging=Tru ...@@ -444,9 +485,12 @@ def patch_thread(threading=True, _threading_local=True, Event=False, logging=Tru
def patch_socket(dns=True, aggressive=True): def patch_socket(dns=True, aggressive=True):
"""Replace the standard socket object with gevent's cooperative sockets. """
Replace the standard socket object with gevent's cooperative
sockets.
If ``dns`` is true, also patch dns functions in :mod:`socket`. :keyword bool dns: When true (the default), also patch address
resolution functions in :mod:`socket`. See :doc:`dns` for details.
""" """
from gevent import socket from gevent import socket
# Note: although it seems like it's not strictly necessary to monkey patch 'create_connection', # Note: although it seems like it's not strictly necessary to monkey patch 'create_connection',
...@@ -465,21 +509,31 @@ def patch_socket(dns=True, aggressive=True): ...@@ -465,21 +509,31 @@ def patch_socket(dns=True, aggressive=True):
def patch_dns(): def patch_dns():
"""Replace DNS functions in :mod:`socket` with cooperative versions. """
Replace :doc:`DNS functions <dns>` in :mod:`socket` with
cooperative versions.
This is only useful if :func:`patch_socket` has been called and is done automatically This is only useful if :func:`patch_socket` has been called and is
by that method if requested. done automatically by that method if requested.
""" """
from gevent import socket from gevent import socket
patch_module('socket', items=socket.__dns__) # pylint:disable=no-member patch_module('socket', items=socket.__dns__) # pylint:disable=no-member
def patch_ssl(_warnings=None): def patch_ssl(_warnings=None, _first_time=True):
"""Replace SSLSocket object and socket wrapping functions in :mod:`ssl` with cooperative versions. """
patch_ssl() -> None
Replace :class:`ssl.SSLSocket` object and socket wrapping functions in
:mod:`ssl` with cooperative versions.
This is only useful if :func:`patch_socket` has been called. This is only useful if :func:`patch_socket` has been called.
""" """
if 'ssl' in sys.modules and hasattr(sys.modules['ssl'], 'SSLContext'): if _first_time and 'ssl' in sys.modules and hasattr(sys.modules['ssl'], 'SSLContext'):
if sys.version_info[0] > 2 or ('pkg_resources' not in sys.modules):
# Don't warn on Python 2 if pkg_resources has been imported
# because that imports ssl and it's commonly used for namespace packages,
# which typically means we're still in some early part of the import cycle
_queue_warning('Monkey-patching ssl after ssl has already been imported ' _queue_warning('Monkey-patching ssl after ssl has already been imported '
'may lead to errors, including RecursionError on Python 3.6. ' 'may lead to errors, including RecursionError on Python 3.6. '
'Please monkey-patch earlier. ' 'Please monkey-patch earlier. '
...@@ -570,7 +624,7 @@ def patch_subprocess(): ...@@ -570,7 +624,7 @@ def patch_subprocess():
def patch_builtins(): def patch_builtins():
""" """
Make the builtin __import__ function `greenlet safe`_ under Python 2. Make the builtin :func:`__import__` function `greenlet safe`_ under Python 2.
.. note:: .. note::
This does nothing under Python 3 as it is not necessary. Python 3 features This does nothing under Python 3 as it is not necessary. Python 3 features
...@@ -585,12 +639,12 @@ def patch_builtins(): ...@@ -585,12 +639,12 @@ def patch_builtins():
def patch_signal(): def patch_signal():
""" """
Make the signal.signal function work with a monkey-patched os. Make the :func:`signal.signal` function work with a :func:`monkey-patched os <patch_os>`.
.. caution:: This method must be used with :func:`patch_os` to have proper SIGCHLD .. caution:: This method must be used with :func:`patch_os` to have proper ``SIGCHLD``
handling. :func:`patch_all` calls both by default. handling. :func:`patch_all` calls both by default.
.. caution:: For proper SIGCHLD handling, you must yield to the event loop. .. caution:: For proper ``SIGCHLD`` handling, you must yield to the event loop.
Using :func:`patch_all` is the easiest way to ensure this. Using :func:`patch_all` is the easiest way to ensure this.
.. seealso:: :mod:`gevent.signal` .. seealso:: :mod:`gevent.signal`
...@@ -652,7 +706,7 @@ def patch_all(socket=True, dns=True, time=True, select=True, thread=True, os=Tru ...@@ -652,7 +706,7 @@ def patch_all(socket=True, dns=True, time=True, select=True, thread=True, os=Tru
if select: if select:
patch_select(aggressive=aggressive) patch_select(aggressive=aggressive)
if ssl: if ssl:
patch_ssl(_warnings=_warnings) patch_ssl(_warnings=_warnings, _first_time=first_time)
if httplib: if httplib:
raise ValueError('gevent.httplib is no longer provided, httplib must be False') raise ValueError('gevent.httplib is no longer provided, httplib must be False')
if subprocess: if subprocess:
......
from subprocess import Popen
from gevent import monkey from gevent import monkey
monkey.patch_all() monkey.patch_all()
import sys import sys
import unittest
class TestMonkey(unittest.TestCase):
maxDiff = None
import time def test_time(self):
assert 'built-in' not in repr(time.sleep), repr(time.sleep) import time
from gevent import time as gtime
self.assertIs(time.sleep, gtime.sleep)
try: def test_thread(self):
try:
import thread import thread
except ImportError: except ImportError:
import _thread as thread import _thread as thread
import threading import threading
assert 'built-in' not in repr(thread.start_new_thread), repr(thread.start_new_thread)
assert 'built-in' not in repr(threading._start_new_thread), repr(threading._start_new_thread) from gevent import thread as gthread
if sys.version_info[0] == 2: self.assertIs(thread.start_new_thread, gthread.start_new_thread)
assert 'built-in' not in repr(threading._sleep), repr(threading._sleep) self.assertIs(threading._start_new_thread, gthread.start_new_thread)
import socket if sys.version_info[0] == 2:
from gevent import socket as gevent_socket from gevent import threading as gthreading
assert socket.create_connection is gevent_socket.create_connection self.assertIs(threading._sleep, gthreading._sleep)
import os
import types
for name in ('fork', 'forkpty'): self.assertFalse(monkey.is_object_patched('threading', 'Event'))
monkey.patch_thread(Event=True)
self.assertTrue(monkey.is_object_patched('threading', 'Event'))
def test_socket(self):
import socket
from gevent import socket as gevent_socket
self.assertIs(socket.create_connection, gevent_socket.create_connection)
def test_os(self):
import os
import types
from gevent import os as gos
for name in ('fork', 'forkpty'):
if hasattr(os, name): if hasattr(os, name):
attr = getattr(os, name) attr = getattr(os, name)
assert 'built-in' not in repr(attr), repr(attr) assert 'built-in' not in repr(attr), repr(attr)
assert not isinstance(attr, types.BuiltinFunctionType), repr(attr) assert not isinstance(attr, types.BuiltinFunctionType), repr(attr)
assert isinstance(attr, types.FunctionType), repr(attr) assert isinstance(attr, types.FunctionType), repr(attr)
self.assertIs(attr, getattr(gos, name))
assert monkey.saved def test_saved(self):
self.assertTrue(monkey.saved)
for modname in monkey.saved:
self.assertTrue(monkey.is_module_patched(modname))
assert not monkey.is_object_patched('threading', 'Event') for objname in monkey.saved[modname]:
monkey.patch_thread(Event=True) self.assertTrue(monkey.is_object_patched(modname, objname))
assert monkey.is_object_patched('threading', 'Event')
for modname in monkey.saved: def test_patch_subprocess_twice(self):
assert monkey.is_module_patched(modname) self.assertNotIn('gevent', repr(Popen))
self.assertIs(Popen, monkey.get_original('subprocess', 'Popen'))
monkey.patch_subprocess()
self.assertIs(Popen, monkey.get_original('subprocess', 'Popen'))
for objname in monkey.saved[modname]: def test_patch_twice(self):
assert monkey.is_object_patched(modname, objname) import warnings
orig_saved = {} orig_saved = {}
for k, v in monkey.saved.items(): for k, v in monkey.saved.items():
orig_saved[k] = v.copy() orig_saved[k] = v.copy()
import warnings with warnings.catch_warnings(record=True) as issued_warnings:
with warnings.catch_warnings(record=True) as issued_warnings:
# Patch again, triggering three warnings, one for os=False/signal=True, # Patch again, triggering three warnings, one for os=False/signal=True,
# one for repeated monkey-patching, one for patching after ssl (on python >= 2.7.9) # one for repeated monkey-patching, one for patching after ssl (on python >= 2.7.9)
monkey.patch_all(os=False) monkey.patch_all(os=False)
assert len(issued_warnings) >= 2, [str(x) for x in issued_warnings] self.assertGreaterEqual(len(issued_warnings), 2)
assert 'SIGCHLD' in str(issued_warnings[-1].message), issued_warnings[-1] self.assertIn('SIGCHLD', str(issued_warnings[-1].message))
assert 'more than once' in str(issued_warnings[0].message), issued_warnings[0] self.assertIn('more than once', str(issued_warnings[0].message))
# Patching with the exact same argument doesn't issue a second warning. # Patching with the exact same argument doesn't issue a second warning.
# in fact, it doesn't do anything # in fact, it doesn't do anything
...@@ -59,13 +88,20 @@ with warnings.catch_warnings(record=True) as issued_warnings: ...@@ -59,13 +88,20 @@ with warnings.catch_warnings(record=True) as issued_warnings:
monkey.patch_all(os=False) monkey.patch_all(os=False)
orig_saved['_gevent_saved_patch_all'] = monkey.saved['_gevent_saved_patch_all'] orig_saved['_gevent_saved_patch_all'] = monkey.saved['_gevent_saved_patch_all']
assert not issued_warnings, [str(x) for x in issued_warnings] self.assertFalse(issued_warnings)
# Make sure that re-patching did not change the monkey.saved
# attribute, overwriting the original functions.
if 'logging' in monkey.saved and 'logging' not in orig_saved:
# some part of the warning or unittest machinery imports logging
orig_saved['logging'] = monkey.saved['logging']
self.assertEqual(orig_saved, monkey.saved)
# Make sure some problematic attributes stayed correct.
# NOTE: This was only a problem if threading was not previously imported.
for k, v in monkey.saved['threading'].items():
self.assertNotIn('gevent', str(v))
# Make sure that re-patching did not change the monkey.saved
# attribute, overwriting the original functions.
assert orig_saved == monkey.saved, (orig_saved, monkey.saved)
# Make sure some problematic attributes stayed correct. if __name__ == '__main__':
# NOTE: This was only a problem if threading was not previously imported. unittest.main()
for k, v in monkey.saved['threading'].items():
assert 'gevent' not in str(v), (k, v)
import unittest
import warnings
# This file should only have this one test in it
# because we have to be careful about our imports
# and because we need to be careful about our patching.
class Test(unittest.TestCase):
def test_with_pkg_resources(self):
# Issue 1108: Python 2, importing pkg_resources,
# as is done for namespace packages, imports ssl,
# leading to an unwanted SSL warning.
__import__('pkg_resources')
from gevent import monkey
self.assertFalse(monkey.saved)
with warnings.catch_warnings(record=True) as issued_warnings:
warnings.simplefilter('always')
monkey.patch_all()
monkey.patch_all()
issued_warnings = [x for x in issued_warnings
if isinstance(x.message, monkey.MonkeyPatchWarning)]
self.assertFalse(issued_warnings, [str(i) for i in issued_warnings])
self.assertEqual(0, len(issued_warnings))
if __name__ == '__main__':
unittest.main()
test___monkey_patching.py test___monkey_patching.py
test__monkey_ssl_warning.py
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment