Commit 5accbdb8 authored by Jeremy Hylton's avatar Jeremy Hylton

Make sure socket.close() doesn't interfere with socket.makefile().

If a makefile()-generated object is open and its parent socket is
closed, the parent socket should remain open until the child is
closed, too.  The code to support this is moderately complex and
requires one extra slots in the socket object.

This change fixes httplib so that several urllib2net test cases pass
again.

Add SocketCloser class to socket.py, which encapsulates the
refcounting logic for sockets after makefile() has been called.

Move SocketIO class from io module to socket module.  It's only use is
to implement the raw I/O methods on top of a socket to support
makefile().

Add unittests to test_socket to cover various patterns of close and
makefile.
parent d2ef864f
......@@ -442,34 +442,6 @@ class FileIO(_fileio._FileIO, RawIOBase):
return self._mode
class SocketIO(RawIOBase):
"""Raw I/O implementation for stream sockets."""
# XXX More docs
def __init__(self, sock, mode):
assert mode in ("r", "w", "rw")
RawIOBase.__init__(self)
self._sock = sock
self._mode = mode
def readinto(self, b):
return self._sock.recv_into(b)
def write(self, b):
return self._sock.send(b)
def readable(self):
return "r" in self._mode
def writable(self):
return "w" in self._mode
def fileno(self):
return self._sock.fileno()
class BufferedIOBase(IOBase):
"""Base class for buffered IO objects.
......
......@@ -89,22 +89,67 @@ if sys.platform.lower().startswith("win"):
# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
_can_dup_socket = hasattr(_socket, "dup")
_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
class SocketCloser:
"""Helper to manage socket close() logic for makefile().
The OS socket should not be closed until the socket and all
of its makefile-children are closed. If the refcount is zero
when socket.close() is called, this is easy: Just close the
socket. If the refcount is non-zero when socket.close() is
called, then the real close should not occur until the last
makefile-child is closed.
"""
def __init__(self, sock):
self._sock = sock
self._makefile_refs = 0
# Test whether the socket is open.
try:
sock.fileno()
self._socket_open = True
except error:
self._socket_open = False
def socket_close(self):
self._socket_open = False
self.close()
def makefile_open(self):
self._makefile_refs += 1
def makefile_close(self):
self._makefile_refs -= 1
self.close()
def close(self):
if not (self._socket_open or self._makefile_refs):
self._sock._real_close()
class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method."""
__slots__ = ["__weakref__"]
__slots__ = ["__weakref__", "_closer"]
if not _can_dup_socket:
__slots__.append("_base")
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
if fileno is None:
_socket.socket.__init__(self, family, type, proto)
else:
_socket.socket.__init__(self, family, type, proto, fileno)
# Defer creating a SocketCloser until makefile() is actually called.
self._closer = None
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
......@@ -128,14 +173,6 @@ class socket(_socket.socket):
conn.close()
return wrapper, addr
if not _can_dup_socket:
def close(self):
"""Wrap close() to close the _base as well."""
_socket.socket.close(self)
base = getattr(self, "_base", None)
if base is not None:
base.close()
def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None):
"""Return an I/O stream connected to the socket.
......@@ -156,7 +193,9 @@ class socket(_socket.socket):
rawmode += "r"
if writing:
rawmode += "w"
raw = io.SocketIO(self, rawmode)
if self._closer is None:
self._closer = SocketCloser(self)
raw = SocketIO(self, rawmode, self._closer)
if buffering is None:
buffering = -1
if buffering < 0:
......@@ -183,6 +222,65 @@ class socket(_socket.socket):
text.mode = mode
return text
def close(self):
if self._closer is None:
self._real_close()
else:
self._closer.socket_close()
# _real_close calls close on the _socket.socket base class.
if not _can_dup_socket:
def _real_close(self):
_socket.socket.close(self)
base = getattr(self, "_base", None)
if base is not None:
self._base = None
base.close()
else:
def _real_close(self):
_socket.socket.close(self)
class SocketIO(io.RawIOBase):
"""Raw I/O implementation for stream sockets.
This class supports the makefile() method on sockets. It provides
the raw I/O interface on top of a socket object.
"""
# XXX More docs
def __init__(self, sock, mode, closer):
assert mode in ("r", "w", "rw")
io.RawIOBase.__init__(self)
self._sock = sock
self._mode = mode
self._closer = closer
closer.makefile_open()
def readinto(self, b):
return self._sock.recv_into(b)
def write(self, b):
return self._sock.send(b)
def readable(self):
return "r" in self._mode
def writable(self):
return "w" in self._mode
def fileno(self):
return self._sock.fileno()
def close(self):
if self.closed:
return
self._closer.makefile_close()
io.RawIOBase.close(self)
def getfqdn(name=''):
"""Get fully qualified domain name from name.
......
......@@ -163,6 +163,11 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
class SocketConnectedTest(ThreadedTCPSocketTest):
"""Socket tests for client-server connection.
self.cli_conn is a client socket connected to the server. The
setUp() method guarantees that it is connected to the server.
"""
def __init__(self, methodName='runTest'):
ThreadedTCPSocketTest.__init__(self, methodName=methodName)
......@@ -618,6 +623,10 @@ class TCPCloserTest(ThreadedTCPSocketTest):
self.assertEqual(read, [sd])
self.assertEqual(sd.recv(1), b'')
# Calling close() many times should be safe.
conn.close()
conn.close()
def _testClose(self):
self.cli.connect((HOST, PORT))
time.sleep(1.0)
......@@ -710,6 +719,16 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
self.cli.send(MSG)
class FileObjectClassTestCase(SocketConnectedTest):
"""Unit tests for the object returned by socket.makefile()
self.serv_file is the io object returned by makefile() on
the client connection. You can read from this file to
get output from the server.
self.cli_file is the io object returned by makefile() on the
server connection. You can write to this file to send output
to the client.
"""
bufsize = -1 # Use default buffer size
......@@ -779,6 +798,26 @@ class FileObjectClassTestCase(SocketConnectedTest):
self.cli_file.write(MSG)
self.cli_file.flush()
def testCloseAfterMakefile(self):
# The file returned by makefile should keep the socket open.
self.cli_conn.close()
# read until EOF
msg = self.serv_file.read()
self.assertEqual(msg, MSG)
def _testCloseAfterMakefile(self):
self.cli_file.write(MSG)
self.cli_file.flush()
def testMakefileAfterMakefileClose(self):
self.serv_file.close()
msg = self.cli_conn.recv(len(MSG))
self.assertEqual(msg, MSG)
def _testMakefileAfterMakefileClose(self):
self.cli_file.write(MSG)
self.cli_file.flush()
def testClosedAttr(self):
self.assert_(not self.serv_file.closed)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment