Commit d231cdff authored by Denis Bilenko's avatar Denis Bilenko

monkey: use __implements__ of the module to decide what to patch (issue #48)

parent 68ca8552
......@@ -69,6 +69,29 @@ __all__ = ['patch_all',
'patch_thread']
class original(object):
pass
def patch_module(name, items=None):
dest = __import__(name)
source = getattr(__import__('gevent.' + name), name)
monkey_original = dest.monkey_original = original()
count = 0
if items is None:
items = getattr(source, '__implements__', None)
if items is None:
raise AttributeError('%r does not have __implements__' % source)
for attr in items:
olditem = getattr(dest, attr, None)
newitem = getattr(source, attr)
if olditem is not newitem:
setattr(monkey_original, attr, olditem)
setattr(dest, attr, getattr(source, attr))
count += 1
return count
def patch_os():
"""Replace :func:`os.fork` with :func:`gevent.fork`."""
try:
......@@ -91,26 +114,22 @@ def patch_thread(threading=True, _threading_local=True):
If *threading* is true (the default), also patch ``threading.local``.
If *_threading_local* is true (the default), also patch ``_threading_local.local``.
"""
from gevent import thread as green_thread
if not patch_module('thread'):
return
from gevent.local import local
thread = __import__('thread')
if thread.exit is not green_thread.exit:
thread.get_ident = green_thread.get_ident
thread.start_new_thread = green_thread.start_new_thread
thread.LockType = green_thread.LockType
thread.allocate_lock = green_thread.allocate_lock
thread.exit = green_thread.exit
if hasattr(green_thread, 'stack_size'):
thread.stack_size = green_thread.stack_size
from gevent.local import local
thread._local = local
if threading:
if noisy and 'threading' in sys.modules:
sys.stderr.write("gevent.monkey's warning: 'threading' is already imported\n\n")
threading = __import__('threading')
threading.local = local
if _threading_local:
_threading_local = __import__('_threading_local')
_threading_local.local = local
thread._local = local
if threading:
if noisy and 'threading' in sys.modules:
sys.stderr.write("gevent.monkey's warning: 'threading' is already imported\n\n")
threading = __import__('threading')
threading.local = local
if _threading_local:
_threading_local = __import__('_threading_local')
_threading_local.local = local
dns_functions = ['getaddrinfo', 'getnameinfo', 'gethostbyname', 'gethostbyname_ex', 'gethostbyaddr']
def patch_socket(dns=True, aggressive=True):
......@@ -119,45 +138,24 @@ def patch_socket(dns=True, aggressive=True):
If *dns* is true, also patch dns functions in :mod:`socket`.
"""
from gevent import socket
_socket = __import__('socket')
_socket.socket = socket.socket
_socket.SocketType = socket.SocketType
_socket.create_connection = socket.create_connection
if hasattr(socket, 'socketpair'):
_socket.socketpair = socket.socketpair
if hasattr(socket, 'fromfd'):
_socket.fromfd = socket.fromfd
try:
from gevent.socket import ssl, sslerror
_socket.ssl = ssl
_socket.sslerror = sslerror
except ImportError:
if aggressive:
try:
del _socket.ssl
except AttributeError:
pass
if dns:
patch_dns()
items = None
else:
items = socket.__implements__[:]
for function in dns_functions:
items.remove(function)
patch_module('socket', items=items)
if aggressive:
if 'ssl' not in socket.__implements__:
socket.__dict__.pop('ssl', None)
def patch_dns():
from gevent.socket import gethostbyname, getaddrinfo
_socket = __import__('socket')
_socket.getaddrinfo = getaddrinfo
_socket.gethostbyname = gethostbyname
patch_module('socket', items=dns_functions)
def patch_ssl():
try:
_ssl = __import__('ssl')
except ImportError:
return
from gevent.ssl import SSLSocket, wrap_socket, get_server_certificate, sslwrap_simple
_ssl.SSLSocket = SSLSocket
_ssl.wrap_socket = wrap_socket
_ssl.get_server_certificate = get_server_certificate
_ssl.sslwrap_simple = sslwrap_simple
patch_module('ssl')
def patch_select(aggressive=False):
......@@ -165,18 +163,16 @@ def patch_select(aggressive=False):
If aggressive is true (the default), also remove other blocking functions the :mod:`select`.
"""
from gevent.select import select
_select = __import__('select')
globals()['_select_select'] = _select.select
_select.select = select
patch_module('select')
if aggressive:
select = __import__('select')
# since these are blocking and don't work with the libevent's event loop
# we're removing them here. This makes some other modules (e.g. asyncore)
# non-blocking, as they use select that we provide when none of these are available.
_select.__dict__.pop('poll', None)
_select.__dict__.pop('epoll', None)
_select.__dict__.pop('kqueue', None)
_select.__dict__.pop('kevent', None)
select.__dict__.pop('poll', None)
select.__dict__.pop('epoll', None)
select.__dict__.pop('kqueue', None)
select.__dict__.pop('kevent', None)
def patch_httplib():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment