test_socketserver.py 10.7 KB
Newer Older
1
"""
2
Test suite for SocketServer.py.
3
"""
4

5
import contextlib
6
import imp
7
import os
8
import select
9 10
import signal
import socket
11 12
import select
import errno
13
import tempfile
14
import unittest
15
import SocketServer
16 17

import test.test_support
18
from test.test_support import reap_children, reap_threads, verbose
19 20 21 22
try:
    import threading
except ImportError:
    threading = None
23 24

test.test_support.requires("network")
25

26
TEST_STR = "hello world\n"
27
HOST = test.test_support.HOST
28 29

HAVE_UNIX_SOCKETS = hasattr(socket, "AF_UNIX")
30 31
requires_unix_sockets = unittest.skipUnless(HAVE_UNIX_SOCKETS,
                                            'requires Unix sockets')
32
HAVE_FORKING = hasattr(os, "fork") and os.name != "os2"
33
requires_forking = unittest.skipUnless(HAVE_FORKING, 'requires forking')
34

35 36 37 38
def signal_alarm(n):
    """Call signal.alarm when it exists (i.e. not on Windows)."""
    if hasattr(signal, 'alarm'):
        signal.alarm(n)
39

40 41 42
# Remember real select() to avoid interferences with mocking
_real_select = select.select

43
def receive(sock, n, timeout=20):
44
    r, w, x = _real_select([sock], [], [], timeout)
45 46 47 48 49
    if sock in r:
        return sock.recv(n)
    else:
        raise RuntimeError, "timed out on %r" % (sock,)

50
if HAVE_UNIX_SOCKETS:
51 52
    class ForkingUnixStreamServer(SocketServer.ForkingMixIn,
                                  SocketServer.UnixStreamServer):
53 54
        pass

55 56
    class ForkingUnixDatagramServer(SocketServer.ForkingMixIn,
                                    SocketServer.UnixDatagramServer):
57
        pass
58 59


60 61 62 63
@contextlib.contextmanager
def simple_subprocess(testcase):
    pid = os.fork()
    if pid == 0:
64
        # Don't raise an exception; it would be caught by the test harness.
65 66 67
        os._exit(72)
    yield None
    pid2, status = os.waitpid(pid, 0)
68 69
    testcase.assertEqual(pid2, pid)
    testcase.assertEqual(72 << 8, status)
70 71


72
@unittest.skipUnless(threading, 'Threading required for this test.')
73 74 75 76
class SocketServerTest(unittest.TestCase):
    """Test all socket servers."""

    def setUp(self):
77
        signal_alarm(60)  # Kill deadlocks after 60 seconds.
78 79 80 81
        self.port_seed = 0
        self.test_files = []

    def tearDown(self):
82
        signal_alarm(0)  # Didn't deadlock.
83 84 85 86 87 88 89 90 91 92 93
        reap_children()

        for fn in self.test_files:
            try:
                os.remove(fn)
            except os.error:
                pass
        self.test_files[:] = []

    def pickaddr(self, proto):
        if proto == socket.AF_INET:
94
            return (HOST, 0)
95
        else:
96 97 98 99 100 101
            # XXX: We need a way to tell AF_UNIX to pick its own name
            # like AF_INET provides port==0.
            dir = None
            if os.name == 'os2':
                dir = '\socket'
            fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
            if os.name == 'os2':
                # AF_UNIX socket names on OS/2 require a specific prefix
                # which can't include a drive letter and must also use
                # backslashes as directory separators
                if fn[1] == ':':
                    fn = fn[2:]
                if fn[0] in (os.sep, os.altsep):
                    fn = fn[1:]
                if os.sep == '/':
                    fn = fn.replace(os.sep, os.altsep)
                else:
                    fn = fn.replace(os.altsep, os.sep)
            self.test_files.append(fn)
            return fn

117 118 119 120 121 122 123
    def make_server(self, addr, svrcls, hdlrbase):
        class MyServer(svrcls):
            def handle_error(self, request, client_address):
                self.close_request(request)
                self.server_close()
                raise

124 125 126 127 128
        class MyHandler(hdlrbase):
            def handle(self):
                line = self.rfile.readline()
                self.wfile.write(line)

129 130
        if verbose: print "creating server"
        server = MyServer(addr, MyHandler)
131
        self.assertEqual(server.server_address, server.socket.getsockname())
132 133
        return server

134
    @reap_threads
135 136 137 138 139 140
    def run_server(self, svrcls, hdlrbase, testfunc):
        server = self.make_server(self.pickaddr(svrcls.address_family),
                                  svrcls, hdlrbase)
        # We had the OS pick a port, so pull the real address out of
        # the server.
        addr = server.server_address
141
        if verbose:
142
            print "server created"
143 144
            print "ADDR =", addr
            print "CLASS =", svrcls
145 146 147 148 149 150 151
        t = threading.Thread(
            name='%s serving' % svrcls,
            target=server.serve_forever,
            # Short poll interval to make the test finish quickly.
            # Time between requests is short enough that we won't wake
            # up spuriously too many times.
            kwargs={'poll_interval':0.01})
152
        t.daemon = True  # In case this function raises.
153 154
        t.start()
        if verbose: print "server running"
155
        for i in range(3):
156 157 158
            if verbose: print "test client", i
            testfunc(svrcls.address_family, addr)
        if verbose: print "waiting for server"
159
        server.shutdown()
160 161
        t.join()
        if verbose: print "done"
162 163 164 165 166 167 168 169 170

    def stream_examine(self, proto, addr):
        s = socket.socket(proto, socket.SOCK_STREAM)
        s.connect(addr)
        s.sendall(TEST_STR)
        buf = data = receive(s, 100)
        while data and '\n' not in buf:
            data = receive(s, 100)
            buf += data
171
        self.assertEqual(buf, TEST_STR)
172 173 174 175 176 177 178 179 180
        s.close()

    def dgram_examine(self, proto, addr):
        s = socket.socket(proto, socket.SOCK_DGRAM)
        s.sendto(TEST_STR, addr)
        buf = data = receive(s, 100)
        while data and '\n' not in buf:
            data = receive(s, 100)
            buf += data
181
        self.assertEqual(buf, TEST_STR)
182 183
        s.close()

184
    def test_TCPServer(self):
185 186
        self.run_server(SocketServer.TCPServer,
                        SocketServer.StreamRequestHandler,
187 188 189
                        self.stream_examine)

    def test_ThreadingTCPServer(self):
190 191
        self.run_server(SocketServer.ThreadingTCPServer,
                        SocketServer.StreamRequestHandler,
192 193
                        self.stream_examine)

194 195 196 197
    @requires_forking
    def test_ForkingTCPServer(self):
        with simple_subprocess(self):
            self.run_server(SocketServer.ForkingTCPServer,
198
                            SocketServer.StreamRequestHandler,
199 200
                            self.stream_examine)

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    @requires_unix_sockets
    def test_UnixStreamServer(self):
        self.run_server(SocketServer.UnixStreamServer,
                        SocketServer.StreamRequestHandler,
                        self.stream_examine)

    @requires_unix_sockets
    def test_ThreadingUnixStreamServer(self):
        self.run_server(SocketServer.ThreadingUnixStreamServer,
                        SocketServer.StreamRequestHandler,
                        self.stream_examine)

    @requires_unix_sockets
    @requires_forking
    def test_ForkingUnixStreamServer(self):
        with simple_subprocess(self):
            self.run_server(ForkingUnixStreamServer,
218
                            SocketServer.StreamRequestHandler,
219 220 221
                            self.stream_examine)

    def test_UDPServer(self):
222 223
        self.run_server(SocketServer.UDPServer,
                        SocketServer.DatagramRequestHandler,
224 225 226
                        self.dgram_examine)

    def test_ThreadingUDPServer(self):
227 228
        self.run_server(SocketServer.ThreadingUDPServer,
                        SocketServer.DatagramRequestHandler,
229 230
                        self.dgram_examine)

231 232 233 234 235 236
    @requires_forking
    def test_ForkingUDPServer(self):
        with simple_subprocess(self):
            self.run_server(SocketServer.ForkingUDPServer,
                            SocketServer.DatagramRequestHandler,
                            self.dgram_examine)
237

238 239 240 241 242 243 244 245 246 247 248 249 250
    @contextlib.contextmanager
    def mocked_select_module(self):
        """Mocks the select.select() call to raise EINTR for first call"""
        old_select = select.select

        class MockSelect:
            def __init__(self):
                self.called = 0

            def __call__(self, *args):
                self.called += 1
                if self.called == 1:
                    # raise the exception on first call
251
                    raise select.error(errno.EINTR, os.strerror(errno.EINTR))
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
                else:
                    # Return real select value for consecutive calls
                    return old_select(*args)

        select.select = MockSelect()
        try:
            yield select.select
        finally:
            select.select = old_select

    def test_InterruptServerSelectCall(self):
        with self.mocked_select_module() as mock_select:
            pid = self.run_server(SocketServer.TCPServer,
                                  SocketServer.StreamRequestHandler,
                                  self.stream_examine)
            # Make sure select was called again:
            self.assertGreater(mock_select.called, 1)

270 271 272
    # Alas, on Linux (at least) recvfrom() doesn't return a meaningful
    # client address so this cannot work:

273 274 275 276 277
    # @requires_unix_sockets
    # def test_UnixDatagramServer(self):
    #     self.run_server(SocketServer.UnixDatagramServer,
    #                     SocketServer.DatagramRequestHandler,
    #                     self.dgram_examine)
278
    #
279 280 281 282 283
    # @requires_unix_sockets
    # def test_ThreadingUnixDatagramServer(self):
    #     self.run_server(SocketServer.ThreadingUnixDatagramServer,
    #                     SocketServer.DatagramRequestHandler,
    #                     self.dgram_examine)
284
    #
285 286 287 288 289 290
    # @requires_unix_sockets
    # @requires_forking
    # def test_ForkingUnixDatagramServer(self):
    #     self.run_server(SocketServer.ForkingUnixDatagramServer,
    #                     SocketServer.DatagramRequestHandler,
    #                     self.dgram_examine)
291

292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
    @reap_threads
    def test_shutdown(self):
        # Issue #2302: shutdown() should always succeed in making an
        # other thread leave serve_forever().
        class MyServer(SocketServer.TCPServer):
            pass

        class MyHandler(SocketServer.StreamRequestHandler):
            pass

        threads = []
        for i in range(20):
            s = MyServer((HOST, 0), MyHandler)
            t = threading.Thread(
                name='MyServer serving',
                target=s.serve_forever,
                kwargs={'poll_interval':0.01})
            t.daemon = True  # In case this function raises.
            threads.append((t, s))
        for t, s in threads:
            t.start()
            s.shutdown()
        for t, s in threads:
            t.join()

317

318 319
def test_main():
    if imp.lock_held():
320
        # If the import lock is held, the threads will hang
321
        raise unittest.SkipTest("can't run when import lock is held")
322

323
    test.test_support.run_unittest(SocketServerTest)
324

325 326
if __name__ == "__main__":
    test_main()