Commit 61399d52 authored by Jason Madden's avatar Jason Madden

Fix #1044 by always closing opened sockets before raising

Also enable ResoureWarnings by default in the test suite and fix a
bunch that showed up.
parent be540604
...@@ -90,6 +90,12 @@ ...@@ -90,6 +90,12 @@
- ``socket.send()`` now catches ``EPROTYPE`` on macOS to handle a race - ``socket.send()`` now catches ``EPROTYPE`` on macOS to handle a race
condition during shutdown. Fixed in :pr:`1035` by Jay Oster. condition during shutdown. Fixed in :pr:`1035` by Jay Oster.
- :func:`gevent.socket.create_connection` now properly cleans up open
sockets if connecting or binding raises a :exc:`BaseException` like
:exc:`KeyboardInterrupt`, :exc:`greenlet.GreenletExit` or
:exc:`gevent.timeout.Timeout`. Reported in :issue:`1044` by
kochelmonster.
- Update c-ares to 1.13.0. See :issue:`990`. - Update c-ares to 1.13.0. See :issue:`990`.
1.2.2 (2017-06-05) 1.2.2 (2017-06-05)
......
setuptools setuptools
wheel wheel
cython>=0.27 cython>=0.27.3
greenlet>=0.4.10 greenlet>=0.4.10
pylint>=1.7.1 pylint>=1.7.1
prospector[with_pyroma] prospector[with_pyroma]
......
...@@ -716,8 +716,13 @@ def main(): ...@@ -716,8 +716,13 @@ def main():
def _get_script_help(): def _get_script_help():
from inspect import getargspec # pylint:disable=deprecated-method
patch_all_args = getargspec(patch_all)[0] # pylint:disable=deprecated-method import inspect
try:
getter = inspect.getfullargspec # deprecated in 3.5, un-deprecated in 3.6
except AttributeError:
getter = inspect.getargspec
patch_all_args = getter(patch_all)[0]
modules = [x for x in patch_all_args if 'patch_' + x in globals()] modules = [x for x in patch_all_args if 'patch_' + x in globals()]
script_help = """gevent.monkey - monkey patch the standard modules to use gevent. script_help = """gevent.monkey - monkey patch the standard modules to use gevent.
......
...@@ -74,7 +74,13 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N ...@@ -74,7 +74,13 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N
host, port = address host, port = address
err = None err = None
for res in getaddrinfo(host, port, 0 if has_ipv6 else AF_INET, SOCK_STREAM): # getaddrinfo is documented as returning a list, but our interface
# is pluggable, so be sure it does.
addrs = list(getaddrinfo(host, port, 0 if has_ipv6 else AF_INET, SOCK_STREAM))
if not addrs:
raise error("getaddrinfo returns an empty list")
for res in addrs:
af, socktype, proto, _, sa = res af, socktype, proto, _, sa = res
sock = None sock = None
try: try:
...@@ -84,24 +90,34 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N ...@@ -84,24 +90,34 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=N
if source_address: if source_address:
sock.bind(source_address) sock.bind(source_address)
sock.connect(sa) sock.connect(sa)
return sock except error:
except error as ex:
# without exc_clear(), if connect() fails once, the socket is referenced by the frame in exc_info
# and the next bind() fails (see test__socket.TestCreateConnection)
# that does not happen with regular sockets though, because _socket.socket.connect() is a built-in.
# this is similar to "getnameinfo loses a reference" failure in test_socket.py
if not PY3:
sys.exc_clear() # pylint:disable=no-member,useless-suppression
if sock is not None: if sock is not None:
sock.close() sock.close()
err = ex if res is addrs[-1]:
if err is not None: raise
# without exc_clear(), if connect() fails once, the socket
# is referenced by the frame in exc_info and the next
# bind() fails (see test__socket.TestCreateConnection)
# that does not happen with regular sockets though,
# because _socket.socket.connect() is a built-in. this is
# similar to "getnameinfo loses a reference" failure in
# test_socket.py
try: try:
raise err # pylint:disable=raising-bad-type c = sys.exc_clear
finally: except AttributeError:
err = None pass # Python 3 doesn't have this
else: else:
raise error("getaddrinfo returns an empty list") c()
except BaseException:
# Things like GreenletExit, Timeout and KeyboardInterrupt.
# These get raised immediately, being sure to
# close the socket
if sock is not None:
sock.close()
raise
else:
return sock
# This is promised to be in the __all__ of the _source, but, for circularity reasons, # This is promised to be in the __all__ of the _source, but, for circularity reasons,
# we implement it in this module. Mostly for documentation purposes, put it # we implement it in this module. Mostly for documentation purposes, put it
......
...@@ -430,6 +430,10 @@ class TestCase(TestCaseMetaClass("NewBase", (BaseTestCase,), {})): ...@@ -430,6 +430,10 @@ class TestCase(TestCaseMetaClass("NewBase", (BaseTestCase,), {})):
super(TestCase, cls).tearDownClass() super(TestCase, cls).tearDownClass()
def _close_on_teardown(self, resource): def _close_on_teardown(self, resource):
"""
*resource* either has a ``close`` method, or is a
callable.
"""
if 'close_on_teardown' not in self.__dict__: if 'close_on_teardown' not in self.__dict__:
self.close_on_teardown = [] self.close_on_teardown = []
self.close_on_teardown.append(resource) self.close_on_teardown.append(resource)
......
...@@ -47,17 +47,18 @@ def TESTRUNNER(tests=None): ...@@ -47,17 +47,18 @@ def TESTRUNNER(tests=None):
if tests: if tests:
atexit.register(os.system, 'rm -f */@test*') atexit.register(os.system, 'rm -f */@test*')
basic_args = [sys.executable, '-u', '-W', 'ignore', '-m' 'monkey_test']
for filename in tests: for filename in tests:
if filename in version_tests: if filename in version_tests:
util.log("Overriding %s from %s with file from %s", filename, directory, full_directory) util.log("Overriding %s from %s with file from %s", filename, directory, full_directory)
continue continue
yield [sys.executable, '-u', '-m', 'monkey_test', filename], options.copy() yield basic_args + [filename], options.copy()
yield [sys.executable, '-u', '-m', 'monkey_test', '--Event', filename], options.copy() yield basic_args + ['--Event', filename], options.copy()
options['cwd'] = full_directory options['cwd'] = full_directory
for filename in version_tests: for filename in version_tests:
yield [sys.executable, '-u', '-m', 'monkey_test', filename], options.copy() yield basic_args + [filename], options.copy()
yield [sys.executable, '-u', '-m', 'monkey_test', '--Event', filename], options.copy() yield basic_args + ['--Event', filename], options.copy()
def main(): def main():
......
...@@ -17,7 +17,8 @@ class Test_udp_client(TestCase): ...@@ -17,7 +17,8 @@ class Test_udp_client(TestCase):
server = DatagramServer('127.0.0.1:9000', handle) server = DatagramServer('127.0.0.1:9000', handle)
server.start() server.start()
try: try:
run([sys.executable, '-u', 'udp_client.py', 'Test_udp_client'], timeout=10, cwd='../../examples/') run([sys.executable, '-W', 'ignore' '-u', 'udp_client.py', 'Test_udp_client'],
timeout=10, cwd='../../examples/')
finally: finally:
server.close() server.close()
self.assertEqual(log, [b'Test_udp_client']) self.assertEqual(log, [b'Test_udp_client'])
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
# THE SOFTWARE. # THE SOFTWARE.
from greentest import TestCase, main, tcp_listener from greentest import TestCase, main, tcp_listener
from greentest import skipOnPyPy
import gevent import gevent
from gevent import socket from gevent import socket
import sys import sys
...@@ -87,17 +88,19 @@ class TestGreenIo(TestCase): ...@@ -87,17 +88,19 @@ class TestGreenIo(TestCase):
did_it_work(server) did_it_work(server)
server_greenlet.kill() server_greenlet.kill()
@skipOnPyPy("GC is different")
def test_del_closes_socket(self): def test_del_closes_socket(self):
if PYPY:
return
timer = gevent.Timeout.start_new(0.5) timer = gevent.Timeout.start_new(0.5)
def accept_once(listener): def accept_once(listener):
# delete/overwrite the original conn # delete/overwrite the original conn
# object, only keeping the file object around # object, only keeping the file object around
# closing the file object should close everything # closing the file object should close everything
# XXX: This is not exactly true on Python 3.
# This produces a ResourceWarning.
try: try:
conn, addr = listener.accept() conn, _ = listener.accept()
conn = conn.makefile(mode='wb') conn = conn.makefile(mode='wb')
conn.write(b'hello\n') conn.write(b'hello\n')
conn.close() conn.close()
......
...@@ -32,14 +32,13 @@ DELAY = 0.1 ...@@ -32,14 +32,13 @@ DELAY = 0.1
class TestCloseSocketWhilePolling(greentest.TestCase): class TestCloseSocketWhilePolling(greentest.TestCase):
def test(self): def test(self):
try: with self.assertRaises(Exception):
sock = socket.socket() sock = socket.socket()
self._close_on_teardown(sock)
get_hub().loop.timer(0, sock.close) get_hub().loop.timer(0, sock.close)
sock.connect(('python.org', 81)) sock.connect(('python.org', 81))
except Exception:
gevent.sleep(0) gevent.sleep(0)
else:
assert False, 'expected an error here'
class TestExceptionInMainloop(greentest.TestCase): class TestExceptionInMainloop(greentest.TestCase):
......
from __future__ import print_function from __future__ import print_function
import os import os
from gevent import monkey; monkey.patch_all() from gevent import monkey; monkey.patch_all()
import re
import socket import socket
import ssl import ssl
import threading import threading
import unittest import unittest
import errno import errno
from greentest import TestCase
dirname = os.path.dirname(os.path.abspath(__file__)) dirname = os.path.dirname(os.path.abspath(__file__))
certfile = os.path.join(dirname, '2.7/keycert.pem') certfile = os.path.join(dirname, '2.7/keycert.pem')
pid = os.getpid() pid = os.getpid()
...@@ -27,13 +28,13 @@ except ImportError: ...@@ -27,13 +28,13 @@ except ImportError:
psutil = None psutil = None
class Test(unittest.TestCase): class Test(TestCase):
extra_allowed_open_states = () extra_allowed_open_states = ()
def tearDown(self): def tearDown(self):
self.extra_allowed_open_states = () self.extra_allowed_open_states = ()
unittest.TestCase.tearDown(self) super(Test, self).tearDown()
def assert_raises_EBADF(self, func): def assert_raises_EBADF(self, func):
try: try:
...@@ -156,6 +157,7 @@ class TestSocket(Test): ...@@ -156,6 +157,7 @@ class TestSocket(Test):
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
...@@ -180,6 +182,7 @@ class TestSocket(Test): ...@@ -180,6 +182,7 @@ class TestSocket(Test):
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
...@@ -213,6 +216,7 @@ class TestSocket(Test): ...@@ -213,6 +216,7 @@ class TestSocket(Test):
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
...@@ -282,10 +286,12 @@ class TestSSL(Test): ...@@ -282,10 +286,12 @@ class TestSSL(Test):
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector) x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect) t = threading.Thread(target=connect)
t.start() t.start()
...@@ -303,15 +309,18 @@ class TestSSL(Test): ...@@ -303,15 +309,18 @@ class TestSSL(Test):
def test_server_makefile1(self): def test_server_makefile1(self):
listener = socket.socket() listener = socket.socket()
self._close_on_teardown(listener)
listener.bind(('127.0.0.1', 0)) listener.bind(('127.0.0.1', 0))
port = listener.getsockname()[1] port = listener.getsockname()[1]
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector) x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect) t = threading.Thread(target=connect)
t.start() t.start()
...@@ -338,10 +347,12 @@ class TestSSL(Test): ...@@ -338,10 +347,12 @@ class TestSSL(Test):
listener.listen(1) listener.listen(1)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector) x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect) t = threading.Thread(target=connect)
t.start() t.start()
...@@ -372,10 +383,12 @@ class TestSSL(Test): ...@@ -372,10 +383,12 @@ class TestSSL(Test):
listener = ssl.wrap_socket(listener, keyfile=certfile, certfile=certfile) listener = ssl.wrap_socket(listener, keyfile=certfile, certfile=certfile)
connector = socket.socket() connector = socket.socket()
self._close_on_teardown(connector)
def connect(): def connect():
connector.connect(('127.0.0.1', port)) connector.connect(('127.0.0.1', port))
ssl.wrap_socket(connector) x = ssl.wrap_socket(connector)
self._close_on_teardown(x)
t = threading.Thread(target=connect) t = threading.Thread(target=connect)
t.start() t.start()
......
...@@ -70,7 +70,7 @@ def init_server(): ...@@ -70,7 +70,7 @@ def init_server():
def handle_request(s, raise_on_timeout): def handle_request(s, raise_on_timeout):
try: try:
conn, address = s.accept() conn, _ = s.accept()
except socket.timeout: except socket.timeout:
if raise_on_timeout: if raise_on_timeout:
raise raise
...@@ -83,7 +83,7 @@ def handle_request(s, raise_on_timeout): ...@@ -83,7 +83,7 @@ def handle_request(s, raise_on_timeout):
res = conn.send(b'bye') res = conn.send(b'bye')
#print('handle_request - sent %r' % res) #print('handle_request - sent %r' % res)
#print('handle_request - conn refcount: %s' % sys.getrefcount(conn)) #print('handle_request - conn refcount: %s' % sys.getrefcount(conn))
#conn.close() conn.close()
def make_request(port): def make_request(port):
...@@ -96,7 +96,7 @@ def make_request(port): ...@@ -96,7 +96,7 @@ def make_request(port):
res = s.recv(100) res = s.recv(100)
assert res == b'bye', repr(res) assert res == b'bye', repr(res)
#print('make_request - recvd %r' % res) #print('make_request - recvd %r' % res)
#s.close() s.close()
def run_interaction(run_client): def run_interaction(run_client):
......
...@@ -9,8 +9,11 @@ try: ...@@ -9,8 +9,11 @@ try:
assert weakref.ref(Dummy())() is None assert weakref.ref(Dummy())() is None
from gevent import socket from gevent import socket
s = socket.socket()
assert weakref.ref(socket.socket())() is None r = weakref.ref(s)
s.close()
del s
assert r() is None
except AssertionError: except AssertionError:
import sys import sys
if hasattr(sys, 'pypy_version_info'): if hasattr(sys, 'pypy_version_info'):
......
...@@ -322,19 +322,59 @@ def get_port(): ...@@ -322,19 +322,59 @@ def get_port():
class TestCreateConnection(greentest.TestCase): class TestCreateConnection(greentest.TestCase):
__timeout__ = 5 __timeout__ = 5000
def test(self): def test_refuses(self):
try: with self.assertRaises(socket.error) as cm:
socket.create_connection((greentest.DEFAULT_BIND_ADDR, get_port()), socket.create_connection((greentest.DEFAULT_BIND_ADDR, get_port()),
timeout=30, timeout=30,
source_address=('', get_port())) source_address=('', get_port()))
except socket.error as ex: ex = cm.exception
if 'refused' not in str(ex).lower(): self.assertIn('refused', str(ex).lower())
raise
else: def test_base_exception(self):
raise AssertionError('create_connection did not raise socket.error as expected') # such as a GreenletExit or a gevent.timeout.Timeout
class E(BaseException):
pass
class MockSocket(object):
created = ()
closed = False
def __init__(self, *_):
MockSocket.created += (self,)
def connect(self, _):
raise E()
def close(self):
self.closed = True
def mockgetaddrinfo(*_):
return [(1, 2, 3, 3, 5),]
import gevent.socket as gsocket
# Make sure we're monkey patched
self.assertEqual(gsocket.create_connection, socket.create_connection)
orig_socket = gsocket.socket
orig_getaddrinfo = gsocket.getaddrinfo
try:
gsocket.socket = MockSocket
gsocket.getaddrinfo = mockgetaddrinfo
with self.assertRaises(E):
socket.create_connection(('host', 'port'))
self.assertEqual(1, len(MockSocket.created))
self.assertTrue(MockSocket.created[0].closed)
finally:
MockSocket.created = ()
gsocket.socket = orig_socket
gsocket.getaddrinfo = orig_getaddrinfo
class TestFunctions(greentest.TestCase): class TestFunctions(greentest.TestCase):
......
...@@ -34,6 +34,7 @@ class TestSocketErrors(greentest.TestCase): ...@@ -34,6 +34,7 @@ class TestSocketErrors(greentest.TestCase):
def test_connection_refused(self): def test_connection_refused(self):
s = socket() s = socket()
self._close_on_teardown(s)
try: try:
s.connect(('127.0.0.1', 81)) s.connect(('127.0.0.1', 81))
except error as ex: except error as ex:
......
...@@ -12,12 +12,13 @@ def _send(socket): ...@@ -12,12 +12,13 @@ def _send(socket):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('127.0.0.1', 12345)) sock.connect(('127.0.0.1', 12345))
getattr(sock, meth)(anStructure) getattr(sock, meth)(anStructure)
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('127.0.0.1', 12345)) sock.connect(('127.0.0.1', 12345))
sock.settimeout(1.0) sock.settimeout(1.0)
getattr(sock, meth)(anStructure) getattr(sock, meth)(anStructure)
sock.close()
def TestSendBuiltinSocket(): def TestSendBuiltinSocket():
import socket import socket
......
...@@ -5,36 +5,42 @@ import greentest ...@@ -5,36 +5,42 @@ import greentest
class Test(greentest.TestCase): class Test(greentest.TestCase):
def start(self): server = None
acceptor = None
server_port = None
def _accept(self):
conn, _ = self.server.accept()
self._close_on_teardown(conn)
def setUp(self):
super(Test, self).setUp()
self.server = socket.socket() self.server = socket.socket()
self._close_on_teardown(self.server)
self.server.bind(('127.0.0.1', 0)) self.server.bind(('127.0.0.1', 0))
self.server.listen(1) self.server.listen(1)
self.server_port = self.server.getsockname()[1] self.server_port = self.server.getsockname()[1]
self.acceptor = gevent.spawn(self.server.accept) self.acceptor = gevent.spawn(self._accept)
def stop(self): def tearDown(self):
self.server.close()
self.acceptor.kill() self.acceptor.kill()
self.server.close()
del self.acceptor del self.acceptor
del self.server del self.server
super(Test, self).tearDown()
def test(self): def test(self):
self.start()
try:
sock = socket.socket() sock = socket.socket()
self._close_on_teardown(sock)
sock.connect(('127.0.0.1', self.server_port)) sock.connect(('127.0.0.1', self.server_port))
try:
sock.settimeout(0.1) sock.settimeout(0.1)
try: with self.assertRaises(socket.error) as cm:
result = sock.recv(1024) sock.recv(1024)
raise AssertionError('Expected timeout to be raised, instead recv() returned %r' % (result, ))
except socket.error as ex: ex = cm.exception
self.assertEqual(ex.args, ('timed out',)) self.assertEqual(ex.args, ('timed out',))
self.assertEqual(str(ex), 'timed out') self.assertEqual(str(ex), 'timed out')
finally:
sock.close()
finally:
self.stop()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -71,7 +71,8 @@ class Test(greentest.TestCase): ...@@ -71,7 +71,8 @@ class Test(greentest.TestCase):
def test_communicate(self): def test_communicate(self):
p = subprocess.Popen([sys.executable, "-c", p = subprocess.Popen([sys.executable, "-W", "ignore",
"-c",
'import sys,os;' 'import sys,os;'
'sys.stderr.write("pineapple");' 'sys.stderr.write("pineapple");'
'sys.stdout.write(sys.stdin.read())'], 'sys.stdout.write(sys.stdin.read())'],
...@@ -91,7 +92,9 @@ class Test(greentest.TestCase): ...@@ -91,7 +92,9 @@ class Test(greentest.TestCase):
# Native string all the things. See https://github.com/gevent/gevent/issues/1039 # Native string all the things. See https://github.com/gevent/gevent/issues/1039
p = subprocess.Popen( p = subprocess.Popen(
[ [
sys.executable, "-c", sys.executable,
"-W", "ignore",
"-c",
'import sys,os;' 'import sys,os;'
'sys.stderr.write("pineapple\\r\\n\\xff\\xff\\xf2\\xf9\\r\\n");' 'sys.stderr.write("pineapple\\r\\n\\xff\\xff\\xf2\\xf9\\r\\n");'
'sys.stdout.write(sys.stdin.read())' 'sys.stdout.write(sys.stdin.read())'
......
...@@ -5,11 +5,7 @@ from gevent import queue as Queue ...@@ -5,11 +5,7 @@ from gevent import queue as Queue
import threading import threading
import time import time
import unittest import unittest
try:
from test import support as test_support
except ImportError:
from test import test_support
from _six import xrange
QUEUE_SIZE = 5 QUEUE_SIZE = 5
...@@ -48,7 +44,7 @@ class _TriggerThread(threading.Thread): ...@@ -48,7 +44,7 @@ class _TriggerThread(threading.Thread):
# is supposed to raise an exception, call do_exceptional_blocking_test() # is supposed to raise an exception, call do_exceptional_blocking_test()
# instead. # instead.
class BlockingTestMixin: class BlockingTestMixin(object):
def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args): def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
self.t = _TriggerThread(trigger_func, trigger_args) self.t = _TriggerThread(trigger_func, trigger_args)
...@@ -65,18 +61,13 @@ class BlockingTestMixin: ...@@ -65,18 +61,13 @@ class BlockingTestMixin:
return self.result return self.result
# Call this instead if block_func is supposed to raise an exception. # Call this instead if block_func is supposed to raise an exception.
def do_exceptional_blocking_test(self,block_func, block_args, trigger_func, def do_exceptional_blocking_test(self, block_func, block_args, trigger_func,
trigger_args, expected_exception_class): trigger_args, expected_exception_class):
self.t = _TriggerThread(trigger_func, trigger_args) self.t = _TriggerThread(trigger_func, trigger_args)
self.t.start() self.t.start()
try: try:
try: with self.assertRaises(expected_exception_class):
block_func(*block_args) block_func(*block_args)
except expected_exception_class:
raise
else:
self.fail("expected exception of kind %r" %
expected_exception_class)
finally: finally:
self.t.join(10) # make sure the thread terminates self.t.join(10) # make sure the thread terminates
if self.t.isAlive(): if self.t.isAlive():
...@@ -87,6 +78,8 @@ class BlockingTestMixin: ...@@ -87,6 +78,8 @@ class BlockingTestMixin:
class BaseQueueTest(unittest.TestCase, BlockingTestMixin): class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
type2test = Queue.Queue
def setUp(self): def setUp(self):
self.cum = 0 self.cum = 0
self.cumlock = threading.Lock() self.cumlock = threading.Lock()
...@@ -100,26 +93,26 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -100,26 +93,26 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
q.put(222) q.put(222)
q.put(444) q.put(444)
target_first_items = dict( target_first_items = dict(
Queue = 111, Queue=111,
LifoQueue = 444, LifoQueue=444,
PriorityQueue = 111) PriorityQueue=111)
actual_first_item = (q.peek(), q.get()) actual_first_item = (q.peek(), q.get())
self.assertEquals(actual_first_item, self.assertEqual(actual_first_item,
(target_first_items[q.__class__.__name__], (target_first_items[q.__class__.__name__],
target_first_items[q.__class__.__name__]), target_first_items[q.__class__.__name__]),
"q.peek() and q.get() are not equal!") "q.peek() and q.get() are not equal!")
target_order = dict(Queue = [333, 222, 444], target_order = dict(Queue=[333, 222, 444],
LifoQueue = [222, 333, 111], LifoQueue=[222, 333, 111],
PriorityQueue = [222, 333, 444]) PriorityQueue=[222, 333, 444])
actual_order = [q.get(), q.get(), q.get()] actual_order = [q.get(), q.get(), q.get()]
self.assertEquals(actual_order, target_order[q.__class__.__name__], self.assertEqual(actual_order, target_order[q.__class__.__name__],
"Didn't seem to queue the correct data!") "Didn't seem to queue the correct data!")
for i in range(QUEUE_SIZE-1): for i in range(QUEUE_SIZE-1):
q.put(i) q.put(i)
self.assert_(not q.empty(), "Queue should not be empty") self.assertFalse(q.empty(), "Queue should not be empty")
self.assert_(not q.full(), "Queue should not be full") self.assertFalse(q.full(), "Queue should not be full")
q.put(999) q.put(999)
self.assert_(q.full(), "Queue should be full") self.assertTrue(q.full(), "Queue should be full")
try: try:
q.put(888, block=0) q.put(888, block=0)
self.fail("Didn't appear to block with a full queue") self.fail("Didn't appear to block with a full queue")
...@@ -130,14 +123,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -130,14 +123,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
self.fail("Didn't appear to time-out with a full queue") self.fail("Didn't appear to time-out with a full queue")
except Queue.Full: except Queue.Full:
pass pass
self.assertEquals(q.qsize(), QUEUE_SIZE) self.assertEqual(q.qsize(), QUEUE_SIZE)
# Test a blocking put # Test a blocking put
self.do_blocking_test(q.put, (888,), q.get, ()) self.do_blocking_test(q.put, (888,), q.get, ())
self.do_blocking_test(q.put, (888, True, 10), q.get, ()) self.do_blocking_test(q.put, (888, True, 10), q.get, ())
# Empty it # Empty it
for i in range(QUEUE_SIZE): for i in range(QUEUE_SIZE):
q.get() q.get()
self.assert_(q.empty(), "Queue should be empty") self.assertTrue(q.empty(), "Queue should be empty")
try: try:
q.get(block=0) q.get(block=0)
self.fail("Didn't appear to block with an empty queue") self.fail("Didn't appear to block with an empty queue")
...@@ -164,14 +157,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -164,14 +157,14 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
def queue_join_test(self, q): def queue_join_test(self, q):
self.cum = 0 self.cum = 0
for i in (0,1): for i in (0, 1):
threading.Thread(target=self.worker, args=(q,)).start() threading.Thread(target=self.worker, args=(q,)).start()
for i in xrange(100): for i in range(100):
q.put(i) q.put(i)
q.join() q.join()
self.assertEquals(self.cum, sum(range(100)), self.assertEqual(self.cum, sum(range(100)),
"q.join() did not block until all tasks were done") "q.join() did not block until all tasks were done")
for i in (0,1): for i in (0, 1):
q.put(None) # instruct the threads to close q.put(None) # instruct the threads to close
q.join() # verify that you can join twice q.join() # verify that you can join twice
...@@ -227,10 +220,6 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -227,10 +220,6 @@ class BaseQueueTest(unittest.TestCase, BlockingTestMixin):
self.simple_queue_test(q) self.simple_queue_test(q)
self.simple_queue_test(q) self.simple_queue_test(q)
class QueueTest(BaseQueueTest):
type2test = Queue.Queue
class LifoQueueTest(BaseQueueTest): class LifoQueueTest(BaseQueueTest):
type2test = Queue.LifoQueue type2test = Queue.LifoQueue
...@@ -274,79 +263,59 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -274,79 +263,59 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin):
q.put(i) q.put(i)
# Test a failing non-blocking put. # Test a failing non-blocking put.
q.fail_next_put = True q.fail_next_put = True
try: with self.assertRaises(FailingQueueException):
q.put("oops", block=0) q.put("oops", block=0)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.fail_next_put = True q.fail_next_put = True
try: with self.assertRaises(FailingQueueException):
q.put("oops", timeout=0.1) q.put("oops", timeout=0.1)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
q.put(999) q.put(999)
self.assert_(q.full(), "Queue should be full") self.assertTrue(q.full(), "Queue should be full")
# Test a failing blocking put # Test a failing blocking put
q.fail_next_put = True q.fail_next_put = True
try: with self.assertRaises(FailingQueueException):
self.do_blocking_test(q.put, (888,), q.get, ()) self.do_blocking_test(q.put, (888,), q.get, ())
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged. # Check the Queue isn't damaged.
# put failed, but get succeeded - re-add # put failed, but get succeeded - re-add
q.put(999) q.put(999)
# Test a failing timeout put # Test a failing timeout put
q.fail_next_put = True q.fail_next_put = True
try:
self.do_exceptional_blocking_test(q.put, (888, True, 10), q.get, (), self.do_exceptional_blocking_test(q.put, (888, True, 10), q.get, (),
FailingQueueException) FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# Check the Queue isn't damaged. # Check the Queue isn't damaged.
# put failed, but get succeeded - re-add # put failed, but get succeeded - re-add
q.put(999) q.put(999)
self.assert_(q.full(), "Queue should be full") self.assertTrue(q.full(), "Queue should be full")
q.get() q.get()
self.assert_(not q.full(), "Queue should not be full") self.assertFalse(q.full(), "Queue should not be full")
q.put(999) q.put(999)
self.assert_(q.full(), "Queue should be full") self.assertTrue(q.full(), "Queue should be full")
# Test a blocking put # Test a blocking put
self.do_blocking_test(q.put, (888,), q.get, ()) self.do_blocking_test(q.put, (888,), q.get, ())
# Empty it # Empty it
for i in range(QUEUE_SIZE): for i in range(QUEUE_SIZE):
q.get() q.get()
self.assert_(q.empty(), "Queue should be empty") self.assertTrue(q.empty(), "Queue should be empty")
q.put("first") q.put("first")
q.fail_next_get = True q.fail_next_get = True
try: with self.assertRaises(FailingQueueException):
q.get() q.get()
self.fail("The queue didn't fail when it should have")
except FailingQueueException: self.assertFalse(q.empty(), "Queue should not be empty")
pass
self.assert_(not q.empty(), "Queue should not be empty")
q.fail_next_get = True q.fail_next_get = True
try: with self.assertRaises(FailingQueueException):
q.get(timeout=0.1) q.get(timeout=0.1)
self.fail("The queue didn't fail when it should have") self.assertFalse(q.empty(), "Queue should not be empty")
except FailingQueueException:
pass
self.assert_(not q.empty(), "Queue should not be empty")
q.get() q.get()
self.assert_(q.empty(), "Queue should be empty") self.assertTrue(q.empty(), "Queue should be empty")
q.fail_next_get = True q.fail_next_get = True
try:
self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',), self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',),
FailingQueueException) FailingQueueException)
self.fail("The queue didn't fail when it should have")
except FailingQueueException:
pass
# put succeeded, but get failed. # put succeeded, but get failed.
self.assert_(not q.empty(), "Queue should not be empty") self.assertFalse(q.empty(), "Queue should not be empty")
q.get() q.get()
self.assert_(q.empty(), "Queue should be empty") self.assertTrue(q.empty(), "Queue should be empty")
def test_failing_queue(self): def test_failing_queue(self):
# Test to make sure a queue is functioning correctly. # Test to make sure a queue is functioning correctly.
...@@ -356,10 +325,5 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin): ...@@ -356,10 +325,5 @@ class FailingQueueTest(unittest.TestCase, BlockingTestMixin):
self.failing_queue_test(q) self.failing_queue_test(q)
def test_main():
test_support.run_unittest(QueueTest, LifoQueueTest, PriorityQueueTest,
FailingQueueTest)
if __name__ == "__main__": if __name__ == "__main__":
test_main() unittest.main()
...@@ -297,6 +297,16 @@ def main(): ...@@ -297,6 +297,16 @@ def main():
config_data = f.read() config_data = f.read()
six.exec_(config_data, config) six.exec_(config_data, config)
FAILING_TESTS = config['FAILING_TESTS'] FAILING_TESTS = config['FAILING_TESTS']
if 'PYTHONWARNINGS' not in os.environ and not sys.warnoptions:
# Enable default warnings such as ResourceWarning.
# On Python 3[.6], the system site.py module has
# "open(fullname, 'rU')" which produces the warning that
# 'U' is deprecated, so ignore warnings from site.py
os.environ['PYTHONWARNINGS'] = 'default,ignore:::site:'
if 'PYTHONFAULTHANDLER' not in os.environ:
os.environ['PYTHONFAULTHANDLER'] = 'true'
tests = discover(options.tests, options.ignore, coverage) tests = discover(options.tests, options.ignore, coverage)
if options.discover: if options.discover:
for cmd, options in tests: for cmd, options in tests:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment