Commit 9967fdb3 authored by Denis Bilenko's avatar Denis Bilenko

simplify GreenSocket and GreenSSL

- drop GreenFile (makefile() uses _fileobject), GreenPipe and GreenPipeSocket
parent eb5a537e
......@@ -32,9 +32,7 @@ _fileobject = __socket__._fileobject
sslerror = __socket__.sslerror
import errno
import os
import time
from errno import EAGAIN
from gevent.greenlet import wait_reader, wait_writer, spawn
......@@ -56,50 +54,6 @@ except ImportError:
class SysCallError(object):
pass
def higher_order_recv(recv_func):
def recv(self, buflen):
if self.act_non_blocking:
return self.fd.recv(buflen)
buf = self.recvbuffer
if buf:
chunk, self.recvbuffer = buf[:buflen], buf[buflen:]
return chunk
fd = self.fd
bytes = recv_func(fd, buflen)
if self.gettimeout():
end = time.time()+self.gettimeout()
else:
end = None
timeout_seconds = None
while bytes is None:
try:
if end:
timeout_seconds = end - time.time()
wait_reader(fd.fileno(), timeout=timeout_seconds, timeout_exc=timeout)
except timeout:
raise
except error, e:
if e[0] == errno.EPIPE:
bytes = ''
else:
raise
else:
bytes = recv_func(fd, buflen)
return bytes
return recv
def higher_order_send(send_func):
def send(self, data):
if self.act_non_blocking:
return self.fd.send(data)
count = send_func(self.fd, data)
if not count:
return 0
return count
return send
CONNECT_ERR = (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK)
CONNECT_SUCCESS = (0, errno.EISCONN)
def socket_connect(descriptor, address):
......@@ -111,157 +65,70 @@ def socket_connect(descriptor, address):
return descriptor
def socket_accept(descriptor):
try:
return descriptor.accept()
except error, e:
if e[0] == errno.EWOULDBLOCK:
return None
raise
def socket_send(descriptor, data):
try:
return descriptor.send(data)
except error, e:
if e[0] == errno.EWOULDBLOCK or e[0] == errno.ENOTCONN:
return 0
raise
except SSL.WantWriteError:
return 0
except SSL.WantReadError:
return 0
# winsock sometimes throws ENOTCONN
SOCKET_CLOSED = (errno.ECONNRESET, errno.ENOTCONN, errno.ESHUTDOWN)
def socket_recv(descriptor, buflen):
try:
return descriptor.recv(buflen)
except error, e:
if e[0] == errno.EWOULDBLOCK:
return None
if e[0] in SOCKET_CLOSED:
return ''
raise
except SSL.WantReadError:
return None
except SSL.ZeroReturnError:
return ''
except SSL.SysCallError, e:
if e[0] == -1 or e[0] > 0:
return ''
raise
class _closedsocket(object):
__slots__ = []
def _dummy(*args):
raise error(errno.EBADF, 'Bad file descriptor')
# All _delegate_methods must also be initialized here.
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
__getattr__ = _dummy
def file_recv(fd, buflen):
try:
return fd.read(buflen)
except IOError, e:
if e[0] == EAGAIN:
return None
return ''
except error, e:
if e[0] == errno.EPIPE:
return ''
raise
def file_send(fd, data):
try:
fd.write(data)
fd.flush()
return len(data)
except IOError, e:
if e[0] == EAGAIN:
return 0
except ValueError, e:
written = 0
except error, e:
if e[0] == errno.EPIPE:
written = 0
def set_nonblocking(fd):
try:
setblocking = fd.setblocking
except AttributeError:
# This version of Python predates socket.setblocking()
import fcntl
fileno = fd.fileno()
flags = fcntl.fcntl(fileno, fcntl.F_GETFL)
fcntl.fcntl(fileno, fcntl.F_SETFL, flags | os.O_NONBLOCK)
else:
# socket supports setblocking()
setblocking(0)
_delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into", "send", "sendto", 'sendall')
class GreenSocket(object):
is_secure = False
timeout = None
is_secure = False # XXX remove this
def __init__(self, family_or_realsock=_socket.AF_INET, *args, **kwargs):
if isinstance(family_or_realsock, (int, long)):
fd = _original_socket(family_or_realsock, *args, **kwargs)
self.fd = _original_socket(family_or_realsock, *args, **kwargs)
self.timeout = _socket.getdefaulttimeout()
else:
fd = family_or_realsock
if hasattr(family_or_realsock, '_sock'):
family_or_realsock = family_or_realsock.sock
self.fd = family_or_realsock
self.timeout = self.fd.gettimeout()
assert not args, args
assert not kwargs, kwargs
set_nonblocking(fd)
self.fd = fd
self._fileno = fd.fileno()
self.recvbuffer = ''
self.closed = False
self.timeout = _socket.getdefaulttimeout()
# when client calls setblocking(0) or settimeout(0) the socket must
# act non-blocking
self.act_non_blocking = False
self.fd.setblocking(0)
def __repr__(self):
return '<%s at %s fileno=%s>' % (type(self).__name__, hex(id(self)), self.fileno())
@property
def family(self):
return self.fd.family
@property
def type(self):
return self.fd.type
try:
fileno = self.fileno()
except Exception, ex:
fileno = str(ex)
return '<%s at %s fileno=%s timeout=%s>' % (type(self).__name__, hex(id(self)), fileno, self.timeout)
@property
def proto(self):
return self.fd.proto
def __getattr__(self, item):
return getattr(self.fd, item)
def accept(self):
if self.act_non_blocking:
if self.timeout==0.0:
return self.fd.accept()
fd = self.fd
while True:
res = socket_accept(fd)
try:
res = self.fd.accept()
except error, e:
if e[0] == errno.EWOULDBLOCK:
res = None
else:
raise
if res is not None:
client, addr = res
set_nonblocking(client)
return type(self)(client), addr
wait_reader(fd.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
def bind(self, *args, **kw):
fn = self.bind = self.fd.bind
return fn(*args, **kw)
def close(self, *args, **kw):
if self.closed:
return
self.closed = True
if self.is_secure:
# *NOTE: This is not quite the correct SSL shutdown sequence.
# We should actually be checking the return value of shutdown.
# Note also that this is not the same as calling self.shutdown().
self.fd.shutdown()
return self.fd.close(*args, **kw)
def close(self):
self.fd = _closedsocket()
dummy = self.fd._dummy
for method in _delegate_methods:
setattr(self, method, dummy)
def connect(self, address):
if self.act_non_blocking:
if self.timeout==0.0:
return self.fd.connect(address)
fd = self.fd
if self.gettimeout() is None:
......@@ -277,7 +144,7 @@ class GreenSocket(object):
wait_writer(fd.fileno(), timeout=end-time.time(), timeout_exc=timeout)
def connect_ex(self, address):
if self.act_non_blocking:
if self.timeout==0.0:
return self.fd.connect_ex(address)
fd = self.fd
if self.gettimeout() is None:
......@@ -300,86 +167,57 @@ class GreenSocket(object):
def dup(self, *args, **kw):
sock = self.fd.dup(*args, **kw)
set_nonblocking(sock)
newsock = type(self)(sock)
newsock.settimeout(self.timeout)
return newsock
def fileno(self, *args, **kw):
fn = self.fileno = self.fd.fileno
return fn(*args, **kw)
def getpeername(self, *args, **kw):
fn = self.getpeername = self.fd.getpeername
return fn(*args, **kw)
def getsockname(self, *args, **kw):
fn = self.getsockname = self.fd.getsockname
return fn(*args, **kw)
def getsockopt(self, *args, **kw):
fn = self.getsockopt = self.fd.getsockopt
return fn(*args, **kw)
def listen(self, *args, **kw):
fn = self.listen = self.fd.listen
return fn(*args, **kw)
def makefile(self, mode='r', bufsize=-1):
return _fileobject(self.dup(), mode, bufsize)
def makeGreenFile(self, mode='r', bufsize=-1):
return GreenFile(self.dup())
recv = higher_order_recv(socket_recv)
def recv(self, *args):
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
res = self.fd.recv(*args)
return res
def recvfrom(self, *args):
if not self.act_non_blocking:
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recvfrom(*args)
def recvfrom_into(self, *args):
if not self.act_non_blocking:
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recvfrom_into(*args)
def recv_into(self, *args):
if not self.act_non_blocking:
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recv_into(*args)
send = higher_order_send(socket_send)
def send(self, *args):
if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.send(*args)
def sendall(self, data):
fd = self.fd
# XXX does not respect timeout
tail = self.send(data)
while tail < len(data):
wait_writer(self.fileno(), timeout_exc=timeout)
tail += self.send(data[tail:])
def sendto(self, *args):
wait_writer(self.fileno(), timeout_exc=timeout)
if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.timeout, timeout_exc=timeout)
return self.fd.sendto(*args)
def setblocking(self, flag):
if flag:
self.act_non_blocking = False
self.timeout = None
else:
self.act_non_blocking = True
self.timeout = 0.0
def setsockopt(self, *args, **kw):
fn = self.setsockopt = self.fd.setsockopt
return fn(*args, **kw)
def shutdown(self, *args, **kw):
if self.is_secure:
fn = self.shutdown = self.fd.sock_shutdown
else:
fn = self.shutdown = self.fd.shutdown
return fn(*args, **kw)
def settimeout(self, howlong):
if howlong is None:
self.setblocking(True)
......@@ -391,190 +229,90 @@ class GreenSocket(object):
howlong = f()
if howlong < 0.0:
raise ValueError('Timeout value out of range')
if howlong == 0.0:
self.setblocking(howlong)
else:
self.timeout = howlong
self.timeout = howlong
def gettimeout(self):
return self.timeout
def read(self, size=None):
if size is not None and not isinstance(size, (int, long)):
raise TypeError('Expecting an int or long for size, got %s: %s' % (type(size), repr(size)))
buf, self.sock.recvbuffer = self.sock.recvbuffer, ''
lst = [buf]
if size is None:
while True:
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
lst.append(d)
else:
buflen = len(buf)
while buflen < size:
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
buflen += len(d)
lst.append(d)
else:
d = lst[-1]
overbite = buflen - size
if overbite:
lst[-1], self.sock.recvbuffer = d[:-overbite], d[-overbite:]
else:
lst[-1], self.sock.recvbuffer = d, ''
return ''.join(lst)
class GreenFile(object):
newlines = '\r\n'
mode = 'wb+'
def __init__(self, fd):
if isinstance(fd, GreenSocket):
set_nonblocking(fd.fd)
else:
set_nonblocking(fd)
self.sock = fd
self.closed = False
def close(self):
self.sock.close()
self.closed = True
def fileno(self):
return self.sock.fileno()
# TODO next
class GreenSSL(GreenSocket):
is_secure = True
def flush(self):
pass
def __init__(self, fd, do_handshake_on_connect=True):
GreenSocket.__init__(self, fd)
self._makefile_refs = 0
def write(self, data):
return self.sock.sendall(data)
def accept(self):
if self.timeout==0.0:
return self.fd.accept()
fd = self.fd
while True:
try:
res = self.fd.accept()
except error, e:
if e[0] == errno.EWOULDBLOCK:
res = None
else:
raise
if res is not None:
client, addr = res
accepted = type(self)(client)
accepted.do_handshake() # XXX
return accepted, addr
wait_reader(fd.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
def readuntil(self, terminator, size=None):
buf, self.sock.recvbuffer = self.sock.recvbuffer, ''
checked = 0
if size is None:
while True:
found = buf.find(terminator, checked)
if found != -1:
found += len(terminator)
chunk, self.sock.recvbuffer = buf[:found], buf[found:]
return chunk
checked = max(0, len(buf) - (len(terminator) - 1))
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
buf += d
return buf
while len(buf) < size:
found = buf.find(terminator, checked)
if found != -1:
found += len(terminator)
chunk, self.sock.recvbuffer = buf[:found], buf[found:]
return chunk
checked = len(buf)
d = self.sock.recv(BUFFER_SIZE)
if not d:
def do_handshake(self):
while True:
try:
self.fd.do_handshake()
break
buf += d
chunk, self.sock.recvbuffer = buf[:size], buf[size:]
return chunk
except SSL.WantReadError:
wait_reader(self.fileno())
except SSL.WantWriteError:
wait_writer(self.fileno())
except SSL.SysCallError, ex:
# XXX fix ex[0]
raise sslerror(str(ex))
def connect(self, *args):
GreenSocket.connect(self, *args)
self.do_handshake()
def readline(self, size=None):
return self.readuntil(self.newlines, size=size)
def __iter__(self):
return self.xreadlines()
def readlines(self, size=None):
return list(self.xreadlines(size=size))
def send(self, data):
if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
try:
return self.fd.send(data)
except SSL.WantWriteError:
return 0
except SSL.WantReadError:
return 0
except SSL.SysCallError, e:
if e[0] == -1 and data == "":
# errors when writing empty strings are expected
# and can be ignored
return 0
def xreadlines(self, size=None):
if size is None:
def recv(self, buflen):
pending = self.fd.pending()
if pending:
return self.fd.recv(min(pending, buflen))
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
try:
while True:
line = self.readline()
if not line:
break
yield line
else:
while size > 0:
line = self.readline(size)
if not line:
break
yield line
size -= len(line)
def writelines(self, lines):
for line in lines:
self.write(line)
read = read
class GreenPipeSocket(GreenSocket):
""" This is a weird class that looks like a socket but expects a file descriptor as an argument instead of a socket.
"""
recv = higher_order_recv(file_recv)
send = higher_order_send(file_send)
class GreenPipe(GreenFile):
def __init__(self, fd):
set_nonblocking(fd)
self.fd = GreenPipeSocket(fd)
super(GreenPipe, self).__init__(self.fd)
def recv(self, *args, **kw):
fn = self.recv = self.fd.recv
return fn(*args, **kw)
def send(self, *args, **kw):
fn = self.send = self.fd.send
return fn(*args, **kw)
def flush(self):
self.fd.fd.flush()
class RefCount(object):
""" Reference counting class only to be used with GreenSSL objects """
def __init__(self):
self._count = 1
def increment(self):
self._count += 1
def decrement(self):
self._count -= 1
assert self._count >= 0
def is_referenced(self):
return self._count > 0
class GreenSSL(GreenSocket):
def __init__(self, fd, refcount = None):
GreenSocket.__init__(self, fd)
self.sock = self
self._refcount = refcount
if refcount is None:
self._refcount = RefCount()
read = read
return self.fd.recv(buflen)
except SSL.ZeroReturnError:
return ''
except SSL.SysCallError, e:
if e[0] == -1 or e[0] > 0:
return ''
raise sslerror(str(e))
def sendall(self, data):
# overriding sendall because ssl sockets behave badly when asked to
# send empty strings; 'normal' sockets don't have a problem
if not data:
return
super(GreenSSL, self).sendall(data)
# NOTE: read() in SSLObject does not have the semantics of file.read
# reading here until we have buflen bytes or hit EOF is an error
def read(self, buflen=1024):
return self.recv(buflen)
def write(self, data):
try:
......@@ -582,26 +320,15 @@ class GreenSSL(GreenSocket):
except SSL.Error, ex:
raise sslerror(str(ex))
def server(self):
return self.fd.server()
def issuer(self):
return self.fd.issuer()
def dup(self):
raise NotImplementedError("Dup not supported on SSL sockets")
def makefile(self, *args, **kw):
self._refcount.increment()
return GreenFile(type(self)(self.fd, refcount = self._refcount))
makeGreenFile = makefile
def makefile(self, mode='r', bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize)
def close(self):
self._refcount.decrement()
if self._refcount.is_referenced():
return
super(GreenSSL, self).close()
def close (self):
if self._makefile_refs < 1:
GreenSocket.close(self)
else:
self._makefile_refs -= 1
def socketpair(*args):
......@@ -652,9 +379,9 @@ def ssl_listener(address, private_key, certificate):
which accepts connections forever and spawns greenlets for
each incoming connection.
"""
sock = wrap_ssl(_original_socket(), private_key, certificate)
r = _original_socket()
sock = wrap_ssl000(r, private_key, certificate)
socket_bind_and_listen(sock, address)
sock.is_secure = True
return sock
# XXX merge this into create_connection
......@@ -726,7 +453,8 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT):
raise error, msg
def wrap_ssl(sock, keyfile=None, certfile=None):
# get rid of this
def wrap_ssl000(sock, keyfile=None, certfile=None):
from OpenSSL import SSL
context = SSL.Context(SSL.SSLv23_METHOD)
if certfile is not None:
......@@ -734,7 +462,44 @@ def wrap_ssl(sock, keyfile=None, certfile=None):
if keyfile is not None:
context.use_privatekey_file(keyfile)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
timeout = sock.gettimeout()
connection = SSL.Connection(context, sock)
connection.set_connect_state()
return GreenSSL(connection)
ssl_sock = GreenSSL(connection)
try:
sock.getpeername()
except:
# no, no connection yet
pass
else:
# yes, do the handshake
ssl_sock.do_handshake()
return ssl_sock
def wrap_ssl(sock, keyfile=None, certfile=None):
from OpenSSL import SSL
context = SSL.Context(SSL.SSLv23_METHOD)
if certfile is not None:
context.use_certificate_file(certfile)
if keyfile is not None:
context.use_privatekey_file(keyfile)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
connection = SSL.Connection(context, sock.fd)
connection.set_connect_state()
ssl_sock = GreenSSL(connection)
ssl_sock.settimeout(sock.gettimeout())
try:
sock.getpeername()
except:
# no, no connection yet
pass
else:
# yes, do the handshake
ssl_sock.do_handshake()
return ssl_sock
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment