Commit 14635186 authored by Victor Stinner's avatar Victor Stinner Committed by GitHub

[2.7] bpo-31234: Join threads explicitly in tests (#7406)

* Add support.wait_threads_exit(): context manager looping at exit
  until the number of threads decreases to its original number.
* Add some missing thread.join()
* test_asyncore.test_send(): call explicitly t.join() because the cleanup
  function is only called outside the test method, whereas the method
  has a @test_support.reap_threads decorator
* test_hashlib: replace threading.Event with thread.join()
* test_thread:

  * Use wait_threads_exit() context manager
  * Replace test_support with support
  * test_forkinthread(): check child process exit status in the
    main thread to better handle error.
parent fadcd445
......@@ -1722,6 +1722,43 @@ def reap_threads(func):
threading_cleanup(*key)
return decorator
@contextlib.contextmanager
def wait_threads_exit(timeout=60.0):
"""
bpo-31234: Context manager to wait until all threads created in the with
statement exit.
Use thread.count() to check if threads exited. Indirectly, wait until
threads exit the internal t_bootstrap() C function of the thread module.
threading_setup() and threading_cleanup() are designed to emit a warning
if a test leaves running threads in the background. This context manager
is designed to cleanup threads started by the thread.start_new_thread()
which doesn't allow to wait for thread exit, whereas thread.Thread has a
join() method.
"""
old_count = thread._count()
try:
yield
finally:
start_time = time.time()
deadline = start_time + timeout
while True:
count = thread._count()
if count <= old_count:
break
if time.time() > deadline:
dt = time.time() - start_time
msg = ("wait_threads() failed to cleanup %s "
"threads after %.1f seconds "
"(count: %s, old count: %s)"
% (count - old_count, dt, count, old_count))
raise AssertionError(msg)
time.sleep(0.010)
gc_collect()
def reap_children():
"""Use this function at the end of test_main() whenever sub-processes
are started. This will help ensure that no extra children (zombies)
......
......@@ -727,19 +727,20 @@ class BaseTestAPI(unittest.TestCase):
server = TCPServer()
t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
t.start()
self.addCleanup(t.join)
for x in xrange(20):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
try:
s.connect(server.address)
except socket.error:
pass
finally:
s.close()
try:
for x in xrange(20):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(.2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', 1, 0))
try:
s.connect(server.address)
except socket.error:
pass
finally:
s.close()
finally:
t.join()
class TestAPI_UseSelect(BaseTestAPI):
......
......@@ -371,25 +371,25 @@ class HashLibTestCase(unittest.TestCase):
data = smallest_data*200000
expected_hash = hashlib.sha1(data*num_threads).hexdigest()
def hash_in_chunks(chunk_size, event):
def hash_in_chunks(chunk_size):
index = 0
while index < len(data):
hasher.update(data[index:index+chunk_size])
index += chunk_size
event.set()
events = []
threads = []
for threadnum in xrange(num_threads):
chunk_size = len(data) // (10**threadnum)
assert chunk_size > 0
assert chunk_size % len(smallest_data) == 0
event = threading.Event()
events.append(event)
threading.Thread(target=hash_in_chunks,
args=(chunk_size, event)).start()
for event in events:
event.wait()
thread = threading.Thread(target=hash_in_chunks,
args=(chunk_size,))
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertEqual(expected_hash, hasher.hexdigest())
......
......@@ -66,6 +66,7 @@ class TestServerThread(threading.Thread):
def stop(self):
self.server.shutdown()
self.join()
class BaseTestCase(unittest.TestCase):
......
......@@ -306,12 +306,14 @@ class TooLongLineTests(unittest.TestCase):
self.sock.settimeout(15)
self.port = test_support.bind_port(self.sock)
servargs = (self.evt, self.respdata, self.sock)
threading.Thread(target=server, args=servargs).start()
self.thread = threading.Thread(target=server, args=servargs)
self.thread.start()
self.evt.wait()
self.evt.clear()
def tearDown(self):
self.evt.wait()
self.thread.join()
sys.stdout = self.old_stdout
def testLineTooLong(self):
......
import os
import unittest
import random
from test import test_support
thread = test_support.import_module('thread')
from test import support
thread = support.import_module('thread')
import time
import sys
import weakref
......@@ -17,7 +17,7 @@ _print_mutex = thread.allocate_lock()
def verbose_print(arg):
"""Helper function for printing out debugging output."""
if test_support.verbose:
if support.verbose:
with _print_mutex:
print arg
......@@ -34,8 +34,8 @@ class BasicThreadTest(unittest.TestCase):
self.running = 0
self.next_ident = 0
key = test_support.threading_setup()
self.addCleanup(test_support.threading_cleanup, *key)
key = support.threading_setup()
self.addCleanup(support.threading_cleanup, *key)
class ThreadRunningTests(BasicThreadTest):
......@@ -60,12 +60,13 @@ class ThreadRunningTests(BasicThreadTest):
self.done_mutex.release()
def test_starting_threads(self):
# Basic test for thread creation.
for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for tasks to complete...")
self.done_mutex.acquire()
verbose_print("all tasks done")
with support.wait_threads_exit():
# Basic test for thread creation.
for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for tasks to complete...")
self.done_mutex.acquire()
verbose_print("all tasks done")
def test_stack_size(self):
# Various stack size tests.
......@@ -95,12 +96,13 @@ class ThreadRunningTests(BasicThreadTest):
verbose_print("trying stack_size = (%d)" % tss)
self.next_ident = 0
self.created = 0
for i in range(NUMTASKS):
self.newtask()
with support.wait_threads_exit():
for i in range(NUMTASKS):
self.newtask()
verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire()
verbose_print("all tasks done")
verbose_print("waiting for all tasks to complete")
self.done_mutex.acquire()
verbose_print("all tasks done")
thread.stack_size(0)
......@@ -110,25 +112,28 @@ class ThreadRunningTests(BasicThreadTest):
mut = thread.allocate_lock()
mut.acquire()
started = []
def task():
started.append(None)
mut.acquire()
mut.release()
thread.start_new_thread(task, ())
while not started:
time.sleep(0.01)
self.assertEqual(thread._count(), orig + 1)
# Allow the task to finish.
mut.release()
# The only reliable way to be sure that the thread ended from the
# interpreter's point of view is to wait for the function object to be
# destroyed.
done = []
wr = weakref.ref(task, lambda _: done.append(None))
del task
while not done:
time.sleep(0.01)
self.assertEqual(thread._count(), orig)
with support.wait_threads_exit():
thread.start_new_thread(task, ())
while not started:
time.sleep(0.01)
self.assertEqual(thread._count(), orig + 1)
# Allow the task to finish.
mut.release()
# The only reliable way to be sure that the thread ended from the
# interpreter's point of view is to wait for the function object to be
# destroyed.
done = []
wr = weakref.ref(task, lambda _: done.append(None))
del task
while not done:
time.sleep(0.01)
self.assertEqual(thread._count(), orig)
def test_save_exception_state_on_error(self):
# See issue #14474
......@@ -143,14 +148,13 @@ class ThreadRunningTests(BasicThreadTest):
real_write(self, *args)
c = thread._count()
started = thread.allocate_lock()
with test_support.captured_output("stderr") as stderr:
with support.captured_output("stderr") as stderr:
real_write = stderr.write
stderr.write = mywrite
started.acquire()
thread.start_new_thread(task, ())
started.acquire()
while thread._count() > c:
time.sleep(0.01)
with support.wait_threads_exit():
thread.start_new_thread(task, ())
started.acquire()
self.assertIn("Traceback", stderr.getvalue())
......@@ -182,13 +186,14 @@ class Barrier:
class BarrierTest(BasicThreadTest):
def test_barrier(self):
self.bar = Barrier(NUMTASKS)
self.running = NUMTASKS
for i in range(NUMTASKS):
thread.start_new_thread(self.task2, (i,))
verbose_print("waiting for tasks to end")
self.done_mutex.acquire()
verbose_print("tasks done")
with support.wait_threads_exit():
self.bar = Barrier(NUMTASKS)
self.running = NUMTASKS
for i in range(NUMTASKS):
thread.start_new_thread(self.task2, (i,))
verbose_print("waiting for tasks to end")
self.done_mutex.acquire()
verbose_print("tasks done")
def task2(self, ident):
for i in range(NUMTRIPS):
......@@ -226,8 +231,9 @@ class TestForkInThread(unittest.TestCase):
@unittest.skipIf(sys.platform.startswith('win'),
"This test is only appropriate for POSIX-like systems.")
@test_support.reap_threads
@support.reap_threads
def test_forkinthread(self):
non_local = {'status': None}
def thread1():
try:
pid = os.fork() # fork in a thread
......@@ -246,11 +252,13 @@ class TestForkInThread(unittest.TestCase):
else: # parent
os.close(self.write_fd)
pid, status = os.waitpid(pid, 0)
self.assertEqual(status, 0)
non_local['status'] = status
thread.start_new_thread(thread1, ())
self.assertEqual(os.read(self.read_fd, 2), "OK",
"Unable to fork() in thread")
with support.wait_threads_exit():
thread.start_new_thread(thread1, ())
self.assertEqual(os.read(self.read_fd, 2), "OK",
"Unable to fork() in thread")
self.assertEqual(non_local['status'], 0)
def tearDown(self):
try:
......@@ -265,7 +273,7 @@ class TestForkInThread(unittest.TestCase):
def test_main():
test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
TestForkInThread)
if __name__ == "__main__":
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment