Commit 14635186 authored by Victor Stinner's avatar Victor Stinner Committed by GitHub

[2.7] bpo-31234: Join threads explicitly in tests (#7406)

* Add support.wait_threads_exit(): context manager looping at exit
  until the number of threads decreases to its original number.
* Add some missing thread.join()
* test_asyncore.test_send(): call explicitly t.join() because the cleanup
  function is only called outside the test method, whereas the method
  has a @test_support.reap_threads decorator
* test_hashlib: replace threading.Event with thread.join()
* test_thread:

  * Use wait_threads_exit() context manager
  * Replace test_support with support
  * test_forkinthread(): check child process exit status in the
    main thread to better handle error.
parent fadcd445
...@@ -1722,6 +1722,43 @@ def reap_threads(func): ...@@ -1722,6 +1722,43 @@ def reap_threads(func):
threading_cleanup(*key) threading_cleanup(*key)
return decorator return decorator
@contextlib.contextmanager
def wait_threads_exit(timeout=60.0):
"""
bpo-31234: Context manager to wait until all threads created in the with
statement exit.
Use thread.count() to check if threads exited. Indirectly, wait until
threads exit the internal t_bootstrap() C function of the thread module.
threading_setup() and threading_cleanup() are designed to emit a warning
if a test leaves running threads in the background. This context manager
is designed to cleanup threads started by the thread.start_new_thread()
which doesn't allow to wait for thread exit, whereas thread.Thread has a
join() method.
"""
old_count = thread._count()
try:
yield
finally:
start_time = time.time()
deadline = start_time + timeout
while True:
count = thread._count()
if count <= old_count:
break
if time.time() > deadline:
dt = time.time() - start_time
msg = ("wait_threads() failed to cleanup %s "
"threads after %.1f seconds "
"(count: %s, old count: %s)"
% (count - old_count, dt, count, old_count))
raise AssertionError(msg)
time.sleep(0.010)
gc_collect()
def reap_children(): def reap_children():
"""Use this function at the end of test_main() whenever sub-processes """Use this function at the end of test_main() whenever sub-processes
are started. This will help ensure that no extra children (zombies) are started. This will help ensure that no extra children (zombies)
......
...@@ -727,19 +727,20 @@ class BaseTestAPI(unittest.TestCase): ...@@ -727,19 +727,20 @@ class BaseTestAPI(unittest.TestCase):
server = TCPServer() server = TCPServer()
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500)) t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
t.start() t.start()
self.addCleanup(t.join) try:
for x in xrange(20):
for x in xrange(20): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(.2)
s.settimeout(.2) s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, struct.pack('ii', 1, 0))
struct.pack('ii', 1, 0)) try:
try: s.connect(server.address)
s.connect(server.address) except socket.error:
except socket.error: pass
pass finally:
finally: s.close()
s.close() finally:
t.join()
class TestAPI_UseSelect(BaseTestAPI): class TestAPI_UseSelect(BaseTestAPI):
......
...@@ -371,25 +371,25 @@ class HashLibTestCase(unittest.TestCase): ...@@ -371,25 +371,25 @@ class HashLibTestCase(unittest.TestCase):
data = smallest_data*200000 data = smallest_data*200000
expected_hash = hashlib.sha1(data*num_threads).hexdigest() expected_hash = hashlib.sha1(data*num_threads).hexdigest()
def hash_in_chunks(chunk_size, event): def hash_in_chunks(chunk_size):
index = 0 index = 0
while index < len(data): while index < len(data):
hasher.update(data[index:index+chunk_size]) hasher.update(data[index:index+chunk_size])
index += chunk_size index += chunk_size
event.set()
events = [] threads = []
for threadnum in xrange(num_threads): for threadnum in xrange(num_threads):
chunk_size = len(data) // (10**threadnum) chunk_size = len(data) // (10**threadnum)
assert chunk_size > 0 assert chunk_size > 0
assert chunk_size % len(smallest_data) == 0 assert chunk_size % len(smallest_data) == 0
event = threading.Event() thread = threading.Thread(target=hash_in_chunks,
events.append(event) args=(chunk_size,))
threading.Thread(target=hash_in_chunks, threads.append(thread)
args=(chunk_size, event)).start()
for thread in threads:
for event in events: thread.start()
event.wait() for thread in threads:
thread.join()
self.assertEqual(expected_hash, hasher.hexdigest()) self.assertEqual(expected_hash, hasher.hexdigest())
......
...@@ -66,6 +66,7 @@ class TestServerThread(threading.Thread): ...@@ -66,6 +66,7 @@ class TestServerThread(threading.Thread):
def stop(self): def stop(self):
self.server.shutdown() self.server.shutdown()
self.join()
class BaseTestCase(unittest.TestCase): class BaseTestCase(unittest.TestCase):
......
...@@ -306,12 +306,14 @@ class TooLongLineTests(unittest.TestCase): ...@@ -306,12 +306,14 @@ class TooLongLineTests(unittest.TestCase):
self.sock.settimeout(15) self.sock.settimeout(15)
self.port = test_support.bind_port(self.sock) self.port = test_support.bind_port(self.sock)
servargs = (self.evt, self.respdata, self.sock) servargs = (self.evt, self.respdata, self.sock)
threading.Thread(target=server, args=servargs).start() self.thread = threading.Thread(target=server, args=servargs)
self.thread.start()
self.evt.wait() self.evt.wait()
self.evt.clear() self.evt.clear()
def tearDown(self): def tearDown(self):
self.evt.wait() self.evt.wait()
self.thread.join()
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
def testLineTooLong(self): def testLineTooLong(self):
......
import os import os
import unittest import unittest
import random import random
from test import test_support from test import support
thread = test_support.import_module('thread') thread = support.import_module('thread')
import time import time
import sys import sys
import weakref import weakref
...@@ -17,7 +17,7 @@ _print_mutex = thread.allocate_lock() ...@@ -17,7 +17,7 @@ _print_mutex = thread.allocate_lock()
def verbose_print(arg): def verbose_print(arg):
"""Helper function for printing out debugging output.""" """Helper function for printing out debugging output."""
if test_support.verbose: if support.verbose:
with _print_mutex: with _print_mutex:
print arg print arg
...@@ -34,8 +34,8 @@ class BasicThreadTest(unittest.TestCase): ...@@ -34,8 +34,8 @@ class BasicThreadTest(unittest.TestCase):
self.running = 0 self.running = 0
self.next_ident = 0 self.next_ident = 0
key = test_support.threading_setup() key = support.threading_setup()
self.addCleanup(test_support.threading_cleanup, *key) self.addCleanup(support.threading_cleanup, *key)
class ThreadRunningTests(BasicThreadTest): class ThreadRunningTests(BasicThreadTest):
...@@ -60,12 +60,13 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -60,12 +60,13 @@ class ThreadRunningTests(BasicThreadTest):
self.done_mutex.release() self.done_mutex.release()
def test_starting_threads(self): def test_starting_threads(self):
# Basic test for thread creation. with support.wait_threads_exit():
for i in range(NUMTASKS): # Basic test for thread creation.
self.newtask() for i in range(NUMTASKS):
verbose_print("waiting for tasks to complete...") self.newtask()
self.done_mutex.acquire() verbose_print("waiting for tasks to complete...")
verbose_print("all tasks done") self.done_mutex.acquire()
verbose_print("all tasks done")
def test_stack_size(self): def test_stack_size(self):
# Various stack size tests. # Various stack size tests.
...@@ -95,12 +96,13 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -95,12 +96,13 @@ class ThreadRunningTests(BasicThreadTest):
verbose_print("trying stack_size = (%d)" % tss) verbose_print("trying stack_size = (%d)" % tss)
self.next_ident = 0 self.next_ident = 0
self.created = 0 self.created = 0
for i in range(NUMTASKS): with support.wait_threads_exit():
self.newtask() for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for all tasks to complete") verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire() self.done_mutex.acquire()
verbose_print("all tasks done") verbose_print("all tasks done")
thread.stack_size(0) thread.stack_size(0)
...@@ -110,25 +112,28 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -110,25 +112,28 @@ class ThreadRunningTests(BasicThreadTest):
mut = thread.allocate_lock() mut = thread.allocate_lock()
mut.acquire() mut.acquire()
started = [] started = []
def task(): def task():
started.append(None) started.append(None)
mut.acquire() mut.acquire()
mut.release() mut.release()
thread.start_new_thread(task, ())
while not started: with support.wait_threads_exit():
time.sleep(0.01) thread.start_new_thread(task, ())
self.assertEqual(thread._count(), orig + 1) while not started:
# Allow the task to finish. time.sleep(0.01)
mut.release() self.assertEqual(thread._count(), orig + 1)
# The only reliable way to be sure that the thread ended from the # Allow the task to finish.
# interpreter's point of view is to wait for the function object to be mut.release()
# destroyed. # The only reliable way to be sure that the thread ended from the
done = [] # interpreter's point of view is to wait for the function object to be
wr = weakref.ref(task, lambda _: done.append(None)) # destroyed.
del task done = []
while not done: wr = weakref.ref(task, lambda _: done.append(None))
time.sleep(0.01) del task
self.assertEqual(thread._count(), orig) while not done:
time.sleep(0.01)
self.assertEqual(thread._count(), orig)
def test_save_exception_state_on_error(self): def test_save_exception_state_on_error(self):
# See issue #14474 # See issue #14474
...@@ -143,14 +148,13 @@ class ThreadRunningTests(BasicThreadTest): ...@@ -143,14 +148,13 @@ class ThreadRunningTests(BasicThreadTest):
real_write(self, *args) real_write(self, *args)
c = thread._count() c = thread._count()
started = thread.allocate_lock() started = thread.allocate_lock()
with test_support.captured_output("stderr") as stderr: with support.captured_output("stderr") as stderr:
real_write = stderr.write real_write = stderr.write
stderr.write = mywrite stderr.write = mywrite
started.acquire() started.acquire()
thread.start_new_thread(task, ()) with support.wait_threads_exit():
started.acquire() thread.start_new_thread(task, ())
while thread._count() > c: started.acquire()
time.sleep(0.01)
self.assertIn("Traceback", stderr.getvalue()) self.assertIn("Traceback", stderr.getvalue())
...@@ -182,13 +186,14 @@ class Barrier: ...@@ -182,13 +186,14 @@ class Barrier:
class BarrierTest(BasicThreadTest): class BarrierTest(BasicThreadTest):
def test_barrier(self): def test_barrier(self):
self.bar = Barrier(NUMTASKS) with support.wait_threads_exit():
self.running = NUMTASKS self.bar = Barrier(NUMTASKS)
for i in range(NUMTASKS): self.running = NUMTASKS
thread.start_new_thread(self.task2, (i,)) for i in range(NUMTASKS):
verbose_print("waiting for tasks to end") thread.start_new_thread(self.task2, (i,))
self.done_mutex.acquire() verbose_print("waiting for tasks to end")
verbose_print("tasks done") self.done_mutex.acquire()
verbose_print("tasks done")
def task2(self, ident): def task2(self, ident):
for i in range(NUMTRIPS): for i in range(NUMTRIPS):
...@@ -226,8 +231,9 @@ class TestForkInThread(unittest.TestCase): ...@@ -226,8 +231,9 @@ class TestForkInThread(unittest.TestCase):
@unittest.skipIf(sys.platform.startswith('win'), @unittest.skipIf(sys.platform.startswith('win'),
"This test is only appropriate for POSIX-like systems.") "This test is only appropriate for POSIX-like systems.")
@test_support.reap_threads @support.reap_threads
def test_forkinthread(self): def test_forkinthread(self):
non_local = {'status': None}
def thread1(): def thread1():
try: try:
pid = os.fork() # fork in a thread pid = os.fork() # fork in a thread
...@@ -246,11 +252,13 @@ class TestForkInThread(unittest.TestCase): ...@@ -246,11 +252,13 @@ class TestForkInThread(unittest.TestCase):
else: # parent else: # parent
os.close(self.write_fd) os.close(self.write_fd)
pid, status = os.waitpid(pid, 0) pid, status = os.waitpid(pid, 0)
self.assertEqual(status, 0) non_local['status'] = status
thread.start_new_thread(thread1, ()) with support.wait_threads_exit():
self.assertEqual(os.read(self.read_fd, 2), "OK", thread.start_new_thread(thread1, ())
"Unable to fork() in thread") self.assertEqual(os.read(self.read_fd, 2), "OK",
"Unable to fork() in thread")
self.assertEqual(non_local['status'], 0)
def tearDown(self): def tearDown(self):
try: try:
...@@ -265,7 +273,7 @@ class TestForkInThread(unittest.TestCase): ...@@ -265,7 +273,7 @@ class TestForkInThread(unittest.TestCase):
def test_main(): def test_main():
test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests, support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
TestForkInThread) TestForkInThread)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment