Commit 9e4861f5 authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-22831: Use "with" to avoid possible fd leaks in tests (part 1). (GH-10928)

parent b7272395
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Original by Roger E. Masse Original by Roger E. Masse
""" """
import contextlib
import io import io
import operator import operator
import os import os
...@@ -32,12 +33,11 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -32,12 +33,11 @@ class DumbDBMTestCase(unittest.TestCase):
} }
def test_dumbdbm_creation(self): def test_dumbdbm_creation(self):
f = dumbdbm.open(_fname, 'c') with contextlib.closing(dumbdbm.open(_fname, 'c')) as f:
self.assertEqual(list(f.keys()), []) self.assertEqual(list(f.keys()), [])
for key in self._dict: for key in self._dict:
f[key] = self._dict[key] f[key] = self._dict[key]
self.read_helper(f) self.read_helper(f)
f.close()
@unittest.skipUnless(hasattr(os, 'umask'), 'test needs os.umask()') @unittest.skipUnless(hasattr(os, 'umask'), 'test needs os.umask()')
def test_dumbdbm_creation_mode(self): def test_dumbdbm_creation_mode(self):
...@@ -69,78 +69,70 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -69,78 +69,70 @@ class DumbDBMTestCase(unittest.TestCase):
def test_dumbdbm_modification(self): def test_dumbdbm_modification(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'w') with contextlib.closing(dumbdbm.open(_fname, 'w')) as f:
self._dict[b'g'] = f[b'g'] = b"indented" self._dict[b'g'] = f[b'g'] = b"indented"
self.read_helper(f) self.read_helper(f)
# setdefault() works as in the dict interface # setdefault() works as in the dict interface
self.assertEqual(f.setdefault(b'xxx', b'foo'), b'foo') self.assertEqual(f.setdefault(b'xxx', b'foo'), b'foo')
self.assertEqual(f[b'xxx'], b'foo') self.assertEqual(f[b'xxx'], b'foo')
f.close()
def test_dumbdbm_read(self): def test_dumbdbm_read(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'r') with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
self.read_helper(f) self.read_helper(f)
with self.assertRaisesRegex(dumbdbm.error, with self.assertRaisesRegex(dumbdbm.error,
'The database is opened for reading only'): 'The database is opened for reading only'):
f[b'g'] = b'x' f[b'g'] = b'x'
with self.assertRaisesRegex(dumbdbm.error, with self.assertRaisesRegex(dumbdbm.error,
'The database is opened for reading only'): 'The database is opened for reading only'):
del f[b'a'] del f[b'a']
# get() works as in the dict interface # get() works as in the dict interface
self.assertEqual(f.get(b'a'), self._dict[b'a']) self.assertEqual(f.get(b'a'), self._dict[b'a'])
self.assertEqual(f.get(b'xxx', b'foo'), b'foo') self.assertEqual(f.get(b'xxx', b'foo'), b'foo')
self.assertIsNone(f.get(b'xxx')) self.assertIsNone(f.get(b'xxx'))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
f[b'xxx'] f[b'xxx']
f.close()
def test_dumbdbm_keys(self): def test_dumbdbm_keys(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
keys = self.keys_helper(f) keys = self.keys_helper(f)
f.close()
def test_write_contains(self): def test_write_contains(self):
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
self.assertIn(b'1', f) self.assertIn(b'1', f)
f.close()
def test_write_write_read(self): def test_write_write_read(self):
# test for bug #482460 # test for bug #482460
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
f[b'1'] = b'hello2' f[b'1'] = b'hello2'
f.close() with contextlib.closing(dumbdbm.open(_fname)) as f:
f = dumbdbm.open(_fname) self.assertEqual(f[b'1'], b'hello2')
self.assertEqual(f[b'1'], b'hello2')
f.close()
def test_str_read(self): def test_str_read(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname, 'r') with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
self.assertEqual(f['\u00fc'], self._dict['\u00fc'.encode('utf-8')]) self.assertEqual(f['\u00fc'], self._dict['\u00fc'.encode('utf-8')])
def test_str_write_contains(self): def test_str_write_contains(self):
self.init_db() self.init_db()
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f['\u00fc'] = b'!' f['\u00fc'] = b'!'
f['1'] = 'a' f['1'] = 'a'
f.close() with contextlib.closing(dumbdbm.open(_fname, 'r')) as f:
f = dumbdbm.open(_fname, 'r') self.assertIn('\u00fc', f)
self.assertIn('\u00fc', f) self.assertEqual(f['\u00fc'.encode('utf-8')],
self.assertEqual(f['\u00fc'.encode('utf-8')], self._dict['\u00fc'.encode('utf-8')])
self._dict['\u00fc'.encode('utf-8')]) self.assertEqual(f[b'1'], b'a')
self.assertEqual(f[b'1'], b'a')
def test_line_endings(self): def test_line_endings(self):
# test for bug #1172763: dumbdbm would die if the line endings # test for bug #1172763: dumbdbm would die if the line endings
# weren't what was expected. # weren't what was expected.
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
f[b'1'] = b'hello' f[b'1'] = b'hello'
f[b'2'] = b'hello2' f[b'2'] = b'hello2'
f.close()
# Mangle the file by changing the line separator to Windows or Unix # Mangle the file by changing the line separator to Windows or Unix
with io.open(_fname + '.dir', 'rb') as file: with io.open(_fname + '.dir', 'rb') as file:
...@@ -163,10 +155,9 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -163,10 +155,9 @@ class DumbDBMTestCase(unittest.TestCase):
self.assertEqual(self._dict[key], f[key]) self.assertEqual(self._dict[key], f[key])
def init_db(self): def init_db(self):
f = dumbdbm.open(_fname, 'n') with contextlib.closing(dumbdbm.open(_fname, 'n')) as f:
for k in self._dict: for k in self._dict:
f[k] = self._dict[k] f[k] = self._dict[k]
f.close()
def keys_helper(self, f): def keys_helper(self, f):
keys = sorted(f.keys()) keys = sorted(f.keys())
...@@ -180,25 +171,23 @@ class DumbDBMTestCase(unittest.TestCase): ...@@ -180,25 +171,23 @@ class DumbDBMTestCase(unittest.TestCase):
import random import random
d = {} # mirror the database d = {} # mirror the database
for dummy in range(5): for dummy in range(5):
f = dumbdbm.open(_fname) with contextlib.closing(dumbdbm.open(_fname)) as f:
for dummy in range(100): for dummy in range(100):
k = random.choice('abcdefghijklm') k = random.choice('abcdefghijklm')
if random.random() < 0.2: if random.random() < 0.2:
if k in d: if k in d:
del d[k] del d[k]
del f[k] del f[k]
else: else:
v = random.choice((b'a', b'b', b'c')) * random.randrange(10000) v = random.choice((b'a', b'b', b'c')) * random.randrange(10000)
d[k] = v d[k] = v
f[k] = v f[k] = v
self.assertEqual(f[k], v) self.assertEqual(f[k], v)
f.close()
with contextlib.closing(dumbdbm.open(_fname)) as f:
f = dumbdbm.open(_fname) expected = sorted((k.encode("latin-1"), v) for k, v in d.items())
expected = sorted((k.encode("latin-1"), v) for k, v in d.items()) got = sorted(f.items())
got = sorted(f.items()) self.assertEqual(expected, got)
self.assertEqual(expected, got)
f.close()
def test_context_manager(self): def test_context_manager(self):
with dumbdbm.open(_fname, 'c') as db: with dumbdbm.open(_fname, 'c') as db:
......
...@@ -268,13 +268,12 @@ class MmapTests(unittest.TestCase): ...@@ -268,13 +268,12 @@ class MmapTests(unittest.TestCase):
def test_find_end(self): def test_find_end(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
data = b'one two ones' data = b'one two ones'
n = len(data) n = len(data)
f.write(data) f.write(data)
f.flush() f.flush()
m = mmap.mmap(f.fileno(), n) m = mmap.mmap(f.fileno(), n)
f.close()
self.assertEqual(m.find(b'one'), 0) self.assertEqual(m.find(b'one'), 0)
self.assertEqual(m.find(b'ones'), 8) self.assertEqual(m.find(b'ones'), 8)
...@@ -287,13 +286,12 @@ class MmapTests(unittest.TestCase): ...@@ -287,13 +286,12 @@ class MmapTests(unittest.TestCase):
def test_rfind(self): def test_rfind(self):
# test the new 'end' parameter works as expected # test the new 'end' parameter works as expected
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
data = b'one two ones' data = b'one two ones'
n = len(data) n = len(data)
f.write(data) f.write(data)
f.flush() f.flush()
m = mmap.mmap(f.fileno(), n) m = mmap.mmap(f.fileno(), n)
f.close()
self.assertEqual(m.rfind(b'one'), 8) self.assertEqual(m.rfind(b'one'), 8)
self.assertEqual(m.rfind(b'one '), 0) self.assertEqual(m.rfind(b'one '), 0)
...@@ -306,30 +304,23 @@ class MmapTests(unittest.TestCase): ...@@ -306,30 +304,23 @@ class MmapTests(unittest.TestCase):
def test_double_close(self): def test_double_close(self):
# make sure a double close doesn't crash on Solaris (Bug# 665913) # make sure a double close doesn't crash on Solaris (Bug# 665913)
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
f.write(2**16 * b'a') # Arbitrary character
f.write(2**16 * b'a') # Arbitrary character
f.close()
f = open(TESTFN, 'rb') with open(TESTFN, 'rb') as f:
mf = mmap.mmap(f.fileno(), 2**16, access=mmap.ACCESS_READ) mf = mmap.mmap(f.fileno(), 2**16, access=mmap.ACCESS_READ)
mf.close() mf.close()
mf.close() mf.close()
f.close()
def test_entire_file(self): def test_entire_file(self):
# test mapping of entire file by passing 0 for map length # test mapping of entire file by passing 0 for map length
f = open(TESTFN, "wb+") with open(TESTFN, "wb+") as f:
f.write(2**16 * b'm') # Arbitrary character
f.write(2**16 * b'm') # Arbitrary character with open(TESTFN, "rb+") as f, \
f.close() mmap.mmap(f.fileno(), 0) as mf:
self.assertEqual(len(mf), 2**16, "Map size should equal file size.")
f = open(TESTFN, "rb+") self.assertEqual(mf.read(2**16), 2**16 * b"m")
mf = mmap.mmap(f.fileno(), 0)
self.assertEqual(len(mf), 2**16, "Map size should equal file size.")
self.assertEqual(mf.read(2**16), 2**16 * b"m")
mf.close()
f.close()
def test_length_0_offset(self): def test_length_0_offset(self):
# Issue #10916: test mapping of remainder of file by passing 0 for # Issue #10916: test mapping of remainder of file by passing 0 for
...@@ -355,16 +346,15 @@ class MmapTests(unittest.TestCase): ...@@ -355,16 +346,15 @@ class MmapTests(unittest.TestCase):
def test_move(self): def test_move(self):
# make move works everywhere (64-bit format problem earlier) # make move works everywhere (64-bit format problem earlier)
f = open(TESTFN, 'wb+') with open(TESTFN, 'wb+') as f:
f.write(b"ABCDEabcde") # Arbitrary character f.write(b"ABCDEabcde") # Arbitrary character
f.flush() f.flush()
mf = mmap.mmap(f.fileno(), 10) mf = mmap.mmap(f.fileno(), 10)
mf.move(5, 0, 5) mf.move(5, 0, 5)
self.assertEqual(mf[:], b"ABCDEABCDE", "Map move should have duplicated front 5") self.assertEqual(mf[:], b"ABCDEABCDE", "Map move should have duplicated front 5")
mf.close() mf.close()
f.close()
# more excessive test # more excessive test
data = b"0123456789" data = b"0123456789"
...@@ -562,10 +552,9 @@ class MmapTests(unittest.TestCase): ...@@ -562,10 +552,9 @@ class MmapTests(unittest.TestCase):
mapsize = 10 mapsize = 10
with open(TESTFN, "wb") as fp: with open(TESTFN, "wb") as fp:
fp.write(b"a"*mapsize) fp.write(b"a"*mapsize)
f = open(TESTFN, "rb") with open(TESTFN, "rb") as f:
m = mmap.mmap(f.fileno(), mapsize, prot=mmap.PROT_READ) m = mmap.mmap(f.fileno(), mapsize, prot=mmap.PROT_READ)
self.assertRaises(TypeError, m.write, "foo") self.assertRaises(TypeError, m.write, "foo")
f.close()
def test_error(self): def test_error(self):
self.assertIs(mmap.error, OSError) self.assertIs(mmap.error, OSError)
...@@ -574,9 +563,8 @@ class MmapTests(unittest.TestCase): ...@@ -574,9 +563,8 @@ class MmapTests(unittest.TestCase):
data = b"0123456789" data = b"0123456789"
with open(TESTFN, "wb") as fp: with open(TESTFN, "wb") as fp:
fp.write(b"x"*len(data)) fp.write(b"x"*len(data))
f = open(TESTFN, "r+b") with open(TESTFN, "r+b") as f:
m = mmap.mmap(f.fileno(), len(data)) m = mmap.mmap(f.fileno(), len(data))
f.close()
# Test write_byte() # Test write_byte()
for i in range(len(data)): for i in range(len(data)):
self.assertEqual(m.tell(), i) self.assertEqual(m.tell(), i)
......
...@@ -799,10 +799,9 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -799,10 +799,9 @@ class GeneralModuleTests(unittest.TestCase):
self.assertEqual(repr(s), expected) self.assertEqual(repr(s), expected)
def test_weakref(self): def test_weakref(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
p = proxy(s) p = proxy(s)
self.assertEqual(p.fileno(), s.fileno()) self.assertEqual(p.fileno(), s.fileno())
s.close()
s = None s = None
try: try:
p.fileno() p.fileno()
...@@ -1072,9 +1071,8 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1072,9 +1071,8 @@ class GeneralModuleTests(unittest.TestCase):
# Testing default timeout # Testing default timeout
# The default timeout should initially be None # The default timeout should initially be None
self.assertEqual(socket.getdefaulttimeout(), None) self.assertEqual(socket.getdefaulttimeout(), None)
s = socket.socket() with socket.socket() as s:
self.assertEqual(s.gettimeout(), None) self.assertEqual(s.gettimeout(), None)
s.close()
# Set the default timeout to 10, and see if it propagates # Set the default timeout to 10, and see if it propagates
with socket_setdefaulttimeout(10): with socket_setdefaulttimeout(10):
...@@ -1297,9 +1295,8 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1297,9 +1295,8 @@ class GeneralModuleTests(unittest.TestCase):
def testSendAfterClose(self): def testSendAfterClose(self):
# testing send() after close() with timeout # testing send() after close() with timeout
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1) sock.settimeout(1)
sock.close()
self.assertRaises(OSError, sock.send, b"spam") self.assertRaises(OSError, sock.send, b"spam")
def testCloseException(self): def testCloseException(self):
...@@ -1317,16 +1314,15 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1317,16 +1314,15 @@ class GeneralModuleTests(unittest.TestCase):
def testNewAttributes(self): def testNewAttributes(self):
# testing .family, .type and .protocol # testing .family, .type and .protocol
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
self.assertEqual(sock.family, socket.AF_INET) self.assertEqual(sock.family, socket.AF_INET)
if hasattr(socket, 'SOCK_CLOEXEC'): if hasattr(socket, 'SOCK_CLOEXEC'):
self.assertIn(sock.type, self.assertIn(sock.type,
(socket.SOCK_STREAM | socket.SOCK_CLOEXEC, (socket.SOCK_STREAM | socket.SOCK_CLOEXEC,
socket.SOCK_STREAM)) socket.SOCK_STREAM))
else: else:
self.assertEqual(sock.type, socket.SOCK_STREAM) self.assertEqual(sock.type, socket.SOCK_STREAM)
self.assertEqual(sock.proto, 0) self.assertEqual(sock.proto, 0)
sock.close()
def test_getsockaddrarg(self): def test_getsockaddrarg(self):
sock = socket.socket() sock = socket.socket()
...@@ -1601,10 +1597,9 @@ class GeneralModuleTests(unittest.TestCase): ...@@ -1601,10 +1597,9 @@ class GeneralModuleTests(unittest.TestCase):
def test_listen_backlog_overflow(self): def test_listen_backlog_overflow(self):
# Issue 15989 # Issue 15989
import _testcapi import _testcapi
srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:
srv.bind((HOST, 0)) srv.bind((HOST, 0))
self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1) self.assertRaises(OverflowError, srv.listen, _testcapi.INT_MAX + 1)
srv.close()
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
def test_flowinfo(self): def test_flowinfo(self):
......
...@@ -132,49 +132,47 @@ class UstarReadTest(ReadTest, unittest.TestCase): ...@@ -132,49 +132,47 @@ class UstarReadTest(ReadTest, unittest.TestCase):
data = fobj.read() data = fobj.read()
tarinfo = self.tar.getmember("ustar/regtype") tarinfo = self.tar.getmember("ustar/regtype")
fobj = self.tar.extractfile(tarinfo) with self.tar.extractfile(tarinfo) as fobj:
text = fobj.read()
text = fobj.read() fobj.seek(0)
fobj.seek(0) self.assertEqual(0, fobj.tell(),
self.assertEqual(0, fobj.tell(), "seek() to file's start failed")
"seek() to file's start failed") fobj.seek(2048, 0)
fobj.seek(2048, 0) self.assertEqual(2048, fobj.tell(),
self.assertEqual(2048, fobj.tell(), "seek() to absolute position failed")
"seek() to absolute position failed") fobj.seek(-1024, 1)
fobj.seek(-1024, 1) self.assertEqual(1024, fobj.tell(),
self.assertEqual(1024, fobj.tell(), "seek() to negative relative position failed")
"seek() to negative relative position failed") fobj.seek(1024, 1)
fobj.seek(1024, 1) self.assertEqual(2048, fobj.tell(),
self.assertEqual(2048, fobj.tell(), "seek() to positive relative position failed")
"seek() to positive relative position failed") s = fobj.read(10)
s = fobj.read(10) self.assertEqual(s, data[2048:2058],
self.assertEqual(s, data[2048:2058], "read() after seek failed")
"read() after seek failed") fobj.seek(0, 2)
fobj.seek(0, 2) self.assertEqual(tarinfo.size, fobj.tell(),
self.assertEqual(tarinfo.size, fobj.tell(), "seek() to file's end failed")
"seek() to file's end failed") self.assertEqual(fobj.read(), b"",
self.assertEqual(fobj.read(), b"", "read() at file's end did not return empty string")
"read() at file's end did not return empty string") fobj.seek(-tarinfo.size, 2)
fobj.seek(-tarinfo.size, 2) self.assertEqual(0, fobj.tell(),
self.assertEqual(0, fobj.tell(), "relative seek() to file's end failed")
"relative seek() to file's end failed") fobj.seek(512)
fobj.seek(512) s1 = fobj.readlines()
s1 = fobj.readlines() fobj.seek(512)
fobj.seek(512) s2 = fobj.readlines()
s2 = fobj.readlines() self.assertEqual(s1, s2,
self.assertEqual(s1, s2, "readlines() after seek failed")
"readlines() after seek failed") fobj.seek(0)
fobj.seek(0) self.assertEqual(len(fobj.readline()), fobj.tell(),
self.assertEqual(len(fobj.readline()), fobj.tell(), "tell() after readline() failed")
"tell() after readline() failed") fobj.seek(512)
fobj.seek(512) self.assertEqual(len(fobj.readline()) + 512, fobj.tell(),
self.assertEqual(len(fobj.readline()) + 512, fobj.tell(), "tell() after seek() and readline() failed")
"tell() after seek() and readline() failed") fobj.seek(0)
fobj.seek(0) line = fobj.readline()
line = fobj.readline() self.assertEqual(fobj.read(), data[len(line):],
self.assertEqual(fobj.read(), data[len(line):], "read() after readline() failed")
"read() after readline() failed")
fobj.close()
def test_fileobj_text(self): def test_fileobj_text(self):
with self.tar.extractfile("ustar/regtype") as fobj: with self.tar.extractfile("ustar/regtype") as fobj:
...@@ -486,15 +484,14 @@ class MiscReadTestBase(CommonReadTest): ...@@ -486,15 +484,14 @@ class MiscReadTestBase(CommonReadTest):
fobj.seek(offset) fobj.seek(offset)
# Test if the tarfile starts with the second member. # Test if the tarfile starts with the second member.
tar = tar.open(self.tarname, mode="r:", fileobj=fobj) with tar.open(self.tarname, mode="r:", fileobj=fobj) as tar:
t = tar.next() t = tar.next()
self.assertEqual(t.name, name) self.assertEqual(t.name, name)
# Read to the end of fileobj and test if seeking back to the # Read to the end of fileobj and test if seeking back to the
# beginning works. # beginning works.
tar.getmembers() tar.getmembers()
self.assertEqual(tar.extractfile(t).read(), data, self.assertEqual(tar.extractfile(t).read(), data,
"seek back did not work") "seek back did not work")
tar.close()
def test_fail_comp(self): def test_fail_comp(self):
# For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file. # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
...@@ -1042,9 +1039,8 @@ class WriteTestBase(TarTest): ...@@ -1042,9 +1039,8 @@ class WriteTestBase(TarTest):
def test_fileobj_no_close(self): def test_fileobj_no_close(self):
fobj = io.BytesIO() fobj = io.BytesIO()
tar = tarfile.open(fileobj=fobj, mode=self.mode) with tarfile.open(fileobj=fobj, mode=self.mode) as tar:
tar.addfile(tarfile.TarInfo("foo")) tar.addfile(tarfile.TarInfo("foo"))
tar.close()
self.assertFalse(fobj.closed, "external fileobjs must never closed") self.assertFalse(fobj.closed, "external fileobjs must never closed")
# Issue #20238: Incomplete gzip output with mode="w:gz" # Issue #20238: Incomplete gzip output with mode="w:gz"
data = fobj.getvalue() data = fobj.getvalue()
...@@ -1306,19 +1302,16 @@ class WriteTest(WriteTestBase, unittest.TestCase): ...@@ -1306,19 +1302,16 @@ class WriteTest(WriteTestBase, unittest.TestCase):
with open(source_file,'w') as f: with open(source_file,'w') as f:
f.write('something\n') f.write('something\n')
os.symlink(source_file, target_file) os.symlink(source_file, target_file)
tar = tarfile.open(temparchive,'w') with tarfile.open(temparchive, 'w') as tar:
tar.add(source_file) tar.add(source_file)
tar.add(target_file) tar.add(target_file)
tar.close()
# Let's extract it to the location which contains the symlink # Let's extract it to the location which contains the symlink
tar = tarfile.open(temparchive,'r') with tarfile.open(temparchive) as tar:
# this should not raise OSError: [Errno 17] File exists # this should not raise OSError: [Errno 17] File exists
try: try:
tar.extractall(path=tempdir) tar.extractall(path=tempdir)
except OSError: except OSError:
self.fail("extractall failed with symlinked files") self.fail("extractall failed with symlinked files")
finally:
tar.close()
finally: finally:
support.unlink(temparchive) support.unlink(temparchive)
support.rmtree(tempdir) support.rmtree(tempdir)
......
...@@ -31,42 +31,39 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -31,42 +31,39 @@ class TestsWithSourceFile(unittest.TestCase):
self.data = '\n'.join(line_gen).encode('ascii') self.data = '\n'.join(line_gen).encode('ascii')
# And write it to a file. # And write it to a file.
fp = open(TESTFN, "wb") with open(TESTFN, "wb") as fp:
fp.write(self.data) fp.write(self.data)
fp.close()
def zipTest(self, f, compression): def zipTest(self, f, compression):
# Create the ZIP archive. # Create the ZIP archive.
zipfp = zipfile.ZipFile(f, "w", compression) with zipfile.ZipFile(f, "w", compression) as zipfp:
# It will contain enough copies of self.data to reach about 6 GiB of # It will contain enough copies of self.data to reach about 6 GiB of
# raw data to store. # raw data to store.
filecount = 6*1024**3 // len(self.data) filecount = 6*1024**3 // len(self.data)
next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL
for num in range(filecount): for num in range(filecount):
zipfp.writestr("testfn%d" % num, self.data) zipfp.writestr("testfn%d" % num, self.data)
# Print still working message since this test can be really slow # Print still working message since this test can be really slow
if next_time <= time.monotonic(): if next_time <= time.monotonic():
next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL
print(( print((
' zipTest still writing %d of %d, be patient...' % ' zipTest still writing %d of %d, be patient...' %
(num, filecount)), file=sys.__stdout__) (num, filecount)), file=sys.__stdout__)
sys.__stdout__.flush() sys.__stdout__.flush()
zipfp.close()
# Read the ZIP archive # Read the ZIP archive
zipfp = zipfile.ZipFile(f, "r", compression) with zipfile.ZipFile(f, "r", compression) as zipfp:
for num in range(filecount): for num in range(filecount):
self.assertEqual(zipfp.read("testfn%d" % num), self.data) self.assertEqual(zipfp.read("testfn%d" % num), self.data)
# Print still working message since this test can be really slow # Print still working message since this test can be really slow
if next_time <= time.monotonic(): if next_time <= time.monotonic():
next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL next_time = time.monotonic() + _PRINT_WORKING_MSG_INTERVAL
print(( print((
' zipTest still reading %d of %d, be patient...' % ' zipTest still reading %d of %d, be patient...' %
(num, filecount)), file=sys.__stdout__) (num, filecount)), file=sys.__stdout__)
sys.__stdout__.flush() sys.__stdout__.flush()
zipfp.close()
def testStored(self): def testStored(self):
# Try the temp file first. If we do TESTFN2 first, then it hogs # Try the temp file first. If we do TESTFN2 first, then it hogs
...@@ -95,56 +92,50 @@ class OtherTests(unittest.TestCase): ...@@ -95,56 +92,50 @@ class OtherTests(unittest.TestCase):
def testMoreThan64kFiles(self): def testMoreThan64kFiles(self):
# This test checks that more than 64k files can be added to an archive, # This test checks that more than 64k files can be added to an archive,
# and that the resulting archive can be read properly by ZipFile # and that the resulting archive can be read properly by ZipFile
zipf = zipfile.ZipFile(TESTFN, mode="w", allowZip64=True) with zipfile.ZipFile(TESTFN, mode="w", allowZip64=True) as zipf:
zipf.debug = 100 zipf.debug = 100
numfiles = (1 << 16) * 3//2 numfiles = (1 << 16) * 3//2
for i in range(numfiles): for i in range(numfiles):
zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57)) zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close()
with zipfile.ZipFile(TESTFN, mode="r") as zipf2:
zipf2 = zipfile.ZipFile(TESTFN, mode="r") self.assertEqual(len(zipf2.namelist()), numfiles)
self.assertEqual(len(zipf2.namelist()), numfiles) for i in range(numfiles):
for i in range(numfiles): content = zipf2.read("foo%08d" % i).decode('ascii')
content = zipf2.read("foo%08d" % i).decode('ascii') self.assertEqual(content, "%d" % (i**3 % 57))
self.assertEqual(content, "%d" % (i**3 % 57))
zipf2.close()
def testMoreThan64kFilesAppend(self): def testMoreThan64kFilesAppend(self):
zipf = zipfile.ZipFile(TESTFN, mode="w", allowZip64=False) with zipfile.ZipFile(TESTFN, mode="w", allowZip64=False) as zipf:
zipf.debug = 100 zipf.debug = 100
numfiles = (1 << 16) - 1 numfiles = (1 << 16) - 1
for i in range(numfiles): for i in range(numfiles):
zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57)) zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
with self.assertRaises(zipfile.LargeZipFile): with self.assertRaises(zipfile.LargeZipFile):
zipf.writestr("foo%08d" % numfiles, b'') zipf.writestr("foo%08d" % numfiles, b'')
self.assertEqual(len(zipf.namelist()), numfiles) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close()
with zipfile.ZipFile(TESTFN, mode="a", allowZip64=False) as zipf:
zipf = zipfile.ZipFile(TESTFN, mode="a", allowZip64=False) zipf.debug = 100
zipf.debug = 100 self.assertEqual(len(zipf.namelist()), numfiles)
self.assertEqual(len(zipf.namelist()), numfiles) with self.assertRaises(zipfile.LargeZipFile):
with self.assertRaises(zipfile.LargeZipFile): zipf.writestr("foo%08d" % numfiles, b'')
zipf.writestr("foo%08d" % numfiles, b'') self.assertEqual(len(zipf.namelist()), numfiles)
self.assertEqual(len(zipf.namelist()), numfiles)
zipf.close() with zipfile.ZipFile(TESTFN, mode="a", allowZip64=True) as zipf:
zipf.debug = 100
zipf = zipfile.ZipFile(TESTFN, mode="a", allowZip64=True) self.assertEqual(len(zipf.namelist()), numfiles)
zipf.debug = 100 numfiles2 = (1 << 16) * 3//2
self.assertEqual(len(zipf.namelist()), numfiles) for i in range(numfiles, numfiles2):
numfiles2 = (1 << 16) * 3//2 zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
for i in range(numfiles, numfiles2): self.assertEqual(len(zipf.namelist()), numfiles2)
zipf.writestr("foo%08d" % i, "%d" % (i**3 % 57))
self.assertEqual(len(zipf.namelist()), numfiles2) with zipfile.ZipFile(TESTFN, mode="r") as zipf2:
zipf.close() self.assertEqual(len(zipf2.namelist()), numfiles2)
for i in range(numfiles2):
zipf2 = zipfile.ZipFile(TESTFN, mode="r") content = zipf2.read("foo%08d" % i).decode('ascii')
self.assertEqual(len(zipf2.namelist()), numfiles2) self.assertEqual(content, "%d" % (i**3 % 57))
for i in range(numfiles2):
content = zipf2.read("foo%08d" % i).decode('ascii')
self.assertEqual(content, "%d" % (i**3 % 57))
zipf2.close()
def tearDown(self): def tearDown(self):
support.unlink(TESTFN) support.unlink(TESTFN)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment