Commit c38c4f60 authored by Denis Bilenko's avatar Denis Bilenko

initial impl for socket/ssl for python3 #38

parent 85e94e89
# Copyright (c) 2009-2014 Denis Bilenko and gevent contributors. See LICENSE for details.
import time
from gevent import _socketcommon
for key in _socketcommon.__dict__:
if key.startswith('__'):
continue
globals()[key] = getattr(_socketcommon, key)
__socket__ = _socketcommon.__socket__
__implements__ = _socketcommon._implements
__extensions__ = _socketcommon.__extensions__
__imports__ = _socketcommon.__imports__
__dns__ = _socketcommon.__dns__
_fileobject = __socket__._fileobject
if sys.version_info[:2] < (2, 7):
_get_memory = buffer
else:
def _get_memory(string, offset):
try:
return memoryview(string)[offset:]
except TypeError:
# fixes "python2.7 array.array doesn't support memoryview used in
# gevent.socket.send" issue
# (http://code.google.com/p/gevent/issues/detail?id=94)
return buffer(string, offset)
class _closedsocket(object):
__slots__ = []
def _dummy(*args, **kwargs):
raise error(EBADF, 'Bad file descriptor')
# All _delegate_methods must also be initialized here.
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
__getattr__ = _dummy
timeout_default = object()
class socket(object):
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
if _sock is None:
self._sock = _realsocket(family, type, proto)
self.timeout = _socket.getdefaulttimeout()
else:
if hasattr(_sock, '_sock'):
self._sock = _sock._sock
self.timeout = getattr(_sock, 'timeout', False)
if self.timeout is False:
self.timeout = _socket.getdefaulttimeout()
else:
self._sock = _sock
self.timeout = _socket.getdefaulttimeout()
self._sock.setblocking(0)
fileno = self._sock.fileno()
self.hub = get_hub()
io = self.hub.loop.io
self._read_event = io(fileno, 1)
self._write_event = io(fileno, 2)
def __repr__(self):
return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), self._formatinfo())
def __str__(self):
return '<%s %s>' % (type(self).__name__, self._formatinfo())
def _formatinfo(self):
try:
fileno = self.fileno()
except Exception as ex:
fileno = str(ex)
try:
sockname = self.getsockname()
sockname = '%s:%s' % sockname
except Exception:
sockname = None
try:
peername = self.getpeername()
peername = '%s:%s' % peername
except Exception:
peername = None
result = 'fileno=%s' % fileno
if sockname is not None:
result += ' sock=' + str(sockname)
if peername is not None:
result += ' peer=' + str(peername)
if getattr(self, 'timeout', None) is not None:
result += ' timeout=' + str(self.timeout)
return result
def _get_ref(self):
return self._read_event.ref or self._write_event.ref
def _set_ref(self, value):
self._read_event.ref = value
self._write_event.ref = value
ref = property(_get_ref, _set_ref)
def _wait(self, watcher, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *watcher* has pending events.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
assert watcher.callback is None, 'This socket is already used by another greenlet: %r' % (watcher.callback, )
if self.timeout is not None:
timeout = Timeout.start_new(self.timeout, timeout_exc, ref=False)
else:
timeout = None
try:
self.hub.wait(watcher)
finally:
if timeout is not None:
timeout.cancel()
def accept(self):
sock = self._sock
while True:
try:
client_socket, address = sock.accept()
break
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
return socket(_sock=client_socket), address
def close(self, _closedsocket=_closedsocket, cancel_wait_ex=cancel_wait_ex):
# This function should not reference any globals. See Python issue #808164.
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
self._sock = _closedsocket()
@property
def closed(self):
return isinstance(self._sock, _closedsocket)
def connect(self, address):
if self.timeout == 0.0:
return self._sock.connect(address)
sock = self._sock
if isinstance(address, tuple):
r = getaddrinfo(address[0], address[1], sock.family, sock.type, sock.proto)
address = r[0][-1]
if self.timeout is not None:
timer = Timeout.start_new(self.timeout, timeout('timed out'))
else:
timer = None
try:
while True:
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
if err:
raise error(err, strerror(err))
result = sock.connect_ex(address)
if not result or result == EISCONN:
break
elif (result in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (result == EINVAL and is_windows):
self._wait(self._write_event)
else:
raise error(result, strerror(result))
finally:
if timer is not None:
timer.cancel()
def connect_ex(self, address):
try:
return self.connect(address) or 0
except timeout:
return EAGAIN
except error as ex:
if type(ex) is error:
return ex.args[0]
else:
raise # gaierror is not silented by connect_ex
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
Note, that the new socket does not inherit the timeout."""
return socket(_sock=self._sock)
def makefile(self, mode='r', bufsize=-1):
# Two things to look out for:
# 1) Closing the original socket object should not close the
# socket (hence creating a new instance)
# 2) The resulting fileobject must keep the timeout in order
# to be compatible with the stdlib's socket.makefile.
return _fileobject(type(self)(_sock=self), mode, bufsize)
def recv(self, *args):
sock = self._sock # keeping the reference so that fd is not closed during waiting
while True:
try:
return sock.recv(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
# QQQ without clearing exc_info test__refcount.test_clean_exit fails
sys.exc_clear()
self._wait(self._read_event)
def recvfrom(self, *args):
sock = self._sock
while True:
try:
return sock.recvfrom(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def recvfrom_into(self, *args):
sock = self._sock
while True:
try:
return sock.recvfrom_into(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def recv_into(self, *args):
sock = self._sock
while True:
try:
return sock.recv_into(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def send(self, data, flags=0, timeout=timeout_default):
sock = self._sock
if timeout is timeout_default:
timeout = self.timeout
try:
return sock.send(data, flags)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event)
try:
return sock.send(data, flags)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def sendall(self, data, flags=0):
if isinstance(data, unicode):
data = data.encode()
# this sendall is also reused by gevent.ssl.SSLSocket subclass,
# so it should not call self._sock methods directly
if self.timeout is None:
data_sent = 0
while data_sent < len(data):
data_sent += self.send(_get_memory(data, data_sent), flags)
else:
timeleft = self.timeout
end = time.time() + timeleft
data_sent = 0
while True:
data_sent += self.send(_get_memory(data, data_sent), flags, timeout=timeleft)
if data_sent >= len(data):
break
timeleft = end - time.time()
if timeleft <= 0:
raise timeout('timed out')
def sendto(self, *args):
sock = self._sock
try:
return sock.sendto(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event)
try:
return sock.sendto(*args)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def setblocking(self, flag):
if flag:
self.timeout = None
else:
self.timeout = 0.0
def settimeout(self, howlong):
if howlong is not None:
try:
f = howlong.__float__
except AttributeError:
raise TypeError('a float is required')
howlong = f()
if howlong < 0.0:
raise ValueError('Timeout value out of range')
self.timeout = howlong
def gettimeout(self):
return self.timeout
def shutdown(self, how):
if how == 0: # SHUT_RD
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
elif how == 1: # SHUT_WR
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
else:
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
self._sock.shutdown(how)
family = property(lambda self: self._sock.family, doc="the socket family")
type = property(lambda self: self._sock.type, doc="the socket type")
proto = property(lambda self: self._sock.proto, doc="the socket protocol")
# delegate the functions that we haven't implemented to the real socket object
_s = ("def %s(self, *args): return self._sock.%s(*args)\n\n"
"%s.__doc__ = _realsocket.%s.__doc__\n")
for _m in set(__socket__._socketmethods) - set(locals()):
exec(_s % (_m, _m, _m, _m))
del _m, _s
SocketType = socket
if hasattr(_socket, 'socketpair'):
def socketpair(*args):
one, two = _socket.socketpair(*args)
return socket(_sock=one), socket(_sock=two)
else:
__implements__.remove('socketpair')
if hasattr(_socket, 'fromfd'):
def fromfd(*args):
return socket(_sock=_socket.fromfd(*args))
else:
__implements__.remove('fromfd')
__all__ = __implements__ + __extensions__ + __imports__
# Port of Python 3.3's socket module to gevent
import io
import time
from gevent import _socketcommon
from gevent.hub import text_type
import _socket
from io import BlockingIOError
for key in _socketcommon.__dict__:
if key.startswith('__'):
continue
globals()[key] = getattr(_socketcommon, key)
__socket__ = _socketcommon.__socket__
__implements__ = _socketcommon._implements
__extensions__ = _socketcommon.__extensions__
__imports__ = _socketcommon.__imports__
__dns__ = _socketcommon.__dns__
SocketIO = __socket__.SocketIO
def _get_memory(string, offset):
return memoryview(string)[offset:]
timeout_default = object()
class socket(_socket.socket):
__slots__ = ["__weakref__", "_io_refs", "_closed", "hub", "_read_event", "_write_event", "timeout"]
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
_socket.socket.__init__(self, family, type, proto, fileno)
self._io_refs = 0
self._closed = False
_socket.socket.setblocking(self, False)
fileno = _socket.socket.fileno(self)
self.hub = get_hub()
io_class = self.hub.loop.io
self._read_event = io_class(fileno, 1)
self._write_event = io_class(fileno, 2)
self.timeout = _socket.getdefaulttimeout()
@property
def type(self):
return _socket.socket.type.__get__(self) & ~_socket.SOCK_NONBLOCK
def __enter__(self):
return self
def __exit__(self, *args):
if not self._closed:
self.close()
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
if s.startswith("<socket object"):
s = "<%s.%s%s%s" % (self.__class__.__module__,
self.__class__.__name__,
getattr(self, '_closed', False) and " [closed] " or "",
s[7:])
return s
def __getstate__(self):
raise TypeError("Cannot serialize socket object")
def _get_ref(self):
return self._read_event.ref or self._write_event.ref
def _set_ref(self, value):
self._read_event.ref = value
self._write_event.ref = value
ref = property(_get_ref, _set_ref)
def _wait(self, watcher, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *watcher* has pending events.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
assert watcher.callback is None, 'This socket is already used by another greenlet: %r' % (watcher.callback, )
if self.timeout is not None:
timeout = Timeout.start_new(self.timeout, timeout_exc, ref=False)
else:
timeout = None
try:
self.hub.wait(watcher)
finally:
if timeout is not None:
timeout.cancel()
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
"""
fd = dup(self.fileno())
sock = self.__class__(self.family, self.type, self.proto, fileno=fd)
sock.settimeout(self.gettimeout())
return sock
def accept(self):
"""accept() -> (socket object, address info)
Wait for an incoming connection. Return a new socket
representing the connection, and the address of the client.
For IP sockets, the address info is a pair (hostaddr, port).
"""
while True:
try:
fd, addr = self._accept()
break
except BlockingIOError:
if self.timeout == 0.0:
raise
self._wait(self._read_event)
sock = socket(self.family, self.type, self.proto, fileno=fd)
# Python Issue #7995: if no default timeout is set and the listening
# socket had a (non-zero) timeout, force the new socket in blocking
# mode to override platform-specific socket flags inheritance.
if getdefaulttimeout() is None and self.gettimeout():
sock.setblocking(True)
return sock, addr
def makefile(self, mode="r", buffering=None, *,
encoding=None, errors=None, newline=None):
"""makefile(...) -> an I/O stream connected to the socket
The arguments are as for io.open() after the filename,
except the only mode characters supported are 'r', 'w' and 'b'.
The semantics are similar too. (XXX refactor to share code?)
"""
for c in mode:
if c not in {"r", "w", "b"}:
raise ValueError("invalid mode %r (only r, w, b allowed)")
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing
binary = "b" in mode
rawmode = ""
if reading:
rawmode += "r"
if writing:
rawmode += "w"
raw = SocketIO(self, rawmode)
self._io_refs += 1
if buffering is None:
buffering = -1
if buffering < 0:
buffering = io.DEFAULT_BUFFER_SIZE
if buffering == 0:
if not binary:
raise ValueError("unbuffered streams must be binary")
return raw
if reading and writing:
buffer = io.BufferedRWPair(raw, raw, buffering)
elif reading:
buffer = io.BufferedReader(raw, buffering)
else:
assert writing
buffer = io.BufferedWriter(raw, buffering)
if binary:
return buffer
text = io.TextIOWrapper(buffer, encoding, errors, newline)
text.mode = mode
return text
def _decref_socketios(self):
if self._io_refs > 0:
self._io_refs -= 1
if self._closed:
self.close()
def _real_close(self, _ss=_socket.socket, cancel_wait_ex=cancel_wait_ex):
# This function should not reference any globals. See Python issue #808164.
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
_ss.close(self)
def close(self):
# This function should not reference any globals. See Python issue #808164.
self._closed = True
if self._io_refs <= 0:
self._real_close()
@property
def closed(self):
return self._closed
def detach(self):
"""detach() -> file descriptor
Close the socket object without closing the underlying file descriptor.
The object cannot be used after this call, but the file descriptor
can be reused for other purposes. The file descriptor is returned.
"""
self._closed = True
return super().detach()
def connect(self, address):
if self.timeout == 0.0:
return _socket.socket.connect(self, address)
if isinstance(address, tuple):
r = getaddrinfo(address[0], address[1], self.family, self.type, self.proto)
address = r[0][-1]
if self.timeout is not None:
timer = Timeout.start_new(self.timeout, timeout('timed out'))
else:
timer = None
try:
while True:
err = self.getsockopt(SOL_SOCKET, SO_ERROR)
if err:
raise error(err, strerror(err))
result = _socket.socket.connect_ex(self, address)
if not result or result == EISCONN:
break
elif (result in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (result == EINVAL and is_windows):
self._wait(self._write_event)
else:
raise error(result, strerror(result))
finally:
if timer is not None:
timer.cancel()
def connect_ex(self, address):
try:
return self.connect(address) or 0
except timeout:
return EAGAIN
except error as ex:
if type(ex) is error:
return ex.args[0]
else:
raise # gaierror is not silented by connect_ex
def recv(self, *args):
while True:
try:
return _socket.socket.recv(self, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
self._wait(self._read_event)
def recvfrom(self, *args):
while True:
try:
return _socket.socket.recvfrom(self, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
self._wait(self._read_event)
def recvfrom_into(self, *args):
while True:
try:
return _socket.socket.recvfrom_into(self, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
self._wait(self._read_event)
def recv_into(self, *args):
while True:
try:
return _socket.socket.recv_into(self, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
self._wait(self._read_event)
def send(self, data, flags=0, timeout=timeout_default):
if timeout is timeout_default:
timeout = self.timeout
try:
return _socket.socket.send(self, data, flags)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
self._wait(self._write_event)
try:
return sock.send(data, flags)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def sendall(self, data, flags=0):
if isinstance(data, text_type):
data = data.encode()
if self.timeout is None:
data_sent = 0
while data_sent < len(data):
data_sent += self.send(_get_memory(data, data_sent), flags)
else:
timeleft = self.timeout
end = time.time() + timeleft
data_sent = 0
while True:
data_sent += self.send(_get_memory(data, data_sent), flags, timeout=timeleft)
if data_sent >= len(data):
break
timeleft = end - time.time()
if timeleft <= 0:
raise timeout('timed out')
def sendto(self, *args):
try:
return _socket.socket.sendto(self, *args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
self._wait(self._write_event)
try:
return _socket.socket.sendto(self, *args)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def setblocking(self, flag):
if flag:
self.timeout = None
else:
self.timeout = 0.0
def settimeout(self, howlong):
if howlong is not None:
try:
f = howlong.__float__
except AttributeError:
raise TypeError('a float is required')
howlong = f()
if howlong < 0.0:
raise ValueError('Timeout value out of range')
self.timeout = howlong
def gettimeout(self):
return self.timeout
def shutdown(self, how):
if how == 0: # SHUT_RD
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
elif how == 1: # SHUT_WR
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
else:
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
self._sock.shutdown(how)
SocketType = socket
def fromfd(fd, family, type, proto=0):
""" fromfd(fd, family, type[, proto]) -> socket object
Create a socket object from a duplicate of the given file
descriptor. The remaining arguments are the same as for socket().
"""
nfd = dup(fd)
return socket(family, type, proto, nfd)
if hasattr(_socket.socket, "share"):
def fromshare(info):
""" fromshare(info) -> socket object
Create a socket object from a the bytes object returned by
socket.share(pid).
"""
return socket(0, 0, 0, info)
if hasattr(_socket, "socketpair"):
def socketpair(family=None, type=SOCK_STREAM, proto=0):
"""socketpair([family[, type[, proto]]]) -> (socket object, socket object)
Create a pair of socket objects from the sockets returned by the platform
socketpair() function.
The arguments are the same as for socket() except the default family is
AF_UNIX if defined on the platform; otherwise, the default is AF_INET.
"""
if family is None:
try:
family = AF_UNIX
except NameError:
family = AF_INET
a, b = _socket.socketpair(family, type, proto)
a = socket(family, type, proto, a.detach())
b = socket(family, type, proto, b.detach())
return a, b
__all__ = __implements__ + __extensions__ + __imports__
# Copyright (c) 2009-2014 Denis Bilenko and gevent contributors. See LICENSE for details.
from __future__ import absolute_import
# standard functions and classes that this module re-implements in a gevent-aware way:
_implements = ['create_connection',
'socket',
'SocketType',
'fromfd',
'socketpair']
__dns__ = ['getaddrinfo',
'gethostbyname',
'gethostbyname_ex',
'gethostbyaddr',
'getnameinfo',
'getfqdn']
_implements += __dns__
# non-standard functions that this module provides:
__extensions__ = ['wait_read',
'wait_write',
'wait_readwrite']
# standard functions and classes that this module re-imports
__imports__ = ['error',
'gaierror',
'herror',
'htonl',
'htons',
'ntohl',
'ntohs',
'inet_aton',
'inet_ntoa',
'inet_pton',
'inet_ntop',
'timeout',
'gethostname',
'getprotobyname',
'getservbyname',
'getservbyport',
'getdefaulttimeout',
'setdefaulttimeout',
# Windows:
'errorTab']
import sys
from gevent.hub import get_hub, string_types, integer_types
from gevent.timeout import Timeout
is_windows = sys.platform == 'win32'
if is_windows:
# no such thing as WSAEPERM or error code 10001 according to winsock.h or MSDN
from errno import WSAEINVAL as EINVAL
from errno import WSAEWOULDBLOCK as EWOULDBLOCK
from errno import WSAEINPROGRESS as EINPROGRESS
from errno import WSAEALREADY as EALREADY
from errno import WSAEISCONN as EISCONN
from gevent.win32util import formatError as strerror
EAGAIN = EWOULDBLOCK
else:
from errno import EINVAL
from errno import EWOULDBLOCK
from errno import EINPROGRESS
from errno import EALREADY
from errno import EAGAIN
from errno import EISCONN
from os import strerror
try:
from errno import EBADF
except ImportError:
EBADF = 9
import _socket
_realsocket = _socket.socket
import socket as __socket__
for name in __imports__[:]:
try:
value = getattr(__socket__, name)
globals()[name] = value
except AttributeError:
__imports__.remove(name)
for name in __socket__.__all__:
value = getattr(__socket__, name)
if isinstance(value, integer_types) or isinstance(value, string_types):
globals()[name] = value
__imports__.append(name)
del name, value
def wait(io, timeout=None, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *io* is ready.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
assert io.callback is None, 'This socket is already used by another greenlet: %r' % (io.callback, )
if timeout is not None:
timeout = Timeout.start_new(timeout, timeout_exc)
try:
return get_hub().wait(io)
finally:
if timeout is not None:
timeout.cancel()
# rename "io" to "watcher" because wait() works with any watcher
def wait_read(fileno, timeout=None, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *fileno* is ready to read.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 1)
return wait(io, timeout, timeout_exc)
def wait_write(fileno, timeout=None, timeout_exc=timeout('timed out'), event=None):
"""Block the current greenlet until *fileno* is ready to write.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 2)
return wait(io, timeout, timeout_exc)
def wait_readwrite(fileno, timeout=None, timeout_exc=timeout('timed out'), event=None):
"""Block the current greenlet until *fileno* is ready to read or write.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 3)
return wait(io, timeout, timeout_exc)
cancel_wait_ex = error(EBADF, 'File descriptor was closed in another greenlet')
def cancel_wait(watcher):
get_hub().cancel_wait(watcher, cancel_wait_ex)
class BlockingResolver(object):
def __init__(self, hub=None):
pass
def close(self):
pass
for method in ['gethostbyname',
'gethostbyname_ex',
'getaddrinfo',
'gethostbyaddr',
'getnameinfo']:
locals()[method] = staticmethod(getattr(_socket, method))
def gethostbyname(hostname):
return get_hub().resolver.gethostbyname(hostname)
def gethostbyname_ex(hostname):
return get_hub().resolver.gethostbyname_ex(hostname)
def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0):
return get_hub().resolver.getaddrinfo(host, port, family, socktype, proto, flags)
def gethostbyaddr(ip_address):
return get_hub().resolver.gethostbyaddr(ip_address)
def getnameinfo(sockaddr, flags):
return get_hub().resolver.getnameinfo(sockaddr, flags)
def getfqdn(name=''):
"""Get fully qualified domain name from name.
An empty argument is interpreted as meaning the local host.
First the hostname returned by gethostbyaddr() is checked, then
possibly existing aliases. In case no FQDN is available, hostname
from gethostname() is returned.
"""
name = name.strip()
if not name or name == '0.0.0.0':
name = gethostname()
try:
hostname, aliases, ipaddrs = gethostbyaddr(name)
except error:
pass
else:
aliases.insert(0, hostname)
for name in aliases:
if '.' in name:
break
else:
name = hostname
return name
# Wrapper module for _ssl. Written by Bill Janssen.
# Ported to gevent by Denis Bilenko.
"""SSL wrapper for socket objects.
For the documentation, refer to :mod:`ssl` module manual.
This module implements cooperative SSL socket wrappers.
"""
from __future__ import absolute_import
import ssl as __ssl__
try:
_ssl = __ssl__._ssl
except AttributeError:
_ssl = __ssl__._ssl2
import sys
import errno
from gevent.socket import socket, _fileobject, timeout_default
from gevent.socket import error as socket_error
from gevent.hub import string_types
__implements__ = ['SSLSocket',
'wrap_socket',
'get_server_certificate',
'sslwrap_simple']
__imports__ = ['SSLError',
'RAND_status',
'RAND_egd',
'RAND_add',
'cert_time_to_seconds',
'get_protocol_name',
'DER_cert_to_PEM_cert',
'PEM_cert_to_DER_cert']
for name in __imports__[:]:
try:
value = getattr(__ssl__, name)
globals()[name] = value
except AttributeError:
__imports__.remove(name)
for name in dir(__ssl__):
if not name.startswith('_'):
value = getattr(__ssl__, name)
if isinstance(value, (int, long, tuple)) or isinstance(value, string_types):
globals()[name] = value
__imports__.append(name)
del name, value
__all__ = __implements__ + __imports__
class SSLSocket(socket):
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
socket.__init__(self, _sock=sock)
if certfile and not keyfile:
keyfile = certfile
# see if it's connected
try:
socket.getpeername(self)
except socket_error as e:
if e.args[0] != errno.ENOTCONN:
raise
# no, no connection yet
self._sslobj = None
else:
# yes, create the SSL object
if ciphers is None:
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
else:
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs,
ciphers)
if do_handshake_on_connect:
self.do_handshake()
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
self.ciphers = ciphers
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
while True:
try:
return self._sslobj.read(len)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
elif ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
# note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional
self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout)
else:
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
while True:
try:
return self._sslobj.write(data)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a
certificate was provided, but not validated."""
return self._sslobj.peer_certificate(binary_form)
def cipher(self):
if not self._sslobj:
return None
else:
return self._sslobj.cipher()
def send(self, data, flags=0, timeout=timeout_default):
if timeout is timeout_default:
timeout = self.timeout
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
while True:
try:
v = self._sslobj.write(data)
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
return 0
sys.exc_clear()
self._wait(self._read_event)
elif x.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
return 0
sys.exc_clear()
self._wait(self._write_event)
else:
raise
else:
return v
else:
return socket.send(self, data, flags, timeout)
# is it possible for sendall() to send some data without encryption if another end shut down SSL?
def sendto(self, *args):
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
return socket.sendto(self, *args)
def recv(self, buflen=1024, flags=0):
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
# QQQ Shouldn't we wrap the SSL_WANT_READ errors as socket.timeout errors to match socket.recv's behavior?
return self.read(buflen)
else:
return socket.recv(self, buflen, flags)
def recv_into(self, buffer, nbytes=None, flags=0):
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__)
while True:
try:
tmp_buffer = self.read(nbytes)
v = len(tmp_buffer)
buffer[:v] = tmp_buffer
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
continue
else:
raise
else:
return socket.recv_into(self, buffer, nbytes, flags)
def recvfrom(self, *args):
if self._sslobj:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom(self, *args)
def recvfrom_into(self, *args):
if self._sslobj:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom_into(self, *args)
def pending(self):
if self._sslobj:
return self._sslobj.pending()
else:
return 0
def _sslobj_shutdown(self):
while True:
try:
return self._sslobj.shutdown()
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
elif ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def unwrap(self):
if self._sslobj:
s = self._sslobj_shutdown()
self._sslobj = None
return socket(_sock=s)
else:
raise ValueError("No SSL wrapper around " + str(self))
def shutdown(self, how):
self._sslobj = None
socket.shutdown(self, how)
def close(self):
if self._makefile_refs < 1:
self._sslobj = None
socket.close(self)
else:
self._makefile_refs -= 1
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
while True:
try:
return self._sslobj.do_handshake()
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout)
else:
raise
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._sslobj:
raise ValueError("attempt to connect already-connected SSLSocket!")
socket.connect(self, addr)
if self.ciphers is None:
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
else:
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs, self.ciphers)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
return (SSLSocket(newsock._sock,
keyfile=self.keyfile,
certfile=self.certfile,
server_side=True,
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
ciphers=self.ciphers),
addr)
def makefile(self, mode='r', bufsize=-1):
"""Make and return a file-like object that
works with the SSL connection. Just use the code
from the socket module."""
self._makefile_refs += 1
# close=True so as to decrement the reference count when done with
# the file-like object.
return _fileobject(self, mode, bufsize, close=True)
_SSLErrorReadTimeout = SSLError('The read operation timed out')
_SSLErrorWriteTimeout = SSLError('The write operation timed out')
_SSLErrorHandshakeTimeout = SSLError('The handshake operation timed out')
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
"""Create a new :class:`SSLSocket` instance."""
return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt."""
host, port = addr
if (ca_certs is not None):
cert_reqs = CERT_REQUIRED
else:
cert_reqs = CERT_NONE
s = wrap_socket(socket(), ssl_version=ssl_version,
cert_reqs=cert_reqs, ca_certs=ca_certs)
s.connect(addr)
dercert = s.getpeercert(True)
s.close()
return DER_cert_to_PEM_cert(dercert)
def sslwrap_simple(sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
return SSLSocket(sock, keyfile, certfile)
# Wrapper module for _ssl. Written by Bill Janssen.
# Ported to gevent by Denis Bilenko.
"""SSL wrapper for socket objects.
For the documentation, refer to :mod:`ssl` module manual.
This module implements cooperative SSL socket wrappers.
"""
from __future__ import absolute_import
import ssl as __ssl__
_ssl = __ssl__._ssl
import errno
from gevent.socket import socket, timeout_default
from gevent.socket import error as socket_error
__implements__ = ['SSLContext',
'SSLSocket',
'wrap_socket',
'get_server_certificate']
for name in dir(__ssl__):
if name in __implements__:
continue
if name.startswith('__'):
continue
value = getattr(__ssl__, name)
globals()[name] = value
del name, value
orig_SSLContext = __ssl__.SSLContext
class SSLContext(orig_SSLContext):
__slots__ = ('protocol', '__weakref__')
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None):
return SSLSocket(sock=sock, server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
server_hostname=server_hostname,
_context=self)
class SSLSocket(socket):
def __init__(self, sock=None, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
server_hostname=None,
_context=None):
if _context:
self.context = _context
else:
if server_side and not certfile:
raise ValueError("certfile must be specified for server-side "
"operations")
if keyfile and not certfile:
raise ValueError("certfile must be specified")
if certfile and not keyfile:
keyfile = certfile
self.context = SSLContext(ssl_version)
self.context.verify_mode = cert_reqs
if ca_certs:
self.context.load_verify_locations(ca_certs)
if certfile:
self.context.load_cert_chain(certfile, keyfile)
if npn_protocols:
self.context.set_npn_protocols(npn_protocols)
if ciphers:
self.context.set_ciphers(ciphers)
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
self.ciphers = ciphers
# Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
# mixed in.
if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
raise NotImplementedError("only stream sockets are supported")
if server_side and server_hostname:
raise ValueError("server_hostname can only be specified "
"in client mode")
self.server_side = server_side
self.server_hostname = server_hostname
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
connected = False
if sock is not None:
socket.__init__(self,
family=sock.family,
type=sock.type,
proto=sock.proto,
fileno=sock.fileno())
self.settimeout(sock.gettimeout())
# see if it's connected
try:
sock.getpeername()
except socket_error as e:
if e.errno != errno.ENOTCONN:
raise
else:
connected = True
sock.detach()
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
socket.__init__(self, family=family, type=type, proto=proto)
self._closed = False
self._sslobj = None
self._connected = connected
if connected:
# create the SSL object
try:
self._sslobj = self.context._wrap_socket(self, server_side,
server_hostname)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
# non-blocking
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
self.do_handshake()
except socket_error as x:
self.close()
raise x
def dup(self):
raise NotImplemented("Can't dup() %s instances" %
self.__class__.__name__)
def _checkClosed(self, msg=None):
# raise an exception here if you wish to check for spurious closes
pass
def read(self, len=1024, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
while True:
try:
if buffer is not None:
return self._sslobj.read(len, buffer)
else:
return self._sslobj.read(len or 1024)
except SSLWantReadError:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
except SSLWantWriteError:
if self.timeout == 0.0:
raise
# note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional
self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is None:
return ''
else:
return 0
else:
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
while True:
try:
return self._sslobj.write(data)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a
certificate was provided, but not validated."""
self._checkClosed()
return self._sslobj.peer_certificate(binary_form)
def selected_npn_protocol(self):
self._checkClosed()
if not self._sslobj or not _ssl.HAS_NPN:
return None
else:
return self._sslobj.selected_npn_protocol()
def cipher(self):
self._checkClosed()
if not self._sslobj:
return None
else:
return self._sslobj.cipher()
def compression(self):
self._checkClosed()
if not self._sslobj:
return None
else:
return self._sslobj.compression()
def send(self, data, flags=0, timeout=timeout_default):
self._checkClosed()
if timeout is timeout_default:
timeout = self.timeout
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
while True:
try:
return self._sslobj.write(data)
except SSLWantReadError:
if self.timeout == 0.0:
return 0
self._wait(self._read_event)
except SSLWantWriteError:
if self.timeout == 0.0:
return 0
self._wait(self._write_event)
else:
return socket.send(self, data, flags, timeout)
def sendto(self, data, flags_or_addr, addr=None):
self._checkClosed()
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
elif addr is None:
return socket.sendto(self, data, flags_or_addr)
else:
return socket.sendto(self, data, flags_or_addr, addr)
def sendmsg(self, *args, **kwargs):
# Ensure programs don't send data unencrypted if they try to
# use this method.
raise NotImplementedError("sendmsg not allowed on instances of %s" %
self.__class__)
def sendall(self, data, flags=0):
self._checkClosed()
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to sendall() on %s" %
self.__class__)
amount = len(data)
count = 0
while (count < amount):
v = self.send(data[count:])
count += v
return amount
else:
return socket.sendall(self, data, flags)
def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
return self.read(buflen)
else:
return socket.recv(self, buflen, flags)
def recv_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError("non-zero flags not allowed in calls to recv_into() on %s" % self.__class__)
return self.read(nbytes, buffer)
else:
return socket.recv_into(self, buffer, nbytes, flags)
def recvfrom(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom(self, buflen, flags)
def recvfrom_into(self, buffer, nbytes=None, flags=0):
self._checkClosed()
if self._sslobj:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom_into(self, buffer, nbytes, flags)
def recvmsg(self, *args, **kwargs):
raise NotImplementedError("recvmsg not allowed on instances of %s" %
self.__class__)
def recvmsg_into(self, *args, **kwargs):
raise NotImplementedError("recvmsg_into not allowed on instances of "
"%s" % self.__class__)
def pending(self):
self._checkClosed()
if self._sslobj:
return self._sslobj.pending()
else:
return 0
def shutdown(self, how):
self._checkClosed()
self._sslobj = None
socket.shutdown(self, how)
def unwrap(self):
if self._sslobj:
s = self._sslobj.shutdown()
self._sslobj = None
return s
else:
raise ValueError("No SSL wrapper around " + str(self))
def _real_close(self):
self._sslobj = None
# self._closed = True
socket._real_close(self)
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
while True:
try:
return self._sslobj.do_handshake()
except SSLWantReadError:
if self.timeout == 0.0:
raise
self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout)
except SSLWantWriteError:
if self.timeout == 0.0:
raise
self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout)
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
else:
rc = None
socket.connect(self, addr)
if not rc:
if self.do_handshake_on_connect:
self.do_handshake()
self._connected = True
return rc
except socket_error:
self._sslobj = None
raise
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
self._real_connect(addr, False)
def connect_ex(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
return self._real_connect(addr, True)
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
newsock = self.context.wrap_socket(newsock,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
server_side=True)
return newsock, addr
def get_channel_binding(self, cb_type="tls-unique"):
"""Get channel binding data for current connection. Raise ValueError
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
if cb_type not in CHANNEL_BINDING_TYPES:
raise ValueError("Unsupported channel binding type")
if cb_type != "tls-unique":
raise NotImplementedError("{0} channel binding type not implemented".format(cb_type))
if self._sslobj is None:
return None
return self._sslobj.tls_unique_cb()
_SSLErrorReadTimeout = SSLError('The read operation timed out')
_SSLErrorWriteTimeout = SSLError('The write operation timed out')
_SSLErrorHandshakeTimeout = SSLError('The handshake operation timed out')
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt."""
host, port = addr
if (ca_certs is not None):
cert_reqs = CERT_REQUIRED
else:
cert_reqs = CERT_NONE
s = create_connection(addr)
s = wrap_socket(s, ssl_version=ssl_version,
cert_reqs=cert_reqs, ca_certs=ca_certs)
dercert = s.getpeercert(True)
s.close()
return DER_cert_to_PEM_cert(dercert)
# Copyright (c) 2005-2006, Bob Ippolito
# Copyright (c) 2007, Linden Research, Inc.
# Copyright (c) 2009-2012 Denis Bilenko
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# Copyright (c) 2009-2014 Denis Bilenko and gevent contributors. See LICENSE for details.
"""Cooperative socket module.
......@@ -31,505 +11,20 @@ For convenience, exceptions (like :class:`error <socket.error>` and :class:`time
as well as the constants from :mod:`socket` module are imported into this module.
"""
from __future__ import absolute_import
# standard functions and classes that this module re-implements in a gevent-aware way:
__implements__ = ['create_connection',
'socket',
'SocketType',
'fromfd',
'socketpair']
__dns__ = ['getaddrinfo',
'gethostbyname',
'gethostbyname_ex',
'gethostbyaddr',
'getnameinfo',
'getfqdn']
__implements__ += __dns__
# non-standard functions that this module provides:
__extensions__ = ['wait_read',
'wait_write',
'wait_readwrite']
# standard functions and classes that this module re-imports
__imports__ = ['error',
'gaierror',
'herror',
'htonl',
'htons',
'ntohl',
'ntohs',
'inet_aton',
'inet_ntoa',
'inet_pton',
'inet_ntop',
'timeout',
'gethostname',
'getprotobyname',
'getservbyname',
'getservbyport',
'getdefaulttimeout',
'setdefaulttimeout',
# Windows:
'errorTab']
import sys
import time
from gevent.hub import get_hub, string_types, integer_types
from gevent.timeout import Timeout
is_windows = sys.platform == 'win32'
if is_windows:
# no such thing as WSAEPERM or error code 10001 according to winsock.h or MSDN
from errno import WSAEINVAL as EINVAL
from errno import WSAEWOULDBLOCK as EWOULDBLOCK
from errno import WSAEINPROGRESS as EINPROGRESS
from errno import WSAEALREADY as EALREADY
from errno import WSAEISCONN as EISCONN
from gevent.win32util import formatError as strerror
EAGAIN = EWOULDBLOCK
else:
from errno import EINVAL
from errno import EWOULDBLOCK
from errno import EINPROGRESS
from errno import EALREADY
from errno import EAGAIN
from errno import EISCONN
from os import strerror
try:
from errno import EBADF
except ImportError:
EBADF = 9
import _socket
_realsocket = _socket.socket
import socket as __socket__
_fileobject = __socket__._fileobject
for name in __imports__[:]:
try:
value = getattr(__socket__, name)
globals()[name] = value
except AttributeError:
__imports__.remove(name)
for name in __socket__.__all__:
value = getattr(__socket__, name)
if isinstance(value, integer_types) or isinstance(value, string_types):
globals()[name] = value
__imports__.append(name)
del name, value
def wait(io, timeout=None, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *io* is ready.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
assert io.callback is None, 'This socket is already used by another greenlet: %r' % (io.callback, )
if timeout is not None:
timeout = Timeout.start_new(timeout, timeout_exc)
try:
return get_hub().wait(io)
finally:
if timeout is not None:
timeout.cancel()
# rename "io" to "watcher" because wait() works with any watcher
def wait_read(fileno, timeout=None, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *fileno* is ready to read.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 1)
return wait(io, timeout, timeout_exc)
def wait_write(fileno, timeout=None, timeout_exc=timeout('timed out'), event=None):
"""Block the current greenlet until *fileno* is ready to write.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 2)
return wait(io, timeout, timeout_exc)
def wait_readwrite(fileno, timeout=None, timeout_exc=timeout('timed out'), event=None):
"""Block the current greenlet until *fileno* is ready to read or write.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
io = get_hub().loop.io(fileno, 3)
return wait(io, timeout, timeout_exc)
cancel_wait_ex = error(EBADF, 'File descriptor was closed in another greenlet')
def cancel_wait(watcher):
get_hub().cancel_wait(watcher, cancel_wait_ex)
if sys.version_info[:2] < (2, 7):
_get_memory = buffer
elif sys.version_info[:2] < (3, 0):
def _get_memory(string, offset):
try:
return memoryview(string)[offset:]
except TypeError:
return buffer(string, offset)
else:
def _get_memory(string, offset):
return memoryview(string)[offset:]
class _closedsocket(object):
__slots__ = []
def _dummy(*args, **kwargs):
raise error(EBADF, 'Bad file descriptor')
# All _delegate_methods must also be initialized here.
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
__getattr__ = _dummy
timeout_default = object()
class socket(object):
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
if _sock is None:
self._sock = _realsocket(family, type, proto)
self.timeout = _socket.getdefaulttimeout()
else:
if hasattr(_sock, '_sock'):
self._sock = _sock._sock
self.timeout = getattr(_sock, 'timeout', False)
if self.timeout is False:
self.timeout = _socket.getdefaulttimeout()
else:
self._sock = _sock
self.timeout = _socket.getdefaulttimeout()
self._sock.setblocking(0)
fileno = self._sock.fileno()
self.hub = get_hub()
io = self.hub.loop.io
self._read_event = io(fileno, 1)
self._write_event = io(fileno, 2)
def __repr__(self):
return '<%s at %s %s>' % (type(self).__name__, hex(id(self)), self._formatinfo())
def __str__(self):
return '<%s %s>' % (type(self).__name__, self._formatinfo())
def _formatinfo(self):
try:
fileno = self.fileno()
except Exception as ex:
fileno = str(ex)
try:
sockname = self.getsockname()
sockname = '%s:%s' % sockname
except Exception:
sockname = None
try:
peername = self.getpeername()
peername = '%s:%s' % peername
except Exception:
peername = None
result = 'fileno=%s' % fileno
if sockname is not None:
result += ' sock=' + str(sockname)
if peername is not None:
result += ' peer=' + str(peername)
if getattr(self, 'timeout', None) is not None:
result += ' timeout=' + str(self.timeout)
return result
def _get_ref(self):
return self._read_event.ref or self._write_event.ref
def _set_ref(self, value):
self._read_event.ref = value
self._write_event.ref = value
ref = property(_get_ref, _set_ref)
def _wait(self, watcher, timeout_exc=timeout('timed out')):
"""Block the current greenlet until *watcher* has pending events.
If *timeout* is non-negative, then *timeout_exc* is raised after *timeout* second has passed.
By default *timeout_exc* is ``socket.timeout('timed out')``.
from gevent.hub import PY3
If :func:`cancel_wait` is called, raise ``socket.error(EBADF, 'File descriptor was closed in another greenlet')``.
"""
assert watcher.callback is None, 'This socket is already used by another greenlet: %r' % (watcher.callback, )
if self.timeout is not None:
timeout = Timeout.start_new(self.timeout, timeout_exc, ref=False)
else:
timeout = None
try:
self.hub.wait(watcher)
finally:
if timeout is not None:
timeout.cancel()
def accept(self):
sock = self._sock
while True:
try:
client_socket, address = sock.accept()
break
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
return socket(_sock=client_socket), address
def close(self, _closedsocket=_closedsocket, cancel_wait_ex=cancel_wait_ex):
# This function should not reference any globals. See Python issue #808164.
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
self._sock = _closedsocket()
@property
def closed(self):
return isinstance(self._sock, _closedsocket)
def connect(self, address):
if self.timeout == 0.0:
return self._sock.connect(address)
sock = self._sock
if isinstance(address, tuple):
r = getaddrinfo(address[0], address[1], sock.family, sock.type, sock.proto)
address = r[0][-1]
if self.timeout is not None:
timer = Timeout.start_new(self.timeout, timeout('timed out'))
else:
timer = None
try:
while True:
err = sock.getsockopt(SOL_SOCKET, SO_ERROR)
if err:
raise error(err, strerror(err))
result = sock.connect_ex(address)
if not result or result == EISCONN:
break
elif (result in (EWOULDBLOCK, EINPROGRESS, EALREADY)) or (result == EINVAL and is_windows):
self._wait(self._write_event)
else:
raise error(result, strerror(result))
finally:
if timer is not None:
timer.cancel()
def connect_ex(self, address):
try:
return self.connect(address) or 0
except timeout:
return EAGAIN
except error as ex:
if type(ex) is error:
return ex.args[0]
else:
raise # gaierror is not silented by connect_ex
def dup(self):
"""dup() -> socket object
Return a new socket object connected to the same system resource.
Note, that the new socket does not inherit the timeout."""
return socket(_sock=self._sock)
def makefile(self, mode='r', bufsize=-1):
# Two things to look out for:
# 1) Closing the original socket object should not close the
# socket (hence creating a new instance)
# 2) The resulting fileobject must keep the timeout in order
# to be compatible with the stdlib's socket.makefile.
return _fileobject(type(self)(_sock=self), mode, bufsize)
def recv(self, *args):
sock = self._sock # keeping the reference so that fd is not closed during waiting
while True:
try:
return sock.recv(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
# QQQ without clearing exc_info test__refcount.test_clean_exit fails
sys.exc_clear()
self._wait(self._read_event)
def recvfrom(self, *args):
sock = self._sock
while True:
try:
return sock.recvfrom(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def recvfrom_into(self, *args):
sock = self._sock
while True:
try:
return sock.recvfrom_into(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def recv_into(self, *args):
sock = self._sock
while True:
try:
return sock.recv_into(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
def send(self, data, flags=0, timeout=timeout_default):
sock = self._sock
if timeout is timeout_default:
timeout = self.timeout
try:
return sock.send(data, flags)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event)
try:
return sock.send(data, flags)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def sendall(self, data, flags=0):
if isinstance(data, unicode):
data = data.encode()
# this sendall is also reused by gevent.ssl.SSLSocket subclass,
# so it should not call self._sock methods directly
if self.timeout is None:
data_sent = 0
while data_sent < len(data):
data_sent += self.send(_get_memory(data, data_sent), flags)
else:
timeleft = self.timeout
end = time.time() + timeleft
data_sent = 0
while True:
data_sent += self.send(_get_memory(data, data_sent), flags, timeout=timeleft)
if data_sent >= len(data):
break
timeleft = end - time.time()
if timeleft <= 0:
raise timeout('timed out')
def sendto(self, *args):
sock = self._sock
try:
return sock.sendto(*args)
except error as ex:
if ex.args[0] != EWOULDBLOCK or timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event)
try:
return sock.sendto(*args)
except error as ex2:
if ex2.args[0] == EWOULDBLOCK:
return 0
raise
def setblocking(self, flag):
if flag:
self.timeout = None
else:
self.timeout = 0.0
def settimeout(self, howlong):
if howlong is not None:
try:
f = howlong.__float__
except AttributeError:
raise TypeError('a float is required')
howlong = f()
if howlong < 0.0:
raise ValueError('Timeout value out of range')
self.timeout = howlong
def gettimeout(self):
return self.timeout
def shutdown(self, how):
if how == 0: # SHUT_RD
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
elif how == 1: # SHUT_WR
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
else:
self.hub.cancel_wait(self._read_event, cancel_wait_ex)
self.hub.cancel_wait(self._write_event, cancel_wait_ex)
self._sock.shutdown(how)
family = property(lambda self: self._sock.family, doc="the socket family")
type = property(lambda self: self._sock.type, doc="the socket type")
proto = property(lambda self: self._sock.proto, doc="the socket protocol")
# delegate the functions that we haven't implemented to the real socket object
_s = ("def %s(self, *args): return self._sock.%s(*args)\n\n"
"%s.__doc__ = _realsocket.%s.__doc__\n")
for _m in set(__socket__._socketmethods) - set(locals()):
exec(_s % (_m, _m, _m, _m))
del _m, _s
SocketType = socket
if hasattr(_socket, 'socketpair'):
def socketpair(*args):
one, two = _socket.socketpair(*args)
return socket(_sock=one), socket(_sock=two)
if PY3:
from gevent import _socket3 as _source
else:
__implements__.remove('socketpair')
from gevent import _socket2 as _source
if hasattr(_socket, 'fromfd'):
def fromfd(*args):
return socket(_sock=_socket.fromfd(*args))
else:
__implements__.remove('fromfd')
for key in _source.__dict__:
if key.startswith('__') and key not in '__implements__ __dns__ __all__ __extensions__ __imports__ __socket__'.split():
continue
globals()[key] = getattr(_source, key)
try:
......@@ -564,80 +59,17 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N
sock.bind(source_address)
sock.connect(sa)
return sock
except error as err:
except error as ex:
# without exc_clear(), if connect() fails once, the socket is referenced by the frame in exc_info
# and the next bind() fails (see test__socket.TestCreateConnection)
# that does not happen with regular sockets though, because _socket.socket.connect() is a built-in.
# this is similar to "getnameinfo loses a reference" failure in test_socket.py
sys.exc_clear()
if not PY3:
sys.exc_clear()
if sock is not None:
sock.close()
err = ex
if err is not None:
raise err
else:
raise error("getaddrinfo returns an empty list")
class BlockingResolver(object):
def __init__(self, hub=None):
pass
def close(self):
pass
for method in ['gethostbyname',
'gethostbyname_ex',
'getaddrinfo',
'gethostbyaddr',
'getnameinfo']:
locals()[method] = staticmethod(getattr(_socket, method))
def gethostbyname(hostname):
return get_hub().resolver.gethostbyname(hostname)
def gethostbyname_ex(hostname):
return get_hub().resolver.gethostbyname_ex(hostname)
def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0):
return get_hub().resolver.getaddrinfo(host, port, family, socktype, proto, flags)
def gethostbyaddr(ip_address):
return get_hub().resolver.gethostbyaddr(ip_address)
def getnameinfo(sockaddr, flags):
return get_hub().resolver.getnameinfo(sockaddr, flags)
def getfqdn(name=''):
"""Get fully qualified domain name from name.
An empty argument is interpreted as meaning the local host.
First the hostname returned by gethostbyaddr() is checked, then
possibly existing aliases. In case no FQDN is available, hostname
from gethostname() is returned.
"""
name = name.strip()
if not name or name == '0.0.0.0':
name = gethostname()
try:
hostname, aliases, ipaddrs = gethostbyaddr(name)
except error:
pass
else:
aliases.insert(0, hostname)
for name in aliases:
if '.' in name:
break
else:
name = hostname
return name
__all__ = __implements__ + __extensions__ + __imports__
# Wrapper module for _ssl. Written by Bill Janssen.
# Ported to gevent by Denis Bilenko.
"""SSL wrapper for socket objects.
from gevent.hub import PY3
For the documentation, refer to :mod:`ssl` module manual.
This module implements cooperative SSL socket wrappers.
"""
if PY3:
from gevent import _ssl3 as _source
else:
from gevent import _ssl2 as _source
from __future__ import absolute_import
import ssl as __ssl__
try:
_ssl = __ssl__._ssl
except AttributeError:
_ssl = __ssl__._ssl2
import sys
import errno
from gevent.socket import socket, _fileobject, timeout_default
from gevent.socket import error as socket_error
from gevent.hub import string_types
__implements__ = ['SSLSocket',
'wrap_socket',
'get_server_certificate',
'sslwrap_simple']
__imports__ = ['SSLError',
'RAND_status',
'RAND_egd',
'RAND_add',
'cert_time_to_seconds',
'get_protocol_name',
'DER_cert_to_PEM_cert',
'PEM_cert_to_DER_cert']
for name in __imports__[:]:
try:
value = getattr(__ssl__, name)
globals()[name] = value
except AttributeError:
__imports__.remove(name)
for name in dir(__ssl__):
if not name.startswith('_'):
value = getattr(__ssl__, name)
if isinstance(value, (int, long, tuple)) or isinstance(value, string_types):
globals()[name] = value
__imports__.append(name)
del name, value
__all__ = __implements__ + __imports__
class SSLSocket(socket):
def __init__(self, sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
ciphers=None):
socket.__init__(self, _sock=sock)
if certfile and not keyfile:
keyfile = certfile
# see if it's connected
try:
socket.getpeername(self)
except socket_error as e:
if e.args[0] != errno.ENOTCONN:
raise
# no, no connection yet
self._sslobj = None
else:
# yes, create the SSL object
if ciphers is None:
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs)
else:
self._sslobj = _ssl.sslwrap(self._sock, server_side,
keyfile, certfile,
cert_reqs, ssl_version, ca_certs,
ciphers)
if do_handshake_on_connect:
self.do_handshake()
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
self.ssl_version = ssl_version
self.ca_certs = ca_certs
self.ciphers = ciphers
self.do_handshake_on_connect = do_handshake_on_connect
self.suppress_ragged_eofs = suppress_ragged_eofs
self._makefile_refs = 0
def read(self, len=1024):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
while True:
try:
return self._sslobj.read(len)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
elif ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
# note: using _SSLErrorReadTimeout rather than _SSLErrorWriteTimeout below is intentional
self._wait(self._write_event, timeout_exc=_SSLErrorReadTimeout)
else:
raise
def write(self, data):
"""Write DATA to the underlying SSL channel. Returns
number of bytes of DATA actually transmitted."""
while True:
try:
return self._sslobj.write(data)
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorWriteTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def getpeercert(self, binary_form=False):
"""Returns a formatted version of the data in the
certificate provided by the other end of the SSL channel.
Return None if no certificate was provided, {} if a
certificate was provided, but not validated."""
return self._sslobj.peer_certificate(binary_form)
def cipher(self):
if not self._sslobj:
return None
else:
return self._sslobj.cipher()
def send(self, data, flags=0, timeout=timeout_default):
if timeout is timeout_default:
timeout = self.timeout
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
while True:
try:
v = self._sslobj.write(data)
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
return 0
sys.exc_clear()
self._wait(self._read_event)
elif x.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
return 0
sys.exc_clear()
self._wait(self._write_event)
else:
raise
else:
return v
else:
return socket.send(self, data, flags, timeout)
# is it possible for sendall() to send some data without encryption if another end shut down SSL?
def sendto(self, *args):
if self._sslobj:
raise ValueError("sendto not allowed on instances of %s" %
self.__class__)
else:
return socket.sendto(self, *args)
def recv(self, buflen=1024, flags=0):
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv() on %s" %
self.__class__)
# QQQ Shouldn't we wrap the SSL_WANT_READ errors as socket.timeout errors to match socket.recv's behavior?
return self.read(buflen)
else:
return socket.recv(self, buflen, flags)
def recv_into(self, buffer, nbytes=None, flags=0):
if buffer and (nbytes is None):
nbytes = len(buffer)
elif nbytes is None:
nbytes = 1024
if self._sslobj:
if flags != 0:
raise ValueError(
"non-zero flags not allowed in calls to recv_into() on %s" %
self.__class__)
while True:
try:
tmp_buffer = self.read(nbytes)
v = len(tmp_buffer)
buffer[:v] = tmp_buffer
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event)
continue
else:
raise
else:
return socket.recv_into(self, buffer, nbytes, flags)
def recvfrom(self, *args):
if self._sslobj:
raise ValueError("recvfrom not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom(self, *args)
def recvfrom_into(self, *args):
if self._sslobj:
raise ValueError("recvfrom_into not allowed on instances of %s" %
self.__class__)
else:
return socket.recvfrom_into(self, *args)
def pending(self):
if self._sslobj:
return self._sslobj.pending()
else:
return 0
def _sslobj_shutdown(self):
while True:
try:
return self._sslobj.shutdown()
except SSLError as ex:
if ex.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return ''
elif ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorReadTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorWriteTimeout)
else:
raise
def unwrap(self):
if self._sslobj:
s = self._sslobj_shutdown()
self._sslobj = None
return socket(_sock=s)
else:
raise ValueError("No SSL wrapper around " + str(self))
def shutdown(self, how):
self._sslobj = None
socket.shutdown(self, how)
def close(self):
if self._makefile_refs < 1:
self._sslobj = None
socket.close(self)
else:
self._makefile_refs -= 1
def do_handshake(self):
"""Perform a TLS/SSL handshake."""
while True:
try:
return self._sslobj.do_handshake()
except SSLError as ex:
if ex.args[0] == SSL_ERROR_WANT_READ:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._read_event, timeout_exc=_SSLErrorHandshakeTimeout)
elif ex.args[0] == SSL_ERROR_WANT_WRITE:
if self.timeout == 0.0:
raise
sys.exc_clear()
self._wait(self._write_event, timeout_exc=_SSLErrorHandshakeTimeout)
else:
raise
def connect(self, addr):
"""Connects to remote ADDR, and then wraps the connection in
an SSL channel."""
# Here we assume that the socket is client-side, and not
# connected at the time of the call. We connect it, then wrap it.
if self._sslobj:
raise ValueError("attempt to connect already-connected SSLSocket!")
socket.connect(self, addr)
if self.ciphers is None:
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
else:
self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs, self.ciphers)
if self.do_handshake_on_connect:
self.do_handshake()
def accept(self):
"""Accepts a new connection from a remote client, and returns
a tuple containing that new connection wrapped with a server-side
SSL channel, and the address of the remote client."""
newsock, addr = socket.accept(self)
return (SSLSocket(newsock._sock,
keyfile=self.keyfile,
certfile=self.certfile,
server_side=True,
cert_reqs=self.cert_reqs,
ssl_version=self.ssl_version,
ca_certs=self.ca_certs,
do_handshake_on_connect=self.do_handshake_on_connect,
suppress_ragged_eofs=self.suppress_ragged_eofs,
ciphers=self.ciphers),
addr)
def makefile(self, mode='r', bufsize=-1):
"""Make and return a file-like object that
works with the SSL connection. Just use the code
from the socket module."""
self._makefile_refs += 1
# close=True so as to decrement the reference count when done with
# the file-like object.
return _fileobject(self, mode, bufsize, close=True)
_SSLErrorReadTimeout = SSLError('The read operation timed out')
_SSLErrorWriteTimeout = SSLError('The write operation timed out')
_SSLErrorHandshakeTimeout = SSLError('The handshake operation timed out')
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
do_handshake_on_connect=True,
suppress_ragged_eofs=True, ciphers=None):
"""Create a new :class:`SSLSocket` instance."""
return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
server_side=server_side, cert_reqs=cert_reqs,
ssl_version=ssl_version, ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers)
def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
"""Retrieve the certificate from the server at the specified address,
and return it as a PEM-encoded string.
If 'ca_certs' is specified, validate the server cert against it.
If 'ssl_version' is specified, use it in the connection attempt."""
host, port = addr
if (ca_certs is not None):
cert_reqs = CERT_REQUIRED
else:
cert_reqs = CERT_NONE
s = wrap_socket(socket(), ssl_version=ssl_version,
cert_reqs=cert_reqs, ca_certs=ca_certs)
s.connect(addr)
dercert = s.getpeercert(True)
s.close()
return DER_cert_to_PEM_cert(dercert)
def sslwrap_simple(sock, keyfile=None, certfile=None):
"""A replacement for the old socket.ssl function. Designed
for compability with Python 2.5 and earlier. Will disappear in
Python 3.0."""
return SSLSocket(sock, keyfile, certfile)
for key in _source.__dict__:
if key.startswith('__') and key not in '__implements__ __all__ __imports__'.split():
continue
globals()[key] = getattr(_source, key)
......@@ -33,7 +33,7 @@ NOT_IMPLEMENTED = {
COULD_BE_MISSING = {
'socket': ['create_connection', 'RAND_add', 'RAND_egd', 'RAND_status']}
NO_ALL = ['gevent.threading', 'gevent._util']
NO_ALL = ['gevent.threading', 'gevent._util', 'gevent._socketcommon']
class Test(unittest.TestCase):
......@@ -140,6 +140,10 @@ are missing from %r:
raise AssertionError(msg)
def _test(self, modname):
if modname.endswith('2'):
return
if modname.endswith('3'):
return
self.modname = modname
six.exec_("import %s" % modname, {})
self.module = sys.modules[modname]
......
import sys
from greentest import walk_modules, BaseTestCase, main
import six
......@@ -21,6 +22,10 @@ def make_exec_test(path, module):
for path, module in walk_modules():
if sys.version_info[0] == 3 and path.endswith('2.py'):
continue
if sys.version_info[0] == 2 and path.endswith('3.py'):
continue
make_exec_test(path, module)
......
......@@ -147,7 +147,7 @@ class TestTCP(greentest.TestCase):
def accept_once():
conn, addr = self.listener.accept()
fd = conn.makefile()
fd = conn.makefile(mode='w')
fd.write('hello\n')
fd.close()
......
......@@ -8,8 +8,11 @@ import glob
IGNORED = r'''
gevent/socket.py:\d+: undefined name
gevent/_socket[23].py:\d+: undefined name
gevent/_socketcommon.py:\d+: undefined name
gevent/_socketcommon.py:\d+: .*imported but unused
gevent/subprocess.py:\d+: undefined name
gevent/ssl.py:\d+: undefined name
gevent/_?ssl[23]?.py:\d+: undefined name
gevent/__init__.py:\d+:.*imported but unused
gevent/__init__.py:\d+: redefinition of unused 'signal' from line
gevent/coros.py:\d+: 'from gevent.lock import *' used; unable to detect undefined names
......@@ -68,5 +71,11 @@ def pyflakes(args):
pyflakes('examples/ greentest/*.py util/ *.py')
py = set(glob.glob('gevent/*.py')) - set(['gevent/_util_py2.py'])
if sys.version_info[0] == 3:
ignored_files = ['gevent/_util_py2.py', 'gevent/_socket2.py']
else:
ignored_files = ['gevent/_socket3.py']
py = set(glob.glob('gevent/*.py')) - set(ignored_files)
pyflakes(' '.join(py))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment