Commit 029a8d9f authored by Jason Madden's avatar Jason Madden

Use tests from 3.7a4

parent 689538e5
...@@ -726,14 +726,10 @@ class BaseTestAPI: ...@@ -726,14 +726,10 @@ class BaseTestAPI:
def test_create_socket(self): def test_create_socket(self):
s = asyncore.dispatcher() s = asyncore.dispatcher()
s.create_socket(self.family) s.create_socket(self.family)
self.assertEqual(s.socket.type, socket.SOCK_STREAM)
self.assertEqual(s.socket.family, self.family) self.assertEqual(s.socket.family, self.family)
SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0) self.assertEqual(s.socket.gettimeout(), 0)
sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK self.assertFalse(s.socket.get_inheritable())
if hasattr(socket, 'SOCK_CLOEXEC'):
self.assertIn(s.socket.type,
(sock_type | socket.SOCK_CLOEXEC, sock_type))
else:
self.assertEqual(s.socket.type, sock_type)
def test_bind(self): def test_bind(self):
if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX:
......
...@@ -1444,7 +1444,7 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1444,7 +1444,7 @@ class GeneralModuleTests(unittest.TestCase):
socket.gethostbyname(domain) socket.gethostbyname(domain)
socket.gethostbyname_ex(domain) socket.gethostbyname_ex(domain)
socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM) socket.getaddrinfo(domain,0,socket.AF_UNSPEC,socket.SOCK_STREAM)
# this may not work if the forward lookup choses the IPv6 address, as that doesn't # this may not work if the forward lookup chooses the IPv6 address, as that doesn't
# have a reverse entry yet # have a reverse entry yet
# socket.gethostbyaddr('испытание.python.org') # socket.gethostbyaddr('испытание.python.org')
...@@ -1577,6 +1577,22 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1577,6 +1577,22 @@ class GeneralModuleTests(unittest.TestCase):
self.assertEqual(str(s.family), 'AddressFamily.AF_INET') self.assertEqual(str(s.family), 'AddressFamily.AF_INET')
self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM') self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM')
def test_socket_consistent_sock_type(self):
SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', 0)
SOCK_CLOEXEC = getattr(socket, 'SOCK_CLOEXEC', 0)
sock_type = socket.SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC
with socket.socket(socket.AF_INET, sock_type) as s:
self.assertEqual(s.type, socket.SOCK_STREAM)
s.settimeout(1)
self.assertEqual(s.type, socket.SOCK_STREAM)
s.settimeout(0)
self.assertEqual(s.type, socket.SOCK_STREAM)
s.setblocking(True)
self.assertEqual(s.type, socket.SOCK_STREAM)
s.setblocking(False)
self.assertEqual(s.type, socket.SOCK_STREAM)
@unittest.skipIf(os.name == 'nt', 'Will not work on Windows') @unittest.skipIf(os.name == 'nt', 'Will not work on Windows')
def test_uknown_socket_family_repr(self): def test_uknown_socket_family_repr(self):
# Test that when created with a family that's not one of the known # Test that when created with a family that's not one of the known
...@@ -1589,9 +1605,18 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1589,9 +1605,18 @@ class GeneralModuleTests(unittest.TestCase):
# On Windows this trick won't work, so the test is skipped. # On Windows this trick won't work, so the test is skipped.
fd, path = tempfile.mkstemp() fd, path = tempfile.mkstemp()
self.addCleanup(os.unlink, path) self.addCleanup(os.unlink, path)
with socket.socket(family=42424, type=13331, fileno=fd) as s: unknown_family = max(socket.AddressFamily.__members__.values()) + 1
self.assertEqual(s.family, 42424)
self.assertEqual(s.type, 13331) unknown_type = max(
kind
for name, kind in socket.SocketKind.__members__.items()
if name not in {'SOCK_NONBLOCK', 'SOCK_CLOEXEC'}
) + 1
with socket.socket(
family=unknown_family, type=unknown_type, fileno=fd) as s:
self.assertEqual(s.family, unknown_family)
self.assertEqual(s.type, unknown_type)
@unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()') @unittest.skipUnless(hasattr(os, 'sendfile'), 'test needs os.sendfile()')
def test__sendfile_use_sendfile(self): def test__sendfile_use_sendfile(self):
...@@ -4399,7 +4424,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): ...@@ -4399,7 +4424,7 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):
self.write_file.write(self.write_msg) self.write_file.write(self.write_msg)
self.write_file.flush() self.write_file.flush()
self.evt2.set() self.evt2.set()
# Avoid cloding the socket before the server test has finished, # Avoid closing the socket before the server test has finished,
# otherwise system recv() will return 0 instead of EWOULDBLOCK. # otherwise system recv() will return 0 instead of EWOULDBLOCK.
self.serv_finished.wait(5.0) self.serv_finished.wait(5.0)
...@@ -4533,6 +4558,10 @@ class NetworkConnectionNoServer(unittest.TestCase): ...@@ -4533,6 +4558,10 @@ class NetworkConnectionNoServer(unittest.TestCase):
expected_errnos = [ errno.ECONNREFUSED, ] expected_errnos = [ errno.ECONNREFUSED, ]
if hasattr(errno, 'ENETUNREACH'): if hasattr(errno, 'ENETUNREACH'):
expected_errnos.append(errno.ENETUNREACH) expected_errnos.append(errno.ENETUNREACH)
if hasattr(errno, 'EADDRNOTAVAIL'):
# bpo-31910: socket.create_connection() fails randomly
# with EADDRNOTAVAIL on Travis CI
expected_errnos.append(errno.EADDRNOTAVAIL)
self.assertIn(cm.exception.errno, expected_errnos) self.assertIn(cm.exception.errno, expected_errnos)
...@@ -4671,7 +4700,7 @@ class TCPTimeoutTest(SocketTCPTest): ...@@ -4671,7 +4700,7 @@ class TCPTimeoutTest(SocketTCPTest):
'test needs signal.alarm()') 'test needs signal.alarm()')
def testInterruptedTimeout(self): def testInterruptedTimeout(self):
# XXX I don't know how to do this test on MSWindows or any other # XXX I don't know how to do this test on MSWindows or any other
# plaform that doesn't support signal.alarm() or os.kill(), though # platform that doesn't support signal.alarm() or os.kill(), though
# the bug should have existed on all platforms. # the bug should have existed on all platforms.
self.serv.settimeout(5.0) # must be longer than alarm self.serv.settimeout(5.0) # must be longer than alarm
class Alarm(Exception): class Alarm(Exception):
...@@ -5080,7 +5109,7 @@ class InheritanceTest(unittest.TestCase): ...@@ -5080,7 +5109,7 @@ class InheritanceTest(unittest.TestCase):
def test_SOCK_CLOEXEC(self): def test_SOCK_CLOEXEC(self):
with socket.socket(socket.AF_INET, with socket.socket(socket.AF_INET,
socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s:
self.assertTrue(s.type & socket.SOCK_CLOEXEC) self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertFalse(s.get_inheritable()) self.assertFalse(s.get_inheritable())
def test_default_inheritable(self): def test_default_inheritable(self):
...@@ -5132,8 +5161,6 @@ class InheritanceTest(unittest.TestCase): ...@@ -5132,8 +5161,6 @@ class InheritanceTest(unittest.TestCase):
0) 0)
@unittest.skipUnless(hasattr(socket, "socketpair"),
"need socket.socketpair()")
def test_socketpair(self): def test_socketpair(self):
s1, s2 = socket.socketpair() s1, s2 = socket.socketpair()
self.addCleanup(s1.close) self.addCleanup(s1.close)
...@@ -5147,11 +5174,15 @@ class InheritanceTest(unittest.TestCase): ...@@ -5147,11 +5174,15 @@ class InheritanceTest(unittest.TestCase):
class NonblockConstantTest(unittest.TestCase): class NonblockConstantTest(unittest.TestCase):
def checkNonblock(self, s, nonblock=True, timeout=0.0): def checkNonblock(self, s, nonblock=True, timeout=0.0):
if nonblock: if nonblock:
self.assertTrue(s.type & socket.SOCK_NONBLOCK) self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertEqual(s.gettimeout(), timeout) self.assertEqual(s.gettimeout(), timeout)
self.assertTrue(
fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
else: else:
self.assertFalse(s.type & socket.SOCK_NONBLOCK) self.assertEqual(s.type, socket.SOCK_STREAM)
self.assertEqual(s.gettimeout(), None) self.assertEqual(s.gettimeout(), None)
self.assertFalse(
fcntl.fcntl(s, fcntl.F_GETFL, os.O_NONBLOCK) & os.O_NONBLOCK)
@support.requires_linux_version(2, 6, 28) @support.requires_linux_version(2, 6, 28)
def test_SOCK_NONBLOCK(self): def test_SOCK_NONBLOCK(self):
...@@ -5295,7 +5326,7 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest): ...@@ -5295,7 +5326,7 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest):
Test the send() implementation of socket.sendfile(). Test the send() implementation of socket.sendfile().
""" """
FILESIZE = (10 * 1024 * 1024) # 10MB FILESIZE = (10 * 1024 * 1024) # 10 MiB
BUFSIZE = 8192 BUFSIZE = 8192
FILEDATA = b"" FILEDATA = b""
TIMEOUT = 2 TIMEOUT = 2
...@@ -5571,6 +5602,9 @@ class LinuxKernelCryptoAPI(unittest.TestCase): ...@@ -5571,6 +5602,9 @@ class LinuxKernelCryptoAPI(unittest.TestCase):
else: else:
return sock return sock
# bpo-31705: On kernel older than 4.5, sendto() failed with ENOKEY,
# at least on ppc64le architecture
@support.requires_linux_version(4, 5)
def test_sha256(self): def test_sha256(self):
expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396" expected = bytes.fromhex("ba7816bf8f01cfea414140de5dae2223b00361a396"
"177a9cb410ff61f20015ad") "177a9cb410ff61f20015ad")
......
...@@ -850,7 +850,7 @@ class BasicSocketTests(unittest.TestCase): ...@@ -850,7 +850,7 @@ class BasicSocketTests(unittest.TestCase):
self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901) self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901)
self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds
# no special treatement for the special value: # no special treatment for the special value:
# 99991231235959Z (rfc 5280) # 99991231235959Z (rfc 5280)
self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
......
...@@ -370,7 +370,7 @@ class ProcessTestCase(BaseTestCase): ...@@ -370,7 +370,7 @@ class ProcessTestCase(BaseTestCase):
# is relative. # is relative.
python_dir, python_base = self._split_python_path() python_dir, python_base = self._split_python_path()
rel_python = os.path.join(os.curdir, python_base) rel_python = os.path.join(os.curdir, python_base)
with support.temp_cwd('test_cwd_with_relative_arg', quiet=True) as wrong_dir: # gevent: use distinct name, avoid Travis CI failure with support.temp_cwd('test_cwd_with_relative_arg', quiet=True) as wrong_dir: # gevent: use distinct name, avoid Travis CI failure)
# Before calling with the correct cwd, confirm that the call fails # Before calling with the correct cwd, confirm that the call fails
# without cwd and with the wrong cwd. # without cwd and with the wrong cwd.
self.assertRaises(FileNotFoundError, subprocess.Popen, self.assertRaises(FileNotFoundError, subprocess.Popen,
...@@ -2290,11 +2290,11 @@ class POSIXProcessTestCase(BaseTestCase): ...@@ -2290,11 +2290,11 @@ class POSIXProcessTestCase(BaseTestCase):
fds_to_keep = set(open_fds.pop() for _ in range(8)) fds_to_keep = set(open_fds.pop() for _ in range(8))
p = subprocess.Popen([sys.executable, fd_status], p = subprocess.Popen([sys.executable, fd_status],
stdout=subprocess.PIPE, close_fds=True, stdout=subprocess.PIPE, close_fds=True,
pass_fds=()) pass_fds=fds_to_keep)
output, ignored = p.communicate() output, ignored = p.communicate()
remaining_fds = set(map(int, output.split(b','))) remaining_fds = set(map(int, output.split(b',')))
self.assertFalse(remaining_fds & fds_to_keep & open_fds, self.assertFalse((remaining_fds - fds_to_keep) & open_fds,
"Some fds not in pass_fds were left open") "Some fds not in pass_fds were left open")
self.assertIn(1, remaining_fds, "Subprocess failed") self.assertIn(1, remaining_fds, "Subprocess failed")
...@@ -2743,11 +2743,6 @@ class Win32ProcessTestCase(BaseTestCase): ...@@ -2743,11 +2743,6 @@ class Win32ProcessTestCase(BaseTestCase):
[sys.executable, "-c", [sys.executable, "-c",
"import sys; sys.exit(47)"], "import sys; sys.exit(47)"],
preexec_fn=lambda: 1) preexec_fn=lambda: 1)
self.assertRaises(ValueError, subprocess.call,
[sys.executable, "-c",
"import sys; sys.exit(47)"],
stdout=subprocess.PIPE,
close_fds=True)
@support.cpython_only @support.cpython_only
def test_issue31471(self): def test_issue31471(self):
...@@ -2765,6 +2760,67 @@ class Win32ProcessTestCase(BaseTestCase): ...@@ -2765,6 +2760,67 @@ class Win32ProcessTestCase(BaseTestCase):
close_fds=True) close_fds=True)
self.assertEqual(rc, 47) self.assertEqual(rc, 47)
def test_close_fds_with_stdio(self):
import msvcrt
fds = os.pipe()
self.addCleanup(os.close, fds[0])
self.addCleanup(os.close, fds[1])
handles = []
for fd in fds:
os.set_inheritable(fd, True)
handles.append(msvcrt.get_osfhandle(fd))
p = subprocess.Popen([sys.executable, "-c",
"import msvcrt; print(msvcrt.open_osfhandle({}, 0))".format(handles[0])],
stdout=subprocess.PIPE, close_fds=False)
stdout, stderr = p.communicate()
self.assertEqual(p.returncode, 0)
int(stdout.strip()) # Check that stdout is an integer
p = subprocess.Popen([sys.executable, "-c",
"import msvcrt; print(msvcrt.open_osfhandle({}, 0))".format(handles[0])],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True)
stdout, stderr = p.communicate()
self.assertEqual(p.returncode, 1)
self.assertIn(b"OSError", stderr)
# The same as the previous call, but with an empty handle_list
handle_list = []
startupinfo = subprocess.STARTUPINFO()
startupinfo.lpAttributeList = {"handle_list": handle_list}
p = subprocess.Popen([sys.executable, "-c",
"import msvcrt; print(msvcrt.open_osfhandle({}, 0))".format(handles[0])],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
startupinfo=startupinfo, close_fds=True)
stdout, stderr = p.communicate()
self.assertEqual(p.returncode, 1)
self.assertIn(b"OSError", stderr)
# Check for a warning due to using handle_list and close_fds=False
with support.check_warnings((".*overriding close_fds", RuntimeWarning)):
startupinfo = subprocess.STARTUPINFO()
startupinfo.lpAttributeList = {"handle_list": handles[:]}
p = subprocess.Popen([sys.executable, "-c",
"import msvcrt; print(msvcrt.open_osfhandle({}, 0))".format(handles[0])],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
startupinfo=startupinfo, close_fds=False)
stdout, stderr = p.communicate()
self.assertEqual(p.returncode, 0)
def test_empty_attribute_list(self):
startupinfo = subprocess.STARTUPINFO()
startupinfo.lpAttributeList = {}
subprocess.call([sys.executable, "-c", "import sys; sys.exit(0)"],
startupinfo=startupinfo)
def test_empty_handle_list(self):
startupinfo = subprocess.STARTUPINFO()
startupinfo.lpAttributeList = {"handle_list": []}
subprocess.call([sys.executable, "-c", "import sys; sys.exit(0)"],
startupinfo=startupinfo)
def test_shell_sequence(self): def test_shell_sequence(self):
# Run command through the shell (sequence) # Run command through the shell (sequence)
newenv = os.environ.copy() newenv = os.environ.copy()
......
...@@ -17,7 +17,7 @@ import weakref ...@@ -17,7 +17,7 @@ import weakref
import os import os
import subprocess import subprocess
from test import lock_tests import lock_tests # gevent: use our local copy
from test import support from test import support
...@@ -132,10 +132,10 @@ class ThreadTests(BaseTestCase): ...@@ -132,10 +132,10 @@ class ThreadTests(BaseTestCase):
# Kill the "immortal" _DummyThread # Kill the "immortal" _DummyThread
del threading._active[ident[0]] del threading._active[ident[0]]
# run with a small(ish) thread stack size (256kB) # run with a small(ish) thread stack size (256 KiB)
def test_various_ops_small_stack(self): def test_various_ops_small_stack(self):
if verbose: if verbose:
print('with 256kB thread stack size...') print('with 256 KiB thread stack size...')
try: try:
threading.stack_size(262144) threading.stack_size(262144)
except _thread.error: except _thread.error:
...@@ -144,10 +144,10 @@ class ThreadTests(BaseTestCase): ...@@ -144,10 +144,10 @@ class ThreadTests(BaseTestCase):
self.test_various_ops() self.test_various_ops()
threading.stack_size(0) threading.stack_size(0)
# run with a large thread stack size (1MB) # run with a large thread stack size (1 MiB)
def test_various_ops_large_stack(self): def test_various_ops_large_stack(self):
if verbose: if verbose:
print('with 1MB thread stack size...') print('with 1 MiB thread stack size...')
try: try:
threading.stack_size(0x100000) threading.stack_size(0x100000)
except _thread.error: except _thread.error:
...@@ -427,7 +427,7 @@ class ThreadTests(BaseTestCase): ...@@ -427,7 +427,7 @@ class ThreadTests(BaseTestCase):
t.daemon = True t.daemon = True
self.assertIn('daemon', repr(t)) self.assertIn('daemon', repr(t))
def test_deamon_param(self): def test_daemon_param(self):
t = threading.Thread() t = threading.Thread()
self.assertFalse(t.daemon) self.assertFalse(t.daemon)
t = threading.Thread(daemon=False) t = threading.Thread(daemon=False)
...@@ -1138,14 +1138,6 @@ class TimerTests(BaseTestCase): ...@@ -1138,14 +1138,6 @@ class TimerTests(BaseTestCase):
class LockTests(lock_tests.LockTests): class LockTests(lock_tests.LockTests):
locktype = staticmethod(threading.Lock) locktype = staticmethod(threading.Lock)
@unittest.skip("not on gevent")
def test_locked_repr(self):
pass
@unittest.skip("not on gevent")
def test_repr(self):
pass
class PyRLockTests(lock_tests.RLockTests): class PyRLockTests(lock_tests.RLockTests):
locktype = staticmethod(threading._PyRLock) locktype = staticmethod(threading._PyRLock)
...@@ -1156,11 +1148,6 @@ class CRLockTests(lock_tests.RLockTests): ...@@ -1156,11 +1148,6 @@ class CRLockTests(lock_tests.RLockTests):
class EventTests(lock_tests.EventTests): class EventTests(lock_tests.EventTests):
eventtype = staticmethod(threading.Event) eventtype = staticmethod(threading.Event)
@unittest.skip("not on gevent")
def test_reset_internal_locks(self):
# xxx: gevent: This uses an internal _cond attribute we don't have
pass
class ConditionAsRLockTests(lock_tests.RLockTests): class ConditionAsRLockTests(lock_tests.RLockTests):
# Condition uses an RLock by default and exports its API. # Condition uses an RLock by default and exports its API.
locktype = staticmethod(threading.Condition) locktype = staticmethod(threading.Condition)
......
...@@ -569,5 +569,198 @@ class BoundedSemaphoreTests(BaseSemaphoreTests): ...@@ -569,5 +569,198 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
sem.release() sem.release()
self.assertRaises(ValueError, sem.release) self.assertRaises(ValueError, sem.release)
class BarrierTests(BaseTestCase):
"""
Tests for Barrier objects.
"""
N = 5
defaultTimeout = 2.0
def setUp(self):
self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
def tearDown(self):
self.barrier.abort()
def run_threads(self, f):
b = Bunch(f, self.N-1)
f()
b.wait_for_finished()
def multipass(self, results, n):
m = self.barrier.parties
self.assertEqual(m, self.N)
for i in range(n):
results[0].append(True)
self.assertEqual(len(results[1]), i * m)
self.barrier.wait()
results[1].append(True)
self.assertEqual(len(results[0]), (i + 1) * m)
self.barrier.wait()
self.assertEqual(self.barrier.n_waiting, 0)
self.assertFalse(self.barrier.broken)
def test_barrier(self, passes=1):
"""
Test that a barrier is passed in lockstep
"""
results = [[],[]]
def f():
self.multipass(results, passes)
self.run_threads(f)
def test_barrier_10(self):
"""
Test that a barrier works for 10 consecutive runs
"""
return self.test_barrier(10)
def test_wait_return(self):
"""
test the return value from barrier.wait
"""
results = []
def f():
r = self.barrier.wait()
results.append(r)
self.run_threads(f)
self.assertEqual(sum(results), sum(range(self.N)))
def test_action(self):
"""
Test the 'action' callback
"""
results = []
def action():
results.append(True)
barrier = self.barriertype(self.N, action)
def f():
barrier.wait()
self.assertEqual(len(results), 1)
self.run_threads(f)
def test_abort(self):
"""
Test that an abort will put the barrier in a broken state
"""
results1 = []
results2 = []
def f():
try:
i = self.barrier.wait()
if i == self.N//2:
raise RuntimeError
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
self.barrier.abort()
pass
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertTrue(self.barrier.broken)
def test_reset(self):
"""
Test that a 'reset' on a barrier frees the waiting threads
"""
results1 = []
results2 = []
results3 = []
def f():
i = self.barrier.wait()
if i == self.N//2:
# Wait until the other threads are all in the barrier.
while self.barrier.n_waiting < self.N-1:
time.sleep(0.001)
self.barrier.reset()
else:
try:
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
# Now, pass the barrier again
self.barrier.wait()
results3.append(True)
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
def test_abort_and_reset(self):
"""
Test that a barrier can be reset after being broken.
"""
results1 = []
results2 = []
results3 = []
barrier2 = self.barriertype(self.N)
def f():
try:
i = self.barrier.wait()
if i == self.N//2:
raise RuntimeError
self.barrier.wait()
results1.append(True)
except threading.BrokenBarrierError:
results2.append(True)
except RuntimeError:
self.barrier.abort()
pass
# Synchronize and reset the barrier. Must synchronize first so
# that everyone has left it when we reset, and after so that no
# one enters it before the reset.
if barrier2.wait() == self.N//2:
self.barrier.reset()
barrier2.wait()
self.barrier.wait()
results3.append(True)
self.run_threads(f)
self.assertEqual(len(results1), 0)
self.assertEqual(len(results2), self.N-1)
self.assertEqual(len(results3), self.N)
def test_timeout(self):
"""
Test wait(timeout)
"""
def f():
i = self.barrier.wait()
if i == self.N // 2:
# One thread is late!
time.sleep(1.0)
# Default timeout is 2.0, so this is shorter.
self.assertRaises(threading.BrokenBarrierError,
self.barrier.wait, 0.5)
self.run_threads(f)
def test_default_timeout(self):
"""
Test the barrier's default timeout
"""
# create a barrier with a low default timeout
barrier = self.barriertype(self.N, timeout=0.3)
def f():
i = barrier.wait()
if i == self.N // 2:
# One thread is later than the default timeout of 0.3s.
time.sleep(1.0)
self.assertRaises(threading.BrokenBarrierError, barrier.wait)
self.run_threads(f)
def test_single_thread(self):
b = self.barriertype(1)
b.wait()
b.wait()
if __name__ == '__main__': if __name__ == '__main__':
print("This module contains no tests; it is used by other test cases like test_threading_2") print("This module contains no tests; it is used by other test cases like test_threading_2")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment