Commit 9967fdb3 authored by Denis Bilenko's avatar Denis Bilenko

simplify GreenSocket and GreenSSL

- drop GreenFile (makefile() uses _fileobject), GreenPipe and GreenPipeSocket
parent eb5a537e
...@@ -32,9 +32,7 @@ _fileobject = __socket__._fileobject ...@@ -32,9 +32,7 @@ _fileobject = __socket__._fileobject
sslerror = __socket__.sslerror sslerror = __socket__.sslerror
import errno import errno
import os
import time import time
from errno import EAGAIN
from gevent.greenlet import wait_reader, wait_writer, spawn from gevent.greenlet import wait_reader, wait_writer, spawn
...@@ -56,50 +54,6 @@ except ImportError: ...@@ -56,50 +54,6 @@ except ImportError:
class SysCallError(object): class SysCallError(object):
pass pass
def higher_order_recv(recv_func):
def recv(self, buflen):
if self.act_non_blocking:
return self.fd.recv(buflen)
buf = self.recvbuffer
if buf:
chunk, self.recvbuffer = buf[:buflen], buf[buflen:]
return chunk
fd = self.fd
bytes = recv_func(fd, buflen)
if self.gettimeout():
end = time.time()+self.gettimeout()
else:
end = None
timeout_seconds = None
while bytes is None:
try:
if end:
timeout_seconds = end - time.time()
wait_reader(fd.fileno(), timeout=timeout_seconds, timeout_exc=timeout)
except timeout:
raise
except error, e:
if e[0] == errno.EPIPE:
bytes = ''
else:
raise
else:
bytes = recv_func(fd, buflen)
return bytes
return recv
def higher_order_send(send_func):
def send(self, data):
if self.act_non_blocking:
return self.fd.send(data)
count = send_func(self.fd, data)
if not count:
return 0
return count
return send
CONNECT_ERR = (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK) CONNECT_ERR = (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK)
CONNECT_SUCCESS = (0, errno.EISCONN) CONNECT_SUCCESS = (0, errno.EISCONN)
def socket_connect(descriptor, address): def socket_connect(descriptor, address):
...@@ -111,157 +65,70 @@ def socket_connect(descriptor, address): ...@@ -111,157 +65,70 @@ def socket_connect(descriptor, address):
return descriptor return descriptor
def socket_accept(descriptor): class _closedsocket(object):
try: __slots__ = []
return descriptor.accept() def _dummy(*args):
except error, e: raise error(errno.EBADF, 'Bad file descriptor')
if e[0] == errno.EWOULDBLOCK: # All _delegate_methods must also be initialized here.
return None send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
raise __getattr__ = _dummy
def socket_send(descriptor, data):
try:
return descriptor.send(data)
except error, e:
if e[0] == errno.EWOULDBLOCK or e[0] == errno.ENOTCONN:
return 0
raise
except SSL.WantWriteError:
return 0
except SSL.WantReadError:
return 0
# winsock sometimes throws ENOTCONN
SOCKET_CLOSED = (errno.ECONNRESET, errno.ENOTCONN, errno.ESHUTDOWN)
def socket_recv(descriptor, buflen):
try:
return descriptor.recv(buflen)
except error, e:
if e[0] == errno.EWOULDBLOCK:
return None
if e[0] in SOCKET_CLOSED:
return ''
raise
except SSL.WantReadError:
return None
except SSL.ZeroReturnError:
return ''
except SSL.SysCallError, e:
if e[0] == -1 or e[0] > 0:
return ''
raise
def file_recv(fd, buflen):
try:
return fd.read(buflen)
except IOError, e:
if e[0] == EAGAIN:
return None
return ''
except error, e:
if e[0] == errno.EPIPE:
return ''
raise
_delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into", "send", "sendto", 'sendall')
def file_send(fd, data):
try:
fd.write(data)
fd.flush()
return len(data)
except IOError, e:
if e[0] == EAGAIN:
return 0
except ValueError, e:
written = 0
except error, e:
if e[0] == errno.EPIPE:
written = 0
def set_nonblocking(fd):
try:
setblocking = fd.setblocking
except AttributeError:
# This version of Python predates socket.setblocking()
import fcntl
fileno = fd.fileno()
flags = fcntl.fcntl(fileno, fcntl.F_GETFL)
fcntl.fcntl(fileno, fcntl.F_SETFL, flags | os.O_NONBLOCK)
else:
# socket supports setblocking()
setblocking(0)
class GreenSocket(object): class GreenSocket(object):
is_secure = False is_secure = False # XXX remove this
timeout = None
def __init__(self, family_or_realsock=_socket.AF_INET, *args, **kwargs): def __init__(self, family_or_realsock=_socket.AF_INET, *args, **kwargs):
if isinstance(family_or_realsock, (int, long)): if isinstance(family_or_realsock, (int, long)):
fd = _original_socket(family_or_realsock, *args, **kwargs) self.fd = _original_socket(family_or_realsock, *args, **kwargs)
self.timeout = _socket.getdefaulttimeout()
else: else:
fd = family_or_realsock if hasattr(family_or_realsock, '_sock'):
family_or_realsock = family_or_realsock.sock
self.fd = family_or_realsock
self.timeout = self.fd.gettimeout()
assert not args, args assert not args, args
assert not kwargs, kwargs assert not kwargs, kwargs
set_nonblocking(fd) self.fd.setblocking(0)
self.fd = fd
self._fileno = fd.fileno()
self.recvbuffer = ''
self.closed = False
self.timeout = _socket.getdefaulttimeout()
# when client calls setblocking(0) or settimeout(0) the socket must
# act non-blocking
self.act_non_blocking = False
def __repr__(self): def __repr__(self):
return '<%s at %s fileno=%s>' % (type(self).__name__, hex(id(self)), self.fileno()) try:
fileno = self.fileno()
@property except Exception, ex:
def family(self): fileno = str(ex)
return self.fd.family return '<%s at %s fileno=%s timeout=%s>' % (type(self).__name__, hex(id(self)), fileno, self.timeout)
@property
def type(self):
return self.fd.type
@property def __getattr__(self, item):
def proto(self): return getattr(self.fd, item)
return self.fd.proto
def accept(self): def accept(self):
if self.act_non_blocking: if self.timeout==0.0:
return self.fd.accept() return self.fd.accept()
fd = self.fd fd = self.fd
while True: while True:
res = socket_accept(fd) try:
res = self.fd.accept()
except error, e:
if e[0] == errno.EWOULDBLOCK:
res = None
else:
raise
if res is not None: if res is not None:
client, addr = res client, addr = res
set_nonblocking(client)
return type(self)(client), addr return type(self)(client), addr
wait_reader(fd.fileno(), timeout=self.gettimeout(), timeout_exc=timeout) wait_reader(fd.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
def bind(self, *args, **kw): def close(self):
fn = self.bind = self.fd.bind self.fd = _closedsocket()
return fn(*args, **kw) dummy = self.fd._dummy
for method in _delegate_methods:
def close(self, *args, **kw): setattr(self, method, dummy)
if self.closed:
return
self.closed = True
if self.is_secure:
# *NOTE: This is not quite the correct SSL shutdown sequence.
# We should actually be checking the return value of shutdown.
# Note also that this is not the same as calling self.shutdown().
self.fd.shutdown()
return self.fd.close(*args, **kw)
def connect(self, address): def connect(self, address):
if self.act_non_blocking: if self.timeout==0.0:
return self.fd.connect(address) return self.fd.connect(address)
fd = self.fd fd = self.fd
if self.gettimeout() is None: if self.gettimeout() is None:
...@@ -277,7 +144,7 @@ class GreenSocket(object): ...@@ -277,7 +144,7 @@ class GreenSocket(object):
wait_writer(fd.fileno(), timeout=end-time.time(), timeout_exc=timeout) wait_writer(fd.fileno(), timeout=end-time.time(), timeout_exc=timeout)
def connect_ex(self, address): def connect_ex(self, address):
if self.act_non_blocking: if self.timeout==0.0:
return self.fd.connect_ex(address) return self.fd.connect_ex(address)
fd = self.fd fd = self.fd
if self.gettimeout() is None: if self.gettimeout() is None:
...@@ -300,86 +167,57 @@ class GreenSocket(object): ...@@ -300,86 +167,57 @@ class GreenSocket(object):
def dup(self, *args, **kw): def dup(self, *args, **kw):
sock = self.fd.dup(*args, **kw) sock = self.fd.dup(*args, **kw)
set_nonblocking(sock)
newsock = type(self)(sock) newsock = type(self)(sock)
newsock.settimeout(self.timeout) newsock.settimeout(self.timeout)
return newsock return newsock
def fileno(self, *args, **kw):
fn = self.fileno = self.fd.fileno
return fn(*args, **kw)
def getpeername(self, *args, **kw):
fn = self.getpeername = self.fd.getpeername
return fn(*args, **kw)
def getsockname(self, *args, **kw):
fn = self.getsockname = self.fd.getsockname
return fn(*args, **kw)
def getsockopt(self, *args, **kw):
fn = self.getsockopt = self.fd.getsockopt
return fn(*args, **kw)
def listen(self, *args, **kw):
fn = self.listen = self.fd.listen
return fn(*args, **kw)
def makefile(self, mode='r', bufsize=-1): def makefile(self, mode='r', bufsize=-1):
return _fileobject(self.dup(), mode, bufsize) return _fileobject(self.dup(), mode, bufsize)
def makeGreenFile(self, mode='r', bufsize=-1): def recv(self, *args):
return GreenFile(self.dup()) if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
recv = higher_order_recv(socket_recv) res = self.fd.recv(*args)
return res
def recvfrom(self, *args): def recvfrom(self, *args):
if not self.act_non_blocking: if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout) wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recvfrom(*args) return self.fd.recvfrom(*args)
def recvfrom_into(self, *args): def recvfrom_into(self, *args):
if not self.act_non_blocking: if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout) wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recvfrom_into(*args) return self.fd.recvfrom_into(*args)
def recv_into(self, *args): def recv_into(self, *args):
if not self.act_non_blocking: if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout) wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.recv_into(*args) return self.fd.recv_into(*args)
send = higher_order_send(socket_send) def send(self, *args):
if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
return self.fd.send(*args)
def sendall(self, data): def sendall(self, data):
fd = self.fd # XXX does not respect timeout
tail = self.send(data) tail = self.send(data)
while tail < len(data): while tail < len(data):
wait_writer(self.fileno(), timeout_exc=timeout) wait_writer(self.fileno(), timeout_exc=timeout)
tail += self.send(data[tail:]) tail += self.send(data[tail:])
def sendto(self, *args): def sendto(self, *args):
wait_writer(self.fileno(), timeout_exc=timeout) if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.timeout, timeout_exc=timeout)
return self.fd.sendto(*args) return self.fd.sendto(*args)
def setblocking(self, flag): def setblocking(self, flag):
if flag: if flag:
self.act_non_blocking = False
self.timeout = None self.timeout = None
else: else:
self.act_non_blocking = True
self.timeout = 0.0 self.timeout = 0.0
def setsockopt(self, *args, **kw):
fn = self.setsockopt = self.fd.setsockopt
return fn(*args, **kw)
def shutdown(self, *args, **kw):
if self.is_secure:
fn = self.shutdown = self.fd.sock_shutdown
else:
fn = self.shutdown = self.fd.shutdown
return fn(*args, **kw)
def settimeout(self, howlong): def settimeout(self, howlong):
if howlong is None: if howlong is None:
self.setblocking(True) self.setblocking(True)
...@@ -391,190 +229,90 @@ class GreenSocket(object): ...@@ -391,190 +229,90 @@ class GreenSocket(object):
howlong = f() howlong = f()
if howlong < 0.0: if howlong < 0.0:
raise ValueError('Timeout value out of range') raise ValueError('Timeout value out of range')
if howlong == 0.0: self.timeout = howlong
self.setblocking(howlong)
else:
self.timeout = howlong
def gettimeout(self): def gettimeout(self):
return self.timeout return self.timeout
def read(self, size=None):
if size is not None and not isinstance(size, (int, long)):
raise TypeError('Expecting an int or long for size, got %s: %s' % (type(size), repr(size)))
buf, self.sock.recvbuffer = self.sock.recvbuffer, ''
lst = [buf]
if size is None:
while True:
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
lst.append(d)
else:
buflen = len(buf)
while buflen < size:
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
buflen += len(d)
lst.append(d)
else:
d = lst[-1]
overbite = buflen - size
if overbite:
lst[-1], self.sock.recvbuffer = d[:-overbite], d[-overbite:]
else:
lst[-1], self.sock.recvbuffer = d, ''
return ''.join(lst)
class GreenFile(object):
newlines = '\r\n'
mode = 'wb+'
def __init__(self, fd):
if isinstance(fd, GreenSocket):
set_nonblocking(fd.fd)
else:
set_nonblocking(fd)
self.sock = fd
self.closed = False
def close(self):
self.sock.close()
self.closed = True
def fileno(self): class GreenSSL(GreenSocket):
return self.sock.fileno() is_secure = True
# TODO next
def flush(self): def __init__(self, fd, do_handshake_on_connect=True):
pass GreenSocket.__init__(self, fd)
self._makefile_refs = 0
def write(self, data): def accept(self):
return self.sock.sendall(data) if self.timeout==0.0:
return self.fd.accept()
fd = self.fd
while True:
try:
res = self.fd.accept()
except error, e:
if e[0] == errno.EWOULDBLOCK:
res = None
else:
raise
if res is not None:
client, addr = res
accepted = type(self)(client)
accepted.do_handshake() # XXX
return accepted, addr
wait_reader(fd.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
def readuntil(self, terminator, size=None): def do_handshake(self):
buf, self.sock.recvbuffer = self.sock.recvbuffer, '' while True:
checked = 0 try:
if size is None: self.fd.do_handshake()
while True:
found = buf.find(terminator, checked)
if found != -1:
found += len(terminator)
chunk, self.sock.recvbuffer = buf[:found], buf[found:]
return chunk
checked = max(0, len(buf) - (len(terminator) - 1))
d = self.sock.recv(BUFFER_SIZE)
if not d:
break
buf += d
return buf
while len(buf) < size:
found = buf.find(terminator, checked)
if found != -1:
found += len(terminator)
chunk, self.sock.recvbuffer = buf[:found], buf[found:]
return chunk
checked = len(buf)
d = self.sock.recv(BUFFER_SIZE)
if not d:
break break
buf += d except SSL.WantReadError:
chunk, self.sock.recvbuffer = buf[:size], buf[size:] wait_reader(self.fileno())
return chunk except SSL.WantWriteError:
wait_writer(self.fileno())
except SSL.SysCallError, ex:
# XXX fix ex[0]
raise sslerror(str(ex))
def connect(self, *args):
GreenSocket.connect(self, *args)
self.do_handshake()
def readline(self, size=None): def send(self, data):
return self.readuntil(self.newlines, size=size) if self.timeout!=0.0:
wait_writer(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
def __iter__(self): try:
return self.xreadlines() return self.fd.send(data)
except SSL.WantWriteError:
def readlines(self, size=None): return 0
return list(self.xreadlines(size=size)) except SSL.WantReadError:
return 0
except SSL.SysCallError, e:
if e[0] == -1 and data == "":
# errors when writing empty strings are expected
# and can be ignored
return 0
def xreadlines(self, size=None): def recv(self, buflen):
if size is None: pending = self.fd.pending()
if pending:
return self.fd.recv(min(pending, buflen))
if self.timeout!=0.0:
wait_reader(self.fileno(), timeout=self.gettimeout(), timeout_exc=timeout)
try:
while True: while True:
line = self.readline() return self.fd.recv(buflen)
if not line: except SSL.ZeroReturnError:
break return ''
yield line except SSL.SysCallError, e:
else: if e[0] == -1 or e[0] > 0:
while size > 0: return ''
line = self.readline(size) raise sslerror(str(e))
if not line:
break
yield line
size -= len(line)
def writelines(self, lines):
for line in lines:
self.write(line)
read = read
class GreenPipeSocket(GreenSocket):
""" This is a weird class that looks like a socket but expects a file descriptor as an argument instead of a socket.
"""
recv = higher_order_recv(file_recv)
send = higher_order_send(file_send)
class GreenPipe(GreenFile):
def __init__(self, fd):
set_nonblocking(fd)
self.fd = GreenPipeSocket(fd)
super(GreenPipe, self).__init__(self.fd)
def recv(self, *args, **kw):
fn = self.recv = self.fd.recv
return fn(*args, **kw)
def send(self, *args, **kw):
fn = self.send = self.fd.send
return fn(*args, **kw)
def flush(self):
self.fd.fd.flush()
class RefCount(object):
""" Reference counting class only to be used with GreenSSL objects """
def __init__(self):
self._count = 1
def increment(self):
self._count += 1
def decrement(self):
self._count -= 1
assert self._count >= 0
def is_referenced(self):
return self._count > 0
class GreenSSL(GreenSocket):
def __init__(self, fd, refcount = None):
GreenSocket.__init__(self, fd)
self.sock = self
self._refcount = refcount
if refcount is None:
self._refcount = RefCount()
read = read
def sendall(self, data): # NOTE: read() in SSLObject does not have the semantics of file.read
# overriding sendall because ssl sockets behave badly when asked to # reading here until we have buflen bytes or hit EOF is an error
# send empty strings; 'normal' sockets don't have a problem def read(self, buflen=1024):
if not data: return self.recv(buflen)
return
super(GreenSSL, self).sendall(data)
def write(self, data): def write(self, data):
try: try:
...@@ -582,26 +320,15 @@ class GreenSSL(GreenSocket): ...@@ -582,26 +320,15 @@ class GreenSSL(GreenSocket):
except SSL.Error, ex: except SSL.Error, ex:
raise sslerror(str(ex)) raise sslerror(str(ex))
def server(self): def makefile(self, mode='r', bufsize=-1):
return self.fd.server() self._makefile_refs += 1
return _fileobject(self, mode, bufsize)
def issuer(self):
return self.fd.issuer()
def dup(self):
raise NotImplementedError("Dup not supported on SSL sockets")
def makefile(self, *args, **kw):
self._refcount.increment()
return GreenFile(type(self)(self.fd, refcount = self._refcount))
makeGreenFile = makefile
def close(self): def close (self):
self._refcount.decrement() if self._makefile_refs < 1:
if self._refcount.is_referenced(): GreenSocket.close(self)
return else:
super(GreenSSL, self).close() self._makefile_refs -= 1
def socketpair(*args): def socketpair(*args):
...@@ -652,9 +379,9 @@ def ssl_listener(address, private_key, certificate): ...@@ -652,9 +379,9 @@ def ssl_listener(address, private_key, certificate):
which accepts connections forever and spawns greenlets for which accepts connections forever and spawns greenlets for
each incoming connection. each incoming connection.
""" """
sock = wrap_ssl(_original_socket(), private_key, certificate) r = _original_socket()
sock = wrap_ssl000(r, private_key, certificate)
socket_bind_and_listen(sock, address) socket_bind_and_listen(sock, address)
sock.is_secure = True
return sock return sock
# XXX merge this into create_connection # XXX merge this into create_connection
...@@ -726,7 +453,8 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT): ...@@ -726,7 +453,8 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT):
raise error, msg raise error, msg
def wrap_ssl(sock, keyfile=None, certfile=None): # get rid of this
def wrap_ssl000(sock, keyfile=None, certfile=None):
from OpenSSL import SSL from OpenSSL import SSL
context = SSL.Context(SSL.SSLv23_METHOD) context = SSL.Context(SSL.SSLv23_METHOD)
if certfile is not None: if certfile is not None:
...@@ -734,7 +462,44 @@ def wrap_ssl(sock, keyfile=None, certfile=None): ...@@ -734,7 +462,44 @@ def wrap_ssl(sock, keyfile=None, certfile=None):
if keyfile is not None: if keyfile is not None:
context.use_privatekey_file(keyfile) context.use_privatekey_file(keyfile)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True) context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
timeout = sock.gettimeout()
connection = SSL.Connection(context, sock) connection = SSL.Connection(context, sock)
connection.set_connect_state() connection.set_connect_state()
return GreenSSL(connection) ssl_sock = GreenSSL(connection)
try:
sock.getpeername()
except:
# no, no connection yet
pass
else:
# yes, do the handshake
ssl_sock.do_handshake()
return ssl_sock
def wrap_ssl(sock, keyfile=None, certfile=None):
from OpenSSL import SSL
context = SSL.Context(SSL.SSLv23_METHOD)
if certfile is not None:
context.use_certificate_file(certfile)
if keyfile is not None:
context.use_privatekey_file(keyfile)
context.set_verify(SSL.VERIFY_NONE, lambda *x: True)
connection = SSL.Connection(context, sock.fd)
connection.set_connect_state()
ssl_sock = GreenSSL(connection)
ssl_sock.settimeout(sock.gettimeout())
try:
sock.getpeername()
except:
# no, no connection yet
pass
else:
# yes, do the handshake
ssl_sock.do_handshake()
return ssl_sock
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment