Commit 17babc5e authored by Antoine Pitrou's avatar Antoine Pitrou

Issue #16408: Fix file descriptors not being closed in error conditions in the zipfile module.

Patch by Serhiy Storchaka.
parent a39a22dc
...@@ -719,30 +719,34 @@ class ZipFile: ...@@ -719,30 +719,34 @@ class ZipFile:
self.fp = file self.fp = file
self.filename = getattr(file, 'name', None) self.filename = getattr(file, 'name', None)
if key == 'r': try:
self._GetContents() if key == 'r':
elif key == 'w':
# set the modified flag so central directory gets written
# even if no files are added to the archive
self._didModify = True
elif key == 'a':
try:
# See if file is a zip file
self._RealGetContents() self._RealGetContents()
# seek to start of directory and overwrite elif key == 'w':
self.fp.seek(self.start_dir, 0)
except BadZipFile:
# file is not a zip file, just append
self.fp.seek(0, 2)
# set the modified flag so central directory gets written # set the modified flag so central directory gets written
# even if no files are added to the archive # even if no files are added to the archive
self._didModify = True self._didModify = True
else: elif key == 'a':
try:
# See if file is a zip file
self._RealGetContents()
# seek to start of directory and overwrite
self.fp.seek(self.start_dir, 0)
except BadZipFile:
# file is not a zip file, just append
self.fp.seek(0, 2)
# set the modified flag so central directory gets written
# even if no files are added to the archive
self._didModify = True
else:
raise RuntimeError('Mode must be "r", "w" or "a"')
except:
fp = self.fp
self.fp = None
if not self._filePassed: if not self._filePassed:
self.fp.close() fp.close()
self.fp = None raise
raise RuntimeError('Mode must be "r", "w" or "a"')
def __enter__(self): def __enter__(self):
return self return self
...@@ -750,17 +754,6 @@ class ZipFile: ...@@ -750,17 +754,6 @@ class ZipFile:
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.close() self.close()
def _GetContents(self):
"""Read the directory, making sure we close the file if the format
is bad."""
try:
self._RealGetContents()
except BadZipFile:
if not self._filePassed:
self.fp.close()
self.fp = None
raise
def _RealGetContents(self): def _RealGetContents(self):
"""Read in the table of contents for the ZIP file.""" """Read in the table of contents for the ZIP file."""
fp = self.fp fp = self.fp
...@@ -862,9 +855,9 @@ class ZipFile: ...@@ -862,9 +855,9 @@ class ZipFile:
try: try:
# Read by chunks, to avoid an OverflowError or a # Read by chunks, to avoid an OverflowError or a
# MemoryError with very large embedded files. # MemoryError with very large embedded files.
f = self.open(zinfo.filename, "r") with self.open(zinfo.filename, "r") as f:
while f.read(chunk_size): # Check CRC-32 while f.read(chunk_size): # Check CRC-32
pass pass
except BadZipFile: except BadZipFile:
return zinfo.filename return zinfo.filename
...@@ -926,76 +919,70 @@ class ZipFile: ...@@ -926,76 +919,70 @@ class ZipFile:
else: else:
zef_file = io.open(self.filename, 'rb') zef_file = io.open(self.filename, 'rb')
# Make sure we have an info object try:
if isinstance(name, ZipInfo): # Make sure we have an info object
# 'name' is already an info object if isinstance(name, ZipInfo):
zinfo = name # 'name' is already an info object
else: zinfo = name
# Get info object for name else:
try: # Get info object for name
zinfo = self.getinfo(name) zinfo = self.getinfo(name)
except KeyError: zef_file.seek(zinfo.header_offset, 0)
if not self._filePassed:
zef_file.close() # Skip the file header:
raise fheader = zef_file.read(sizeFileHeader)
zef_file.seek(zinfo.header_offset, 0) if fheader[0:4] != stringFileHeader:
raise BadZipFile("Bad magic number for file header")
# Skip the file header:
fheader = zef_file.read(sizeFileHeader)
if fheader[0:4] != stringFileHeader:
raise BadZipFile("Bad magic number for file header")
fheader = struct.unpack(structFileHeader, fheader)
fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
if fheader[_FH_EXTRA_FIELD_LENGTH]:
zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
if zinfo.flag_bits & 0x800:
# UTF-8 filename
fname_str = fname.decode("utf-8")
else:
fname_str = fname.decode("cp437")
if fname_str != zinfo.orig_filename: fheader = struct.unpack(structFileHeader, fheader)
fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
if fheader[_FH_EXTRA_FIELD_LENGTH]:
zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
if zinfo.flag_bits & 0x800:
# UTF-8 filename
fname_str = fname.decode("utf-8")
else:
fname_str = fname.decode("cp437")
if fname_str != zinfo.orig_filename:
raise BadZipFile(
'File name in directory %r and header %r differ.'
% (zinfo.orig_filename, fname))
# check for encrypted flag & handle password
is_encrypted = zinfo.flag_bits & 0x1
zd = None
if is_encrypted:
if not pwd:
pwd = self.pwd
if not pwd:
raise RuntimeError("File %s is encrypted, password "
"required for extraction" % name)
zd = _ZipDecrypter(pwd)
# The first 12 bytes in the cypher stream is an encryption header
# used to strengthen the algorithm. The first 11 bytes are
# completely random, while the 12th contains the MSB of the CRC,
# or the MSB of the file time depending on the header type
# and is used to check the correctness of the password.
header = zef_file.read(12)
h = list(map(zd, header[0:12]))
if zinfo.flag_bits & 0x8:
# compare against the file type from extended local headers
check_byte = (zinfo._raw_time >> 8) & 0xff
else:
# compare against the CRC otherwise
check_byte = (zinfo.CRC >> 24) & 0xff
if h[11] != check_byte:
raise RuntimeError("Bad password for file", name)
return ZipExtFile(zef_file, mode, zinfo, zd,
close_fileobj=not self._filePassed)
except:
if not self._filePassed: if not self._filePassed:
zef_file.close() zef_file.close()
raise BadZipFile( raise
'File name in directory %r and header %r differ.'
% (zinfo.orig_filename, fname))
# check for encrypted flag & handle password
is_encrypted = zinfo.flag_bits & 0x1
zd = None
if is_encrypted:
if not pwd:
pwd = self.pwd
if not pwd:
if not self._filePassed:
zef_file.close()
raise RuntimeError("File %s is encrypted, "
"password required for extraction" % name)
zd = _ZipDecrypter(pwd)
# The first 12 bytes in the cypher stream is an encryption header
# used to strengthen the algorithm. The first 11 bytes are
# completely random, while the 12th contains the MSB of the CRC,
# or the MSB of the file time depending on the header type
# and is used to check the correctness of the password.
header = zef_file.read(12)
h = list(map(zd, header[0:12]))
if zinfo.flag_bits & 0x8:
# compare against the file type from extended local headers
check_byte = (zinfo._raw_time >> 8) & 0xff
else:
# compare against the CRC otherwise
check_byte = (zinfo.CRC >> 24) & 0xff
if h[11] != check_byte:
if not self._filePassed:
zef_file.close()
raise RuntimeError("Bad password for file", name)
return ZipExtFile(zef_file, mode, zinfo, zd,
close_fileobj=not self._filePassed)
def extract(self, member, path=None, pwd=None): def extract(self, member, path=None, pwd=None):
"""Extract a member from the archive to the current working directory, """Extract a member from the archive to the current working directory,
...@@ -1052,11 +1039,9 @@ class ZipFile: ...@@ -1052,11 +1039,9 @@ class ZipFile:
os.mkdir(targetpath) os.mkdir(targetpath)
return targetpath return targetpath
source = self.open(member, pwd=pwd) with self.open(member, pwd=pwd) as source, \
target = open(targetpath, "wb") open(targetpath, "wb") as target:
shutil.copyfileobj(source, target) shutil.copyfileobj(source, target)
source.close()
target.close()
return targetpath return targetpath
...@@ -1220,101 +1205,103 @@ class ZipFile: ...@@ -1220,101 +1205,103 @@ class ZipFile:
if self.fp is None: if self.fp is None:
return return
if self.mode in ("w", "a") and self._didModify: # write ending records try:
count = 0 if self.mode in ("w", "a") and self._didModify: # write ending records
pos1 = self.fp.tell() count = 0
for zinfo in self.filelist: # write central directory pos1 = self.fp.tell()
count = count + 1 for zinfo in self.filelist: # write central directory
dt = zinfo.date_time count = count + 1
dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] dt = zinfo.date_time
dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2) dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
extra = [] dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
if zinfo.file_size > ZIP64_LIMIT \ extra = []
or zinfo.compress_size > ZIP64_LIMIT: if zinfo.file_size > ZIP64_LIMIT \
extra.append(zinfo.file_size) or zinfo.compress_size > ZIP64_LIMIT:
extra.append(zinfo.compress_size) extra.append(zinfo.file_size)
file_size = 0xffffffff extra.append(zinfo.compress_size)
compress_size = 0xffffffff file_size = 0xffffffff
else: compress_size = 0xffffffff
file_size = zinfo.file_size else:
compress_size = zinfo.compress_size file_size = zinfo.file_size
compress_size = zinfo.compress_size
if zinfo.header_offset > ZIP64_LIMIT:
extra.append(zinfo.header_offset) if zinfo.header_offset > ZIP64_LIMIT:
header_offset = 0xffffffff extra.append(zinfo.header_offset)
else: header_offset = 0xffffffff
header_offset = zinfo.header_offset else:
header_offset = zinfo.header_offset
extra_data = zinfo.extra
if extra: extra_data = zinfo.extra
# Append a ZIP64 field to the extra's if extra:
extra_data = struct.pack( # Append a ZIP64 field to the extra's
'<HH' + 'Q'*len(extra), extra_data = struct.pack(
1, 8*len(extra), *extra) + extra_data '<HH' + 'Q'*len(extra),
1, 8*len(extra), *extra) + extra_data
extract_version = max(45, zinfo.extract_version)
create_version = max(45, zinfo.create_version) extract_version = max(45, zinfo.extract_version)
else: create_version = max(45, zinfo.create_version)
extract_version = zinfo.extract_version else:
create_version = zinfo.create_version extract_version = zinfo.extract_version
create_version = zinfo.create_version
try:
filename, flag_bits = zinfo._encodeFilenameFlags() try:
centdir = struct.pack(structCentralDir, filename, flag_bits = zinfo._encodeFilenameFlags()
stringCentralDir, create_version, centdir = struct.pack(structCentralDir,
zinfo.create_system, extract_version, zinfo.reserved, stringCentralDir, create_version,
flag_bits, zinfo.compress_type, dostime, dosdate, zinfo.create_system, extract_version, zinfo.reserved,
zinfo.CRC, compress_size, file_size, flag_bits, zinfo.compress_type, dostime, dosdate,
len(filename), len(extra_data), len(zinfo.comment), zinfo.CRC, compress_size, file_size,
0, zinfo.internal_attr, zinfo.external_attr, len(filename), len(extra_data), len(zinfo.comment),
header_offset) 0, zinfo.internal_attr, zinfo.external_attr,
except DeprecationWarning: header_offset)
print((structCentralDir, stringCentralDir, create_version, except DeprecationWarning:
zinfo.create_system, extract_version, zinfo.reserved, print((structCentralDir, stringCentralDir, create_version,
zinfo.flag_bits, zinfo.compress_type, dostime, dosdate, zinfo.create_system, extract_version, zinfo.reserved,
zinfo.CRC, compress_size, file_size, zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
len(zinfo.filename), len(extra_data), len(zinfo.comment), zinfo.CRC, compress_size, file_size,
0, zinfo.internal_attr, zinfo.external_attr, len(zinfo.filename), len(extra_data), len(zinfo.comment),
header_offset), file=sys.stderr) 0, zinfo.internal_attr, zinfo.external_attr,
raise header_offset), file=sys.stderr)
self.fp.write(centdir) raise
self.fp.write(filename) self.fp.write(centdir)
self.fp.write(extra_data) self.fp.write(filename)
self.fp.write(zinfo.comment) self.fp.write(extra_data)
self.fp.write(zinfo.comment)
pos2 = self.fp.tell()
# Write end-of-zip-archive record pos2 = self.fp.tell()
centDirCount = count # Write end-of-zip-archive record
centDirSize = pos2 - pos1 centDirCount = count
centDirOffset = pos1 centDirSize = pos2 - pos1
if (centDirCount >= ZIP_FILECOUNT_LIMIT or centDirOffset = pos1
centDirOffset > ZIP64_LIMIT or if (centDirCount >= ZIP_FILECOUNT_LIMIT or
centDirSize > ZIP64_LIMIT): centDirOffset > ZIP64_LIMIT or
# Need to write the ZIP64 end-of-archive records centDirSize > ZIP64_LIMIT):
zip64endrec = struct.pack( # Need to write the ZIP64 end-of-archive records
structEndArchive64, stringEndArchive64, zip64endrec = struct.pack(
44, 45, 45, 0, 0, centDirCount, centDirCount, structEndArchive64, stringEndArchive64,
centDirSize, centDirOffset) 44, 45, 45, 0, 0, centDirCount, centDirCount,
self.fp.write(zip64endrec) centDirSize, centDirOffset)
self.fp.write(zip64endrec)
zip64locrec = struct.pack(
structEndArchive64Locator, zip64locrec = struct.pack(
stringEndArchive64Locator, 0, pos2, 1) structEndArchive64Locator,
self.fp.write(zip64locrec) stringEndArchive64Locator, 0, pos2, 1)
centDirCount = min(centDirCount, 0xFFFF) self.fp.write(zip64locrec)
centDirSize = min(centDirSize, 0xFFFFFFFF) centDirCount = min(centDirCount, 0xFFFF)
centDirOffset = min(centDirOffset, 0xFFFFFFFF) centDirSize = min(centDirSize, 0xFFFFFFFF)
centDirOffset = min(centDirOffset, 0xFFFFFFFF)
endrec = struct.pack(structEndArchive, stringEndArchive,
0, 0, centDirCount, centDirCount, endrec = struct.pack(structEndArchive, stringEndArchive,
centDirSize, centDirOffset, len(self._comment)) 0, 0, centDirCount, centDirCount,
self.fp.write(endrec) centDirSize, centDirOffset, len(self._comment))
self.fp.write(self._comment) self.fp.write(endrec)
self.fp.flush() self.fp.write(self._comment)
self.fp.flush()
if not self._filePassed: finally:
self.fp.close() fp = self.fp
self.fp = None self.fp = None
if not self._filePassed:
fp.close()
class PyZipFile(ZipFile): class PyZipFile(ZipFile):
...@@ -1481,16 +1468,15 @@ def main(args = None): ...@@ -1481,16 +1468,15 @@ def main(args = None):
if len(args) != 2: if len(args) != 2:
print(USAGE) print(USAGE)
sys.exit(1) sys.exit(1)
zf = ZipFile(args[1], 'r') with ZipFile(args[1], 'r') as zf:
zf.printdir() zf.printdir()
zf.close()
elif args[0] == '-t': elif args[0] == '-t':
if len(args) != 2: if len(args) != 2:
print(USAGE) print(USAGE)
sys.exit(1) sys.exit(1)
zf = ZipFile(args[1], 'r') with ZipFile(args[1], 'r') as zf:
badfile = zf.testzip() badfile = zf.testzip()
if badfile: if badfile:
print("The following enclosed file is corrupted: {!r}".format(badfile)) print("The following enclosed file is corrupted: {!r}".format(badfile))
print("Done testing") print("Done testing")
...@@ -1500,20 +1486,19 @@ def main(args = None): ...@@ -1500,20 +1486,19 @@ def main(args = None):
print(USAGE) print(USAGE)
sys.exit(1) sys.exit(1)
zf = ZipFile(args[1], 'r') with ZipFile(args[1], 'r') as zf:
out = args[2] out = args[2]
for path in zf.namelist(): for path in zf.namelist():
if path.startswith('./'): if path.startswith('./'):
tgt = os.path.join(out, path[2:]) tgt = os.path.join(out, path[2:])
else: else:
tgt = os.path.join(out, path) tgt = os.path.join(out, path)
tgtdir = os.path.dirname(tgt) tgtdir = os.path.dirname(tgt)
if not os.path.exists(tgtdir): if not os.path.exists(tgtdir):
os.makedirs(tgtdir) os.makedirs(tgtdir)
with open(tgt, 'wb') as fp: with open(tgt, 'wb') as fp:
fp.write(zf.read(path)) fp.write(zf.read(path))
zf.close()
elif args[0] == '-c': elif args[0] == '-c':
if len(args) < 3: if len(args) < 3:
...@@ -1529,11 +1514,9 @@ def main(args = None): ...@@ -1529,11 +1514,9 @@ def main(args = None):
os.path.join(path, nm), os.path.join(zippath, nm)) os.path.join(path, nm), os.path.join(zippath, nm))
# else: ignore # else: ignore
zf = ZipFile(args[1], 'w', allowZip64=True) with ZipFile(args[1], 'w', allowZip64=True) as zf:
for src in args[2:]: for src in args[2:]:
addToZip(zf, src, os.path.basename(src)) addToZip(zf, src, os.path.basename(src))
zf.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -164,6 +164,9 @@ Core and Builtins ...@@ -164,6 +164,9 @@ Core and Builtins
Library Library
------- -------
- Issue #16408: Fix file descriptors not being closed in error conditions
in the zipfile module. Patch by Serhiy Storchaka.
- Issue #16140: The subprocess module no longer double closes its child - Issue #16140: The subprocess module no longer double closes its child
subprocess.PIPE parent file descriptors on child error prior to exec(). subprocess.PIPE parent file descriptors on child error prior to exec().
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment