Commit 5b10b982 authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-22831: Use "with" to avoid possible fd leaks in tests (part 2). (GH-10929)

parent 9e4861f5
...@@ -706,9 +706,8 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): ...@@ -706,9 +706,8 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
issue if/when we come across it. issue if/when we come across it.
""" """
tempsock = socket.socket(family, socktype) with socket.socket(family, socktype) as tempsock:
port = bind_port(tempsock) port = bind_port(tempsock)
tempsock.close()
del tempsock del tempsock
return port return port
...@@ -1785,10 +1784,11 @@ class _MemoryWatchdog: ...@@ -1785,10 +1784,11 @@ class _MemoryWatchdog:
sys.stderr.flush() sys.stderr.flush()
return return
watchdog_script = findfile("memory_watchdog.py") with f:
self.mem_watchdog = subprocess.Popen([sys.executable, watchdog_script], watchdog_script = findfile("memory_watchdog.py")
stdin=f, stderr=subprocess.DEVNULL) self.mem_watchdog = subprocess.Popen([sys.executable, watchdog_script],
f.close() stdin=f,
stderr=subprocess.DEVNULL)
self.started = True self.started = True
def stop(self): def stop(self):
......
...@@ -205,31 +205,28 @@ def make_script(script_dir, script_basename, source, omit_suffix=False): ...@@ -205,31 +205,28 @@ def make_script(script_dir, script_basename, source, omit_suffix=False):
script_filename += os.extsep + 'py' script_filename += os.extsep + 'py'
script_name = os.path.join(script_dir, script_filename) script_name = os.path.join(script_dir, script_filename)
# The script should be encoded to UTF-8, the default string encoding # The script should be encoded to UTF-8, the default string encoding
script_file = open(script_name, 'w', encoding='utf-8') with open(script_name, 'w', encoding='utf-8') as script_file:
script_file.write(source) script_file.write(source)
script_file.close()
importlib.invalidate_caches() importlib.invalidate_caches()
return script_name return script_name
def make_zip_script(zip_dir, zip_basename, script_name, name_in_zip=None): def make_zip_script(zip_dir, zip_basename, script_name, name_in_zip=None):
zip_filename = zip_basename+os.extsep+'zip' zip_filename = zip_basename+os.extsep+'zip'
zip_name = os.path.join(zip_dir, zip_filename) zip_name = os.path.join(zip_dir, zip_filename)
zip_file = zipfile.ZipFile(zip_name, 'w') with zipfile.ZipFile(zip_name, 'w') as zip_file:
if name_in_zip is None: if name_in_zip is None:
parts = script_name.split(os.sep) parts = script_name.split(os.sep)
if len(parts) >= 2 and parts[-2] == '__pycache__': if len(parts) >= 2 and parts[-2] == '__pycache__':
legacy_pyc = make_legacy_pyc(source_from_cache(script_name)) legacy_pyc = make_legacy_pyc(source_from_cache(script_name))
name_in_zip = os.path.basename(legacy_pyc) name_in_zip = os.path.basename(legacy_pyc)
script_name = legacy_pyc script_name = legacy_pyc
else: else:
name_in_zip = os.path.basename(script_name) name_in_zip = os.path.basename(script_name)
zip_file.write(script_name, name_in_zip) zip_file.write(script_name, name_in_zip)
zip_file.close()
#if test.support.verbose: #if test.support.verbose:
# zip_file = zipfile.ZipFile(zip_name, 'r') # with zipfile.ZipFile(zip_name, 'r') as zip_file:
# print 'Contents of %r:' % zip_name # print 'Contents of %r:' % zip_name
# zip_file.printdir() # zip_file.printdir()
# zip_file.close()
return zip_name, os.path.join(zip_name, name_in_zip) return zip_name, os.path.join(zip_name, name_in_zip)
def make_pkg(pkg_dir, init_source=''): def make_pkg(pkg_dir, init_source=''):
...@@ -252,17 +249,15 @@ def make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename, ...@@ -252,17 +249,15 @@ def make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename,
script_name_in_zip = os.path.join(pkg_names[-1], os.path.basename(script_name)) script_name_in_zip = os.path.join(pkg_names[-1], os.path.basename(script_name))
zip_filename = zip_basename+os.extsep+'zip' zip_filename = zip_basename+os.extsep+'zip'
zip_name = os.path.join(zip_dir, zip_filename) zip_name = os.path.join(zip_dir, zip_filename)
zip_file = zipfile.ZipFile(zip_name, 'w') with zipfile.ZipFile(zip_name, 'w') as zip_file:
for name in pkg_names: for name in pkg_names:
init_name_in_zip = os.path.join(name, init_basename) init_name_in_zip = os.path.join(name, init_basename)
zip_file.write(init_name, init_name_in_zip) zip_file.write(init_name, init_name_in_zip)
zip_file.write(script_name, script_name_in_zip) zip_file.write(script_name, script_name_in_zip)
zip_file.close()
for name in unlink: for name in unlink:
os.unlink(name) os.unlink(name)
#if test.support.verbose: #if test.support.verbose:
# zip_file = zipfile.ZipFile(zip_name, 'r') # with zipfile.ZipFile(zip_name, 'r') as zip_file:
# print 'Contents of %r:' % zip_name # print 'Contents of %r:' % zip_name
# zip_file.printdir() # zip_file.printdir()
# zip_file.close()
return zip_name, os.path.join(zip_name, script_name_in_zip) return zip_name, os.path.join(zip_name, script_name_in_zip)
...@@ -1379,9 +1379,8 @@ class TestArgumentsFromFile(TempDirMixin, ParserTestCase): ...@@ -1379,9 +1379,8 @@ class TestArgumentsFromFile(TempDirMixin, ParserTestCase):
('invalid', '@no-such-path\n'), ('invalid', '@no-such-path\n'),
] ]
for path, text in file_texts: for path, text in file_texts:
file = open(path, 'w') with open(path, 'w') as file:
file.write(text) file.write(text)
file.close()
parser_signature = Sig(fromfile_prefix_chars='@') parser_signature = Sig(fromfile_prefix_chars='@')
argument_signatures = [ argument_signatures = [
...@@ -1410,9 +1409,8 @@ class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase): ...@@ -1410,9 +1409,8 @@ class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase):
('hello', 'hello world!\n'), ('hello', 'hello world!\n'),
] ]
for path, text in file_texts: for path, text in file_texts:
file = open(path, 'w') with open(path, 'w') as file:
file.write(text) file.write(text)
file.close()
class FromFileConverterArgumentParser(ErrorRaisingArgumentParser): class FromFileConverterArgumentParser(ErrorRaisingArgumentParser):
...@@ -1493,9 +1491,8 @@ class TestFileTypeR(TempDirMixin, ParserTestCase): ...@@ -1493,9 +1491,8 @@ class TestFileTypeR(TempDirMixin, ParserTestCase):
def setUp(self): def setUp(self):
super(TestFileTypeR, self).setUp() super(TestFileTypeR, self).setUp()
for file_name in ['foo', 'bar']: for file_name in ['foo', 'bar']:
file = open(os.path.join(self.temp_dir, file_name), 'w') with open(os.path.join(self.temp_dir, file_name), 'w') as file:
file.write(file_name) file.write(file_name)
file.close()
self.create_readonly_file('readonly') self.create_readonly_file('readonly')
argument_signatures = [ argument_signatures = [
...@@ -1534,9 +1531,8 @@ class TestFileTypeRB(TempDirMixin, ParserTestCase): ...@@ -1534,9 +1531,8 @@ class TestFileTypeRB(TempDirMixin, ParserTestCase):
def setUp(self): def setUp(self):
super(TestFileTypeRB, self).setUp() super(TestFileTypeRB, self).setUp()
for file_name in ['foo', 'bar']: for file_name in ['foo', 'bar']:
file = open(os.path.join(self.temp_dir, file_name), 'w') with open(os.path.join(self.temp_dir, file_name), 'w') as file:
file.write(file_name) file.write(file_name)
file.close()
argument_signatures = [ argument_signatures = [
Sig('-x', type=argparse.FileType('rb')), Sig('-x', type=argparse.FileType('rb')),
......
...@@ -23,17 +23,15 @@ class BinHexTestCase(unittest.TestCase): ...@@ -23,17 +23,15 @@ class BinHexTestCase(unittest.TestCase):
DATA = b'Jack is my hero' DATA = b'Jack is my hero'
def test_binhex(self): def test_binhex(self):
f = open(self.fname1, 'wb') with open(self.fname1, 'wb') as f:
f.write(self.DATA) f.write(self.DATA)
f.close()
binhex.binhex(self.fname1, self.fname2) binhex.binhex(self.fname1, self.fname2)
binhex.hexbin(self.fname2, self.fname1) binhex.hexbin(self.fname2, self.fname1)
f = open(self.fname1, 'rb') with open(self.fname1, 'rb') as f:
finish = f.readline() finish = f.readline()
f.close()
self.assertEqual(self.DATA, finish) self.assertEqual(self.DATA, finish)
......
...@@ -20,13 +20,11 @@ class BoolTest(unittest.TestCase): ...@@ -20,13 +20,11 @@ class BoolTest(unittest.TestCase):
def test_print(self): def test_print(self):
try: try:
fo = open(support.TESTFN, "w") with open(support.TESTFN, "w") as fo:
print(False, True, file=fo) print(False, True, file=fo)
fo.close() with open(support.TESTFN, "r") as fi:
fo = open(support.TESTFN, "r") self.assertEqual(fi.read(), 'False True\n')
self.assertEqual(fo.read(), 'False True\n')
finally: finally:
fo.close()
os.remove(support.TESTFN) os.remove(support.TESTFN)
def test_repr(self): def test_repr(self):
...@@ -245,9 +243,8 @@ class BoolTest(unittest.TestCase): ...@@ -245,9 +243,8 @@ class BoolTest(unittest.TestCase):
def test_fileclosed(self): def test_fileclosed(self):
try: try:
f = open(support.TESTFN, "w") with open(support.TESTFN, "w") as f:
self.assertIs(f.closed, False) self.assertIs(f.closed, False)
f.close()
self.assertIs(f.closed, True) self.assertIs(f.closed, True)
finally: finally:
os.remove(support.TESTFN) os.remove(support.TESTFN)
......
...@@ -1242,9 +1242,8 @@ class EscapeDecodeTest(unittest.TestCase): ...@@ -1242,9 +1242,8 @@ class EscapeDecodeTest(unittest.TestCase):
class RecodingTest(unittest.TestCase): class RecodingTest(unittest.TestCase):
def test_recoding(self): def test_recoding(self):
f = io.BytesIO() f = io.BytesIO()
f2 = codecs.EncodedFile(f, "unicode_internal", "utf-8") with codecs.EncodedFile(f, "unicode_internal", "utf-8") as f2:
f2.write("a") f2.write("a")
f2.close()
# Python used to crash on this at exit because of a refcount # Python used to crash on this at exit because of a refcount
# bug in _codecsmodule.c # bug in _codecsmodule.c
......
...@@ -143,18 +143,17 @@ class TestEPoll(unittest.TestCase): ...@@ -143,18 +143,17 @@ class TestEPoll(unittest.TestCase):
def test_fromfd(self): def test_fromfd(self):
server, client = self._connected_pair() server, client = self._connected_pair()
ep = select.epoll(2) with select.epoll(2) as ep:
ep2 = select.epoll.fromfd(ep.fileno()) ep2 = select.epoll.fromfd(ep.fileno())
ep2.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT) ep2.register(server.fileno(), select.EPOLLIN | select.EPOLLOUT)
ep2.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT) ep2.register(client.fileno(), select.EPOLLIN | select.EPOLLOUT)
events = ep.poll(1, 4) events = ep.poll(1, 4)
events2 = ep2.poll(0.9, 4) events2 = ep2.poll(0.9, 4)
self.assertEqual(len(events), 2) self.assertEqual(len(events), 2)
self.assertEqual(len(events2), 2) self.assertEqual(len(events2), 2)
ep.close()
try: try:
ep2.poll(1, 4) ep2.poll(1, 4)
except OSError as e: except OSError as e:
......
...@@ -722,15 +722,14 @@ class FormatTestCase(unittest.TestCase): ...@@ -722,15 +722,14 @@ class FormatTestCase(unittest.TestCase):
class ReprTestCase(unittest.TestCase): class ReprTestCase(unittest.TestCase):
def test_repr(self): def test_repr(self):
floats_file = open(os.path.join(os.path.split(__file__)[0], with open(os.path.join(os.path.split(__file__)[0],
'floating_points.txt')) 'floating_points.txt')) as floats_file:
for line in floats_file: for line in floats_file:
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith('#'):
continue continue
v = eval(line) v = eval(line)
self.assertEqual(v, eval(repr(v))) self.assertEqual(v, eval(repr(v)))
floats_file.close()
@unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short', @unittest.skipUnless(getattr(sys, 'float_repr_style', '') == 'short',
"applies only when using short float repr style") "applies only when using short float repr style")
......
...@@ -11,9 +11,9 @@ try: ...@@ -11,9 +11,9 @@ try:
except OSError: except OSError:
raise unittest.SkipTest("Unable to open /dev/tty") raise unittest.SkipTest("Unable to open /dev/tty")
else: else:
# Skip if another process is in foreground with tty:
r = fcntl.ioctl(tty, termios.TIOCGPGRP, " ") # Skip if another process is in foreground
tty.close() r = fcntl.ioctl(tty, termios.TIOCGPGRP, " ")
rpgrp = struct.unpack("i", r)[0] rpgrp = struct.unpack("i", r)[0]
if rpgrp not in (os.getpgrp(), os.getsid(0)): if rpgrp not in (os.getpgrp(), os.getsid(0)):
raise unittest.SkipTest("Neither the process group nor the session " raise unittest.SkipTest("Neither the process group nor the session "
......
...@@ -1200,9 +1200,8 @@ class MakedirTests(unittest.TestCase): ...@@ -1200,9 +1200,8 @@ class MakedirTests(unittest.TestCase):
def test_exist_ok_existing_regular_file(self): def test_exist_ok_existing_regular_file(self):
base = support.TESTFN base = support.TESTFN
path = os.path.join(support.TESTFN, 'dir1') path = os.path.join(support.TESTFN, 'dir1')
f = open(path, 'w') with open(path, 'w') as f:
f.write('abc') f.write('abc')
f.close()
self.assertRaises(OSError, os.makedirs, path) self.assertRaises(OSError, os.makedirs, path)
self.assertRaises(OSError, os.makedirs, path, exist_ok=False) self.assertRaises(OSError, os.makedirs, path, exist_ok=False)
self.assertRaises(OSError, os.makedirs, path, exist_ok=True) self.assertRaises(OSError, os.makedirs, path, exist_ok=True)
......
...@@ -23,9 +23,8 @@ class SimplePipeTests(unittest.TestCase): ...@@ -23,9 +23,8 @@ class SimplePipeTests(unittest.TestCase):
self.skipTest('tr is not available') self.skipTest('tr is not available')
t = pipes.Template() t = pipes.Template()
t.append(s_command, pipes.STDIN_STDOUT) t.append(s_command, pipes.STDIN_STDOUT)
f = t.open(TESTFN, 'w') with t.open(TESTFN, 'w') as f:
f.write('hello world #1') f.write('hello world #1')
f.close()
with open(TESTFN) as f: with open(TESTFN) as f:
self.assertEqual(f.read(), 'HELLO WORLD #1') self.assertEqual(f.read(), 'HELLO WORLD #1')
......
...@@ -83,13 +83,12 @@ class PollTests(unittest.TestCase): ...@@ -83,13 +83,12 @@ class PollTests(unittest.TestCase):
r = p.poll() r = p.poll()
self.assertEqual(r[0], (FD, select.POLLNVAL)) self.assertEqual(r[0], (FD, select.POLLNVAL))
f = open(TESTFN, 'w') with open(TESTFN, 'w') as f:
fd = f.fileno() fd = f.fileno()
p = select.poll() p = select.poll()
p.register(f) p.register(f)
r = p.poll() r = p.poll()
self.assertEqual(r[0][0], fd) self.assertEqual(r[0][0], fd)
f.close()
r = p.poll() r = p.poll()
self.assertEqual(r[0], (fd, select.POLLNVAL)) self.assertEqual(r[0], (fd, select.POLLNVAL))
os.unlink(TESTFN) os.unlink(TESTFN)
......
...@@ -268,9 +268,8 @@ class TestBasicOps: ...@@ -268,9 +268,8 @@ class TestBasicOps:
("randv2_64.pck", 866), ("randv2_64.pck", 866),
("randv3.pck", 343)] ("randv3.pck", 343)]
for file, value in files: for file, value in files:
f = open(support.findfile(file),"rb") with open(support.findfile(file),"rb") as f:
r = pickle.load(f) r = pickle.load(f)
f.close()
self.assertEqual(int(r.random()*1000), value) self.assertEqual(int(r.random()*1000), value)
def test_bug_9025(self): def test_bug_9025(self):
......
...@@ -238,9 +238,8 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin): ...@@ -238,9 +238,8 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin):
if verbose > 1: print(" Next level in:", sub_dir) if verbose > 1: print(" Next level in:", sub_dir)
if verbose > 1: print(" Created:", pkg_fname) if verbose > 1: print(" Created:", pkg_fname)
mod_fname = os.path.join(sub_dir, test_fname) mod_fname = os.path.join(sub_dir, test_fname)
mod_file = open(mod_fname, "w") with open(mod_fname, "w") as mod_file:
mod_file.write(source) mod_file.write(source)
mod_file.close()
if verbose > 1: print(" Created:", mod_fname) if verbose > 1: print(" Created:", mod_fname)
mod_name = (pkg_name+".")*depth + mod_base mod_name = (pkg_name+".")*depth + mod_base
mod_spec = importlib.util.spec_from_file_location(mod_name, mod_spec = importlib.util.spec_from_file_location(mod_name,
......
...@@ -46,24 +46,23 @@ class SelectTestCase(unittest.TestCase): ...@@ -46,24 +46,23 @@ class SelectTestCase(unittest.TestCase):
def test_select(self): def test_select(self):
cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done' cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
p = os.popen(cmd, 'r') with os.popen(cmd) as p:
for tout in (0, 1, 2, 4, 8, 16) + (None,)*10: for tout in (0, 1, 2, 4, 8, 16) + (None,)*10:
if support.verbose:
print('timeout =', tout)
rfd, wfd, xfd = select.select([p], [], [], tout)
if (rfd, wfd, xfd) == ([], [], []):
continue
if (rfd, wfd, xfd) == ([p], [], []):
line = p.readline()
if support.verbose: if support.verbose:
print(repr(line)) print('timeout =', tout)
if not line: rfd, wfd, xfd = select.select([p], [], [], tout)
if (rfd, wfd, xfd) == ([], [], []):
continue
if (rfd, wfd, xfd) == ([p], [], []):
line = p.readline()
if support.verbose: if support.verbose:
print('EOF') print(repr(line))
break if not line:
continue if support.verbose:
self.fail('Unexpected return values from select():', rfd, wfd, xfd) print('EOF')
p.close() break
continue
self.fail('Unexpected return values from select():', rfd, wfd, xfd)
# Issue 16230: Crash on select resized list # Issue 16230: Crash on select resized list
def test_select_mutated(self): def test_select_mutated(self):
......
...@@ -88,15 +88,13 @@ class TestCase(unittest.TestCase): ...@@ -88,15 +88,13 @@ class TestCase(unittest.TestCase):
def test_in_memory_shelf(self): def test_in_memory_shelf(self):
d1 = byteskeydict() d1 = byteskeydict()
s = shelve.Shelf(d1, protocol=0) with shelve.Shelf(d1, protocol=0) as s:
s['key1'] = (1,2,3,4) s['key1'] = (1,2,3,4)
self.assertEqual(s['key1'], (1,2,3,4)) self.assertEqual(s['key1'], (1,2,3,4))
s.close()
d2 = byteskeydict() d2 = byteskeydict()
s = shelve.Shelf(d2, protocol=1) with shelve.Shelf(d2, protocol=1) as s:
s['key1'] = (1,2,3,4) s['key1'] = (1,2,3,4)
self.assertEqual(s['key1'], (1,2,3,4)) self.assertEqual(s['key1'], (1,2,3,4))
s.close()
self.assertEqual(len(d1), 1) self.assertEqual(len(d1), 1)
self.assertEqual(len(d2), 1) self.assertEqual(len(d2), 1)
...@@ -104,20 +102,18 @@ class TestCase(unittest.TestCase): ...@@ -104,20 +102,18 @@ class TestCase(unittest.TestCase):
def test_mutable_entry(self): def test_mutable_entry(self):
d1 = byteskeydict() d1 = byteskeydict()
s = shelve.Shelf(d1, protocol=2, writeback=False) with shelve.Shelf(d1, protocol=2, writeback=False) as s:
s['key1'] = [1,2,3,4] s['key1'] = [1,2,3,4]
self.assertEqual(s['key1'], [1,2,3,4]) self.assertEqual(s['key1'], [1,2,3,4])
s['key1'].append(5) s['key1'].append(5)
self.assertEqual(s['key1'], [1,2,3,4]) self.assertEqual(s['key1'], [1,2,3,4])
s.close()
d2 = byteskeydict() d2 = byteskeydict()
s = shelve.Shelf(d2, protocol=2, writeback=True) with shelve.Shelf(d2, protocol=2, writeback=True) as s:
s['key1'] = [1,2,3,4] s['key1'] = [1,2,3,4]
self.assertEqual(s['key1'], [1,2,3,4]) self.assertEqual(s['key1'], [1,2,3,4])
s['key1'].append(5) s['key1'].append(5)
self.assertEqual(s['key1'], [1,2,3,4,5]) self.assertEqual(s['key1'], [1,2,3,4,5])
s.close()
self.assertEqual(len(d1), 1) self.assertEqual(len(d1), 1)
self.assertEqual(len(d2), 1) self.assertEqual(len(d2), 1)
...@@ -140,11 +136,10 @@ class TestCase(unittest.TestCase): ...@@ -140,11 +136,10 @@ class TestCase(unittest.TestCase):
d = {} d = {}
key = 'key' key = 'key'
encodedkey = key.encode('utf-8') encodedkey = key.encode('utf-8')
s = shelve.Shelf(d, writeback=True) with shelve.Shelf(d, writeback=True) as s:
s[key] = [1] s[key] = [1]
p1 = d[encodedkey] # Will give a KeyError if backing store not updated p1 = d[encodedkey] # Will give a KeyError if backing store not updated
s['key'].append(2) s['key'].append(2)
s.close()
p2 = d[encodedkey] p2 = d[encodedkey]
self.assertNotEqual(p1, p2) # Write creates new object in store self.assertNotEqual(p1, p2) # Write creates new object in store
......
...@@ -125,10 +125,9 @@ class HelperFunctionsTests(unittest.TestCase): ...@@ -125,10 +125,9 @@ class HelperFunctionsTests(unittest.TestCase):
pth_dir = os.path.abspath(pth_dir) pth_dir = os.path.abspath(pth_dir)
pth_basename = pth_name + '.pth' pth_basename = pth_name + '.pth'
pth_fn = os.path.join(pth_dir, pth_basename) pth_fn = os.path.join(pth_dir, pth_basename)
pth_file = open(pth_fn, 'w', encoding='utf-8') with open(pth_fn, 'w', encoding='utf-8') as pth_file:
self.addCleanup(lambda: os.remove(pth_fn)) self.addCleanup(lambda: os.remove(pth_fn))
pth_file.write(contents) pth_file.write(contents)
pth_file.close()
return pth_dir, pth_basename return pth_dir, pth_basename
def test_addpackage_import_bad_syntax(self): def test_addpackage_import_bad_syntax(self):
......
...@@ -157,27 +157,25 @@ class SocketServerTest(unittest.TestCase): ...@@ -157,27 +157,25 @@ class SocketServerTest(unittest.TestCase):
if verbose: print("done") if verbose: print("done")
def stream_examine(self, proto, addr): def stream_examine(self, proto, addr):
s = socket.socket(proto, socket.SOCK_STREAM) with socket.socket(proto, socket.SOCK_STREAM) as s:
s.connect(addr) s.connect(addr)
s.sendall(TEST_STR) s.sendall(TEST_STR)
buf = data = receive(s, 100) buf = data = receive(s, 100)
while data and b'\n' not in buf: while data and b'\n' not in buf:
data = receive(s, 100) data = receive(s, 100)
buf += data buf += data
self.assertEqual(buf, TEST_STR) self.assertEqual(buf, TEST_STR)
s.close()
def dgram_examine(self, proto, addr): def dgram_examine(self, proto, addr):
s = socket.socket(proto, socket.SOCK_DGRAM) with socket.socket(proto, socket.SOCK_DGRAM) as s:
if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX: if HAVE_UNIX_SOCKETS and proto == socket.AF_UNIX:
s.bind(self.pickaddr(proto)) s.bind(self.pickaddr(proto))
s.sendto(TEST_STR, addr) s.sendto(TEST_STR, addr)
buf = data = receive(s, 100) buf = data = receive(s, 100)
while data and b'\n' not in buf: while data and b'\n' not in buf:
data = receive(s, 100) data = receive(s, 100)
buf += data buf += data
self.assertEqual(buf, TEST_STR) self.assertEqual(buf, TEST_STR)
s.close()
def test_TCPServer(self): def test_TCPServer(self):
self.run_server(socketserver.TCPServer, self.run_server(socketserver.TCPServer,
......
...@@ -573,9 +573,8 @@ class TestGetTempDir(BaseTestCase): ...@@ -573,9 +573,8 @@ class TestGetTempDir(BaseTestCase):
# sneaky: just instantiate a NamedTemporaryFile, which # sneaky: just instantiate a NamedTemporaryFile, which
# defaults to writing into the directory returned by # defaults to writing into the directory returned by
# gettempdir. # gettempdir.
file = tempfile.NamedTemporaryFile() with tempfile.NamedTemporaryFile() as file:
file.write(b"blat") file.write(b"blat")
file.close()
def test_same_thing(self): def test_same_thing(self):
# gettempdir always returns the same object # gettempdir always returns the same object
...@@ -891,9 +890,8 @@ class TestNamedTemporaryFile(BaseTestCase): ...@@ -891,9 +890,8 @@ class TestNamedTemporaryFile(BaseTestCase):
# A NamedTemporaryFile is deleted when closed # A NamedTemporaryFile is deleted when closed
dir = tempfile.mkdtemp() dir = tempfile.mkdtemp()
try: try:
f = tempfile.NamedTemporaryFile(dir=dir) with tempfile.NamedTemporaryFile(dir=dir) as f:
f.write(b'blat') f.write(b'blat')
f.close()
self.assertFalse(os.path.exists(f.name), self.assertFalse(os.path.exists(f.name),
"NamedTemporaryFile %s exists after close" % f.name) "NamedTemporaryFile %s exists after close" % f.name)
finally: finally:
......
...@@ -788,13 +788,11 @@ class ThreadJoinOnShutdown(BaseTestCase): ...@@ -788,13 +788,11 @@ class ThreadJoinOnShutdown(BaseTestCase):
def random_io(): def random_io():
'''Loop for a while sleeping random tiny amounts and doing some I/O.''' '''Loop for a while sleeping random tiny amounts and doing some I/O.'''
while True: while True:
in_f = open(os.__file__, 'rb') with open(os.__file__, 'rb') as in_f:
stuff = in_f.read(200) stuff = in_f.read(200)
null_f = open(os.devnull, 'wb') with open(os.devnull, 'wb') as null_f:
null_f.write(stuff) null_f.write(stuff)
time.sleep(random.random() / 1995) time.sleep(random.random() / 1995)
null_f.close()
in_f.close()
thread_has_run.add(threading.current_thread()) thread_has_run.add(threading.current_thread())
def main(): def main():
......
...@@ -57,10 +57,8 @@ class TrivialTests(unittest.TestCase): ...@@ -57,10 +57,8 @@ class TrivialTests(unittest.TestCase):
else: else:
file_url = "file://%s" % fname file_url = "file://%s" % fname
f = urllib.request.urlopen(file_url) with urllib.request.urlopen(file_url) as f:
f.read()
f.read()
f.close()
def test_parse_http_list(self): def test_parse_http_list(self):
tests = [ tests = [
......
...@@ -371,10 +371,9 @@ class ProxyAuthTests(unittest.TestCase): ...@@ -371,10 +371,9 @@ class ProxyAuthTests(unittest.TestCase):
self.proxy_digest_handler.add_password(self.REALM, self.URL, self.proxy_digest_handler.add_password(self.REALM, self.URL,
self.USER, self.PASSWD) self.USER, self.PASSWD)
self.digest_auth_handler.set_qop("auth") self.digest_auth_handler.set_qop("auth")
result = self.opener.open(self.URL) with self.opener.open(self.URL) as result:
while result.read(): while result.read():
pass pass
result.close()
def test_proxy_qop_auth_int_works_or_throws_urlerror(self): def test_proxy_qop_auth_int_works_or_throws_urlerror(self):
self.proxy_digest_handler.add_password(self.REALM, self.URL, self.proxy_digest_handler.add_password(self.REALM, self.URL,
...@@ -386,11 +385,11 @@ class ProxyAuthTests(unittest.TestCase): ...@@ -386,11 +385,11 @@ class ProxyAuthTests(unittest.TestCase):
# It's okay if we don't support auth-int, but we certainly # It's okay if we don't support auth-int, but we certainly
# shouldn't receive any kind of exception here other than # shouldn't receive any kind of exception here other than
# a URLError. # a URLError.
result = None pass
if result: else:
while result.read(): with result:
pass while result.read():
result.close() pass
def GetRequestHandler(responses): def GetRequestHandler(responses):
...@@ -611,14 +610,11 @@ class TestUrlopen(unittest.TestCase): ...@@ -611,14 +610,11 @@ class TestUrlopen(unittest.TestCase):
def test_basic(self): def test_basic(self):
handler = self.start_server() handler = self.start_server()
open_url = urllib.request.urlopen("http://localhost:%s" % handler.port) with urllib.request.urlopen("http://localhost:%s" % handler.port) as open_url:
for attr in ("read", "close", "info", "geturl"): for attr in ("read", "close", "info", "geturl"):
self.assertTrue(hasattr(open_url, attr), "object returned from " self.assertTrue(hasattr(open_url, attr), "object returned from "
"urlopen lacks the %s attribute" % attr) "urlopen lacks the %s attribute" % attr)
try:
self.assertTrue(open_url.read(), "calling 'read' failed") self.assertTrue(open_url.read(), "calling 'read' failed")
finally:
open_url.close()
def test_info(self): def test_info(self):
handler = self.start_server() handler = self.start_server()
......
...@@ -821,10 +821,9 @@ class SimpleServerTestCase(BaseServerTestCase): ...@@ -821,10 +821,9 @@ class SimpleServerTestCase(BaseServerTestCase):
def test_404(self): def test_404(self):
# send POST with http.client, it should return 404 header and # send POST with http.client, it should return 404 header and
# 'Not Found' message. # 'Not Found' message.
conn = http.client.HTTPConnection(ADDR, PORT) with contextlib.closing(http.client.HTTPConnection(ADDR, PORT)) as conn:
conn.request('POST', '/this-is-not-valid') conn.request('POST', '/this-is-not-valid')
response = conn.getresponse() response = conn.getresponse()
conn.close()
self.assertEqual(response.status, 404) self.assertEqual(response.status, 404)
self.assertEqual(response.reason, 'Not Found') self.assertEqual(response.reason, 'Not Found')
...@@ -944,9 +943,8 @@ class SimpleServerTestCase(BaseServerTestCase): ...@@ -944,9 +943,8 @@ class SimpleServerTestCase(BaseServerTestCase):
def test_partial_post(self): def test_partial_post(self):
# Check that a partial POST doesn't make the server loop: issue #14001. # Check that a partial POST doesn't make the server loop: issue #14001.
conn = http.client.HTTPConnection(ADDR, PORT) with contextlib.closing(http.client.HTTPConnection(ADDR, PORT)) as conn:
conn.request('POST', '/RPC2 HTTP/1.0\r\nContent-Length: 100\r\n\r\nbye') conn.request('POST', '/RPC2 HTTP/1.0\r\nContent-Length: 100\r\n\r\nbye')
conn.close()
def test_context_manager(self): def test_context_manager(self):
with xmlrpclib.ServerProxy(URL) as server: with xmlrpclib.ServerProxy(URL) as server:
......
This diff is collapsed.
...@@ -122,15 +122,13 @@ class ZipSupportTests(unittest.TestCase): ...@@ -122,15 +122,13 @@ class ZipSupportTests(unittest.TestCase):
test_src) test_src)
zip_name, run_name = make_zip_script(d, 'test_zip', zip_name, run_name = make_zip_script(d, 'test_zip',
script_name) script_name)
z = zipfile.ZipFile(zip_name, 'a') with zipfile.ZipFile(zip_name, 'a') as z:
for mod_name, src in sample_sources.items(): for mod_name, src in sample_sources.items():
z.writestr(mod_name + ".py", src) z.writestr(mod_name + ".py", src)
z.close()
if verbose: if verbose:
zip_file = zipfile.ZipFile(zip_name, 'r') with zipfile.ZipFile(zip_name, 'r') as zip_file:
print ('Contents of %r:' % zip_name) print ('Contents of %r:' % zip_name)
zip_file.printdir() zip_file.printdir()
zip_file.close()
os.remove(script_name) os.remove(script_name)
sys.path.insert(0, zip_name) sys.path.insert(0, zip_name)
import test_zipped_doctest import test_zipped_doctest
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment