Commit 9e4861f5 authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-22831: Use "with" to avoid possible fd leaks in tests (part 1). (GH-10928)

parent b7272395
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Original by Roger E. Masse Original by Roger E. Masse
""" """
import contextlib
import io import io
import operator import operator
import os import os
...@@ -32,12 +33,11 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -32,12 +33,11 @@ class DumbDBMTestCase(unittest.TestCase):
} }
def test_dumbdbm_creation(self): def test_dumbdbm_creation(self):
f = dumbdbm.open(_fname, 'c') with contextlib.closing(dumbdbm.open(_fname, 'c')) as f:
self.assertEqual(list(f.keys()), []) self.assertEqual(list(f.keys()), [])
for key in self._dict: for key in self._dict:
f[key] = self._dict[key] f[key] = self._dict[key]
self.read_helper(f) self.read_helper(f)
f.close()
@unittest.skipUnless(hasattr(os, 'umask'), 'test needs os.umask()') @unittest.skipUnless(hasattr(os, 'umask'), 'test needs os.umask()')
def test_dumbdbm_creation_mode(self): def test_dumbdbm_creation_mode(self):
...@@ -69,17 +69,16 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -69,17 +69,16 @@ class DumbDBMTestCase(unittest.TestCase):
def test_dumbdbm_modification(self): def test_dumbdbm_modification(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'w') with contextlib.closing(dumbdbm.open(_fname, 'w')) as f:
self._dict[b'g'] = f[b'g'] = b"indented" self._dict[b'g'] = f[b'g'] = b"indented"
self.read_helper(f) self.read_helper(f)
# setdefault() works as in the dict interface # setdefault() works as in the dict interface
self.assertEqual(f.setdefault(b'xxx', b'foo'), b'foo') self.assertEqual(f.setdefault(b'xxx', b'foo'), b'foo')
self.assertEqual(f[b'xxx'], b'foo') self.assertEqual(f[b'xxx'], b'foo')
f.close()
def test_dumbdbm_read(self): def test_dumbdbm_read(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'r') with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
self.read_helper(f) self.read_helper(f)
with self.assertRaisesRegex(dumbdbm.error, with self.assertRaisesRegex(dumbdbm.error,
'The database is opened for reading only'): 'The database is opened for reading only'):
...@@ -93,42 +92,36 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -93,42 +92,36 @@ class DumbDBMTestCase(unittest.TestCase):
self.assertIsNone(f.get(b'xxx')) self.assertIsNone(f.get(b'xxx'))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
f[b'xxx'] f[b'xxx']
f.close()
def test_dumbdbm_keys(self): def test_dumbdbm_keys(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
keys = self.keys_helper(f) keys = self.keys_helper(f)
f.close()
def test_write_contains(self): def test_write_contains(self):
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
self.assertIn(b'1', f) self.assertIn(b'1', f)
f.close()
def test_write_write_read(self): def test_write_write_read(self):
# test for bug #482460 # test for bug #482460
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
f[b'1'] = b'hello2' f[b'1'] = b'hello2'
f.close() with contextlib.closing(dumbdbm.open(_fname)) as f:
f = dumbdbm.open(_fname)
self.assertEqual(f[b'1'], b'hello2') self.assertEqual(f[b'1'], b'hello2')
f.close()
def test_str_read(self): def test_str_read(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'r') with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
self.assertEqual(f['\u00fc'], self._dict['\u00fc'.encode('utf-8')]) self.assertEqual(f['\u00fc'], self._dict['\u00fc'.encode('utf-8')])
def test_str_write_contains(self): def test_str_write_contains(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f['\u00fc'] = b'!' f['\u00fc'] = b'!'
f['1'] = 'a' f['1'] = 'a'
f.close() with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
f = dumbdbm.open(_fname, 'r')
self.assertIn('\u00fc', f) self.assertIn('\u00fc', f)
self.assertEqual(f['\u00fc'.encode('utf-8')], self.assertEqual(f['\u00fc'.encode('utf-8')],
self._dict['\u00fc'.encode('utf-8')]) self._dict['\u00fc'.encode('utf-8')])
...@@ -137,10 +130,9 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -137,10 +130,9 @@ class DumbDBMTestCase(unittest.TestCase):
def test_line_endings(self): def test_line_endings(self):
# test for bug #1172763: dumbdbm would die if the line endings # test for bug #1172763: dumbdbm would die if the line endings
# weren't what was expected. # weren't what was expected.
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
f[b'2'] = b'hello2' f[b'2'] = b'hello2'
f.close()
# Mangle the file by changing the line separator to Windows or Unix # Mangle the file by changing the line separator to Windows or Unix
with io.open(_fname + '.dir', 'rb') as file: with io.open(_fname + '.dir', 'rb') as file:
...@@ -163,10 +155,9 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -163,10 +155,9 @@ class DumbDBMTestCase(unittest.TestCase):
self.assertEqual(self._dict[key], f[key]) self.assertEqual(self._dict[key], f[key])
def init_db(self): def init_db(self):
f = dumbdbm.open(_fname, 'n') with contextlib.closing(dumbdbm.open(_fname, 'n')) as f:
for k in self._dict: for k in self._dict:
f[k] = self._dict[k] f[k] = self._dict[k]
f.close()
def keys_helper(self, f): def keys_helper(self, f):
keys = sorted(f.keys()) keys = sorted(f.keys())
...@@ -180,7 +171,7 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -180,7 +171,7 @@ class DumbDBMTestCase(unittest.TestCase):
import random import random
d = {} # mirror the database d = {} # mirror the database
for dummy in range(5): for dummy in range(5):
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
for dummy in range(100): for dummy in range(100):
k = random.choice('abcdefghijklm') k = random.choice('abcdefghijklm')
if random.random() < 0.2: if random.random() < 0.2:
...@@ -192,13 +183,11 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -192,13 +183,11 @@ class DumbDBMTestCase(unittest.TestCase):
d[k] = v d[k] = v
f[k] = v f[k] = v
self.assertEqual(f[k], v) self.assertEqual(f[k], v)
f.close()
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
expected = sorted((k.encode("latin-1"), v) for k, v in d.items()) expected = sorted((k.encode("latin-1"), v) for k, v in d.items())
got = sorted(f.items()) got = sorted(f.items())
self.assertEqual(expected, got) self.assertEqual(expected, got)
f.close()
def test_context_manager(self): def test_context_manager(self):
with dumbdbm.open(_fname, 'c') as db: with dumbdbm.open(_fname, 'c') as db:
......
...@@ -268,13 +268,12 @@ class MmapTests(unittest.TestCase): ...@@ -268,13 +268,12 @@ class MmapTests(unittest.TestCase):
def test_find_end(self): def test_find_end(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
data = b'one two ones' data = b'one two ones'
n = len(data) n = len(data)
f.write(data) f.write(data)
f.flush() f.flush()
m = mmap.mmap(f.fileno(), n) m = mmap.mmap(f.fileno(), n)
f.close()
self.assertEqual(m.find(b'one'), 0) self.assertEqual(m.find(b'one'), 0)
self.assertEqual(m.find(b'ones'), 8) self.assertEqual(m.find(b'ones'), 8)
...@@ -287,13 +286,12 @@ class MmapTests(unittest.TestCase): ...@@ -287,13 +286,12 @@ class MmapTests(unittest.TestCase):
def test_rfind(self): def test_rfind(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
data = b'one two ones' data = b'one two ones'
n = len(data) n = len(data)
f.write(data) f.write(data)
f.flush() f.flush()
m = mmap.mmap(f.fileno(), n) m = mmap.mmap(f.fileno(), n)
f.close()
self.assertEqual(m.rfind(b'one'), 8) self.assertEqual(m.rfind(b'one'), 8)
self.assertEqual(m.rfind(b'one '), 0) self.assertEqual(m.rfind(b'one '), 0)
...@@ -306,30 +304,23 @@ class MmapTests(unittest.TestCase): ...@@ -306,30 +304,23 @@ class MmapTests(unittest.TestCase):
def test_double_close(self): def test_double_close(self):
# make sure a double close doesn't crash on Solaris (Bug# 665913) # make sure a double close doesn't crash on Solaris (Bug# 665913)
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
f.write(2**16 * b'a') # Arbitrary character f.write(2**16 * b'a') # Arbitrary character
f.close()
f = open(TESTFN, 'rb') with open(TESTFN, 'rb') as f:
mf = mmap.mmap(f.fileno(), 2**16, access=mmap.ACCESS_READ) mf = mmap.mmap(f.fileno(), 2**16, access=mmap.ACCESS_READ)
mf.close() mf.close()
mf.close() mf.close()
f.close()
def test_entire_file(self): def test_entire_file(self):
# test mapping of entire file by passing 0 for map length # test mapping of entire file by passing 0 for map length
f = open(TESTFN, "wb+") with open(TESTFN, "wb+") as f:
f.write(2**16 * b'm') # Arbitrary character f.write(2**16 * b'm') # Arbitrary character
f.close()
f = open(TESTFN, "rb+") with open(TESTFN, "rb+") as f, \
mf = mmap.mmap(f.fileno(), 0) mmap.mmap(f.fileno(), 0) as mf:
self.assertEqual(len(mf), 2**16, "Map size should equal file size.") self.assertEqual(len(mf), 2**16, "Map size should equal file size.")
self.assertEqual(mf.read(2**16), 2**16 * b"m") self.assertEqual(mf.read(2**16), 2**16 * b"m")
mf.close()
f.close()
def test_length_0_offset(self): def test_length_0_offset(self):
# Issue #10916: test mapping of remainder of file by passing 0 for # Issue #10916: test mapping of remainder of file by passing 0 for
...@@ -355,7 +346,7 @@ class MmapTests(unittest.TestCase): ...@@ -355,7 +346,7 @@ class MmapTests(unittest.TestCase):
def test_move(self): def test_move(self):
# make move works everywhere (64-bit format problem earlier) # make move works everywhere (64-bit format problem earlier)
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
f.write(b"ABCDEabcde") # Arbitrary character f.write(b"ABCDEabcde") # Arbitrary character
f.flush() f.flush()
...@@ -364,7 +355,6 @@ class MmapTests(unittest.TestCase): ...@@ -364,7 +355,6 @@ class MmapTests(unittest.TestCase):
mf.move(5, 0, 5) mf.move(5, 0, 5)
self.assertEqual(mf[:], b"ABCDEABCDE", "Map move should have duplicated front 5") self.assertEqual(mf[:], b"ABCDEABCDE", "Map move should have duplicated front 5")
mf.close() mf.close()
f.close()
# more excessive test # more excessive test
data = b"0123456789" data = b"0123456789"
...@@ -562,10 +552,9 @@ class MmapTests(unittest.TestCase): ...@@ -562,10 +552,9 @@ class MmapTests(unittest.TestCase):
mapsize = 10 mapsize = 10
with open(TESTFN, "wb") as fp: with open(TESTFN, "wb") as fp:
fp.write(b"a"*mapsize) fp.write(b"a"*mapsize)
f = open(TESTFN, "rb") with open(TESTFN, "rb") as f:
m = mmap.mmap(f.fileno(), mapsize, prot=mmap.PROT_READ) m = mmap.mmap(f.fileno(), mapsize, prot=mmap.PROT_READ)
self.assertRaises(TypeError, m.write, "foo") self.assertRaises(TypeError, m.write, "foo")
f.close()
def test_error(self): def test_error(self):
self.assertIs(mmap.error, OSError) self.assertIs(mmap.error, OSError)
...@@ -574,9 +563,8 @@ class MmapTests(unittest.TestCase): ...@@ -574,9 +563,8 @@ class MmapTests(unittest.TestCase):
data = b"0123456789" data = b"0123456789"
with open(TESTFN, "wb") as fp: with open(TESTFN, "wb") as fp:
fp.write(b"x"*len(data)) fp.write(b"x"*len(data))
f = open(TESTFN, "r+b") with open(TESTFN, "r+b") as f:
m = mmap.mmap(f.fileno(), len(data)) m = mmap.mmap(f.fileno(), len(data))
f.close()
# Test write_byte() # Test write_byte()
for i in range(len(data)): for i in range(len(data)):
self.assertEqual(m.tell(), i) self.assertEqual(m.tell(), i)
......
...@@ -799,10 +799,9 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -799,10 +799,9 @@ class GeneralModuleTests(unittest.TestCase):
self.assertEqual(repr(s), expected) self.assertEqual(repr(s), expected)
def test_weakref(self): def test_weakref(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
p = proxy(s) p = proxy(s)
self.assertEqual(p.fileno(), s.fileno()) self.assertEqual(p.fileno(), s.fileno())
s.close()
s = None s = None
try: try:
p.fileno() p.fileno()
...@@ -1072,9 +1071,8 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1072,9 +1071,8 @@ class GeneralModuleTests(unittest.TestCase):
# Testing default timeout # Testing default timeout
# The default timeout should initially be None # The default timeout should initially be None
self.assertEqual(socket.getdefaulttimeout(), None) self.assertEqual(socket.getdefaulttimeout(), None)
s = socket.socket() with socket.socket() as s:
self.assertEqual(s.gettimeout(), None) self.assertEqual(s.gettimeout(), None)
s.close()
# Set the default timeout to 10, and see if it propagates # Set the default timeout to 10, and see if it propagates
with socket_setdefaulttimeout(10): with socket_setdefaulttimeout(10):
...@@ -1297,9 +1295,8 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1297,9 +1295,8 @@ class GeneralModuleTests(unittest.TestCase):
def testSendAfterClose(self): def testSendAfterClose(self):
# testing send() after close() with timeout # testing send() after close() with timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1) sock.settimeout(1)
sock.close()
self.assertRaises(OSError, sock.send, b"spam") self.assertRaises(OSError, sock.send, b"spam")
def testCloseException(self): def testCloseException(self):
...@@ -1317,7 +1314,7 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1317,7 +1314,7 @@ class GeneralModuleTests(unittest.TestCase):
def testNewAttributes(self): def testNewAttributes(self):
# testing .family, .type and .protocol # testing .family, .type and .protocol
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.assertEqual(sock.family, socket.AF_INET) self.assertEqual(sock.family, socket.AF_INET)
if hasattr(socket, 'SOCK_CLOEXEC'): if hasattr(socket, 'SOCK_CLOEXEC'):
self.assertIn(sock.type, self.assertIn(sock.type,
...@@ -1326,7 +1323,6 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1326,7 +1323,6 @@ class GeneralModuleTests(unittest.TestCase):
else: else:
self.assertEqual(sock.type, socket.SOCK_STREAM) self.assertEqual(sock.type, socket.SOCK_STREAM)
self.assertEqual(sock.proto, 0) self.assertEqual(sock.proto, 0)
sock.close()
def test_getsockaddrarg(self): def test_getsockaddrarg(self):
sock = socket.socket() sock = socket.socket()
...@@ -1601,10 +1597,9 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1601,10 +1597,9 @@ class GeneralModuleTests(unittest.TestCase):
def test_listen_backlog_overflow(self): def test_listen_backlog_overflow(self):
# Issue 15989 # Issue 15989
import _testcapi import _testcapi
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
srv.bind((HOST, 0)) srv.bind((HOST, 0))
self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1) self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
srv.close()
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
def test_flowinfo(self): def test_flowinfo(self):
......
...@@ -132,8 +132,7 @@ class UstarReadTest(ReadTest, unittest.TestCase): ...@@ -132,8 +132,7 @@ class UstarReadTest(ReadTest, unittest.TestCase):
data = fobj.read() data = fobj.read()
tarinfo = self.tar.getmember("ustar/regtype") tarinfo = self.tar.getmember("ustar/regtype")
fobj = self.tar.extractfile(tarinfo) with self.tar.extractfile(tarinfo) as fobj:
text = fobj.read() text = fobj.read()
fobj.seek(0) fobj.seek(0)
self.assertEqual(0, fobj.tell(), self.assertEqual(0, fobj.tell(),
...@@ -174,7 +173,6 @@ class UstarReadTest(ReadTest, unittest.TestCase): ...@@ -174,7 +173,6 @@ class UstarReadTest(ReadTest, unittest.TestCase):
line = fobj.readline() line = fobj.readline()
self.assertEqual(fobj.read(), data[len(line):], self.assertEqual(fobj.read(), data[len(line):],
"read() after readline() failed") "read() after readline() failed")
fobj.close()
def test_fileobj_text(self): def test_fileobj_text(self):
with self.tar.extractfile("ustar/regtype") as fobj: with self.tar.extractfile("ustar/regtype") as fobj:
...@@ -486,7 +484,7 @@ class MiscReadTestBase(CommonReadTest): ...@@ -486,7 +484,7 @@ class MiscReadTestBase(CommonReadTest):
fobj.seek(offset) fobj.seek(offset)
# Test if the tarfile starts with the second member. # Test if the tarfile starts with the second member.
tar = tar.open(self.tarname, mode="r:", fileobj=fobj) with tar.open(self.tarname, mode="r:", fileobj=fobj) as tar:
t = tar.next() t = tar.next()
self.assertEqual(t.name, name) self.assertEqual(t.name, name)
# Read to the end of fileobj and test if seeking back to the # Read to the end of fileobj and test if seeking back to the
...@@ -494,7 +492,6 @@ class MiscReadTestBase(CommonReadTest): ...@@ -494,7 +492,6 @@ class MiscReadTestBase(CommonReadTest):
tar.getmembers() tar.getmembers()
self.assertEqual(tar.extractfile(t).read(), data, self.assertEqual(tar.extractfile(t).read(), data,
"seek back did not work") "seek back did not work")
tar.close()
def test_fail_comp(self): def test_fail_comp(self):
# For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file. # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
...@@ -1042,9 +1039,8 @@ class WriteTestBase(TarTest): ...@@ -1042,9 +1039,8 @@ class WriteTestBase(TarTest):
def test_fileobj_no_close(self): def test_fileobj_no_close(self):
fobj = io.BytesIO() fobj = io.BytesIO()
tar = tarfile.open(fileobj=fobj, mode=self.mode) with tarfile.open(fileobj=fobj, mode=self.mode) as tar:
tar.addfile(tarfile.TarInfo("foo")) tar.addfile(tarfile.TarInfo("foo"))
tar.close()
self.assertFalse(fobj.closed, "external fileobjs must never closed") self.assertFalse(fobj.closed, "external fileobjs must never closed")
# Issue #20238: Incomplete gzip output with mode="w:gz" # Issue #20238: Incomplete gzip output with mode="w:gz"
data = fobj.getvalue() data = fobj.getvalue()
...@@ -1306,19 +1302,16 @@ class WriteTest(WriteTestBase, unittest.TestCase): ...@@ -1306,19 +1302,16 @@ class WriteTest(WriteTestBase, unittest.TestCase):
with open(source_file,'w') as f: with open(source_file,'w') as f:
f.write('something\n') f.write('something\n')
os.symlink(source_file, target_file) os.symlink(source_file, target_file)
tar = tarfile.open(temparchive,'w') with tarfile.open(temparchive, 'w') as tar:
tar.add(source_file) tar.add(source_file)
tar.add(target_file) tar.add(target_file)
tar.close()
# Let's extract it to the location which contains the symlink # Let's extract it to the location which contains the symlink
tar = tarfile.open(temparchive,'r') with tarfile.open(temparchive) as tar:
# this should not raise OSError: [Errno 17] File exists # this should not raise OSError: [Errno 17] File exists
try: try:
tar.extractall(path=tempdir) tar.extractall(path=tempdir)
except OSError: except OSError:
self.fail("extractall failed with symlinked files") self.fail("extractall failed with symlinked files")
finally:
tar.close()
finally: finally:
support.unlink(temparchive) support.unlink(temparchive)
support.rmtree(tempdir) support.rmtree(tempdir)
......
...@@ -31,13 +31,12 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -31,13 +31,12 @@ class TestsWithSourceFile(unittest.TestCase):
self.data = '\n'.join(line_gen).encode('ascii') self.data = '\n'.join(line_gen).encode('ascii')
# And write it to a file. # And write it to a file.
fp = open(TESTFN, "wb") with open(TESTFN, "wb") as fp:
fp.write(self.data) fp.write(self.data)
fp.close()
def zipTest(self, f, compression): def zipTest(self, f, compression):
# Create the ZIP archive. # Create the ZIP archive.
zipfp = zipfile.ZipFile(f, "w", compression) with zipfile.ZipFile(f, "w", compression) as zipfp:
# It will contain enough copies of self.data to reach about 6 GiB of # It will contain enough copies of self.data to reach about 6 GiB of
# raw data to store. # raw data to store.
...@@ -53,10 +52,9 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -53,10 +52,9 @@ class TestsWithSourceFile(unittest.TestCase):
' zipTest still writing %d of %d, be patient...' % ' zipTest still writing %d of %d, be patient...' %
(num, filecount)), file=sys.__stdout__) (num, filecount)), file=sys.__stdout__)
sys.__stdout__.flush() sys.__stdout__.flush()
zipfp.close()
# Read the ZIP archive # Read the ZIP archive
zipfp = zipfile.ZipFile(f, "r", compression) with zipfile.ZipFile(f, "r", compression) as zipfp:
for num in range(filecount): for num in range(filecount):
self.assertEqual(zipfp.read("testfn%d" % num), self.data) self.assertEqual(zipfp.read("testfn%d" % num), self.data)
# Print still working message since this test can be really slow # Print still working message since this test can be really slow
...@@ -66,7 +64,6 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -66,7 +64,6 @@ class TestsWithSourceFile(unittest.TestCase):
' zipTest still reading %d of %d, be patient...' % ' zipTest still reading %d of %d, be patient...' %
(num, filecount)), file=sys.__stdout__) (num, filecount)), file=sys.__stdout__)
sys.__stdout__.flush() sys.__stdout__.flush()
zipfp.close()
def testStored(self): def testStored(self):
# Try the temp file first. If we do TESTFN2 first, then it hogs # Try the temp file first. If we do TESTFN2 first, then it hogs
...@@ -95,23 +92,21 @@ class OtherTests(unittest.TestCase): ...@@ -95,23 +92,21 @@ class OtherTests(unittest.TestCase):
def testMoreThan64kFiles(self): def testMoreThan64kFiles(self):
# This test checks that more than 64k files can be added to an archive, # This test checks that more than 64k files can be added to an archive,
# and that the resulting archive can be read properly by ZipFile # and that the resulting archive can be read properly by ZipFile
zipf = zipfile.ZipFile(TESTFN, mode="w", allowZip64=True) with zipfile.ZipFile(TESTFN, mode="w", allowZip64=True) as zipf:
zipf.debug = 100 zipf.debug = 100
numfiles = (1 << 16) * 3//2 numfiles = (1 << 16) * 3//2
for i in range(numfiles): for i in range(numfiles):
zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57)) zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close()
zipf2 = zipfile.ZipFile(TESTFN, mode="r") with zipfile.ZipFile(TESTFN, mode="r") as zipf2:
self.assertEqual(len(zipf2.namelist()), numfiles) self.assertEqual(len(zipf2.namelist()), numfiles)
for i in range(numfiles): for i in range(numfiles):
content = zipf2.read("foo%08d" % i).decode('ascii') content = zipf2.read("foo%08d" % i).decode('ascii')
self.assertEqual(content, "%d" % (i**3 % 57)) self.assertEqual(content, "%d" % (i**3 % 57))
zipf2.close()
def testMoreThan64kFilesAppend(self): def testMoreThan64kFilesAppend(self):
zipf = zipfile.ZipFile(TESTFN, mode="w", allowZip64=False) with zipfile.ZipFile(TESTFN, mode="w", allowZip64=False) as zipf:
zipf.debug = 100 zipf.debug = 100
numfiles = (1 << 16) - 1 numfiles = (1 << 16) - 1
for i in range(numfiles): for i in range(numfiles):
...@@ -120,31 +115,27 @@ class OtherTests(unittest.TestCase): ...@@ -120,31 +115,27 @@ class OtherTests(unittest.TestCase):
with self.assertRaises(zipfile.LargeZipFile): with self.assertRaises(zipfile.LargeZipFile):
zipf.writestr("foo%08d" % numfiles, b'') zipf.writestr("foo%08d" % numfiles, b'')
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close()
zipf = zipfile.ZipFile(TESTFN, mode="a", allowZip64=False) with zipfile.ZipFile(TESTFN, mode="a", allowZip64=False) as zipf:
zipf.debug = 100 zipf.debug = 100
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
with self.assertRaises(zipfile.LargeZipFile): with self.assertRaises(zipfile.LargeZipFile):
zipf.writestr("foo%08d" % numfiles, b'') zipf.writestr("foo%08d" % numfiles, b'')
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close()
zipf = zipfile.ZipFile(TESTFN, mode="a", allowZip64=True) with zipfile.ZipFile(TESTFN, mode="a", allowZip64=True) as zipf:
zipf.debug = 100 zipf.debug = 100
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
numfiles2 = (1 << 16) * 3//2 numfiles2 = (1 << 16) * 3//2
for i in range(numfiles, numfiles2): for i in range(numfiles, numfiles2):
zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57)) zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
self.assertEqual(len(zipf.namelist()), numfiles2) self.assertEqual(len(zipf.namelist()), numfiles2)
zipf.close()
zipf2 = zipfile.ZipFile(TESTFN, mode="r") with zipfile.ZipFile(TESTFN, mode="r") as zipf2:
self.assertEqual(len(zipf2.namelist()), numfiles2) self.assertEqual(len(zipf2.namelist()), numfiles2)
for i in range(numfiles2): for i in range(numfiles2):
content = zipf2.read("foo%08d" % i).decode('ascii') content = zipf2.read("foo%08d" % i).decode('ascii')
self.assertEqual(content, "%d" % (i**3 % 57)) self.assertEqual(content, "%d" % (i**3 % 57))
zipf2.close()
def tearDown(self): def tearDown(self):
support.unlink(TESTFN) support.unlink(TESTFN)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment