Commit 09aa7520 authored by Gregory P. Smith's avatar Gregory P. Smith

Refactor recently added bugfix into more testable code by using a

method for windows file name sanitization.  Splits the unittest up
into several based on platform.
parent 6d29628d
...@@ -538,8 +538,15 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -538,8 +538,15 @@ class TestsWithSourceFile(unittest.TestCase):
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
self.assertEqual(f.read(), content) self.assertEqual(f.read(), content)
def test_extract_hackers_arcnames(self): def test_sanitize_windows_name(self):
hacknames = [ san = zipfile.ZipFile._sanitize_windows_name
# Passing pathsep in allows this test to work regardless of platform.
self.assertEqual(san(r',,?,C:,foo,bar/z', ','), r'_,C_,foo,bar/z')
self.assertEqual(san(r'a\b,c<d>e|f"g?h*i', ','), r'a\b,c_d_e_f_g_h_i')
self.assertEqual(san('../../foo../../ba..r', '/'), r'foo/ba..r')
def test_extract_hackers_arcnames_common_cases(self):
common_hacknames = [
('../foo/bar', 'foo/bar'), ('../foo/bar', 'foo/bar'),
('foo/../bar', 'foo/bar'), ('foo/../bar', 'foo/bar'),
('foo/../../bar', 'foo/bar'), ('foo/../../bar', 'foo/bar'),
...@@ -549,8 +556,12 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -549,8 +556,12 @@ class TestsWithSourceFile(unittest.TestCase):
('/foo/../bar', 'foo/bar'), ('/foo/../bar', 'foo/bar'),
('/foo/../../bar', 'foo/bar'), ('/foo/../../bar', 'foo/bar'),
] ]
if os.path.sep == '\\': # Windows. self._test_extract_hackers_arcnames(common_hacknames)
hacknames.extend([
@unittest.skipIf(os.path.sep != '\\', 'Requires \\ as path separator.')
def test_extract_hackers_arcnames_windows_only(self):
"""Test combination of path fixing and windows name sanitization."""
windows_hacknames = [
(r'..\foo\bar', 'foo/bar'), (r'..\foo\bar', 'foo/bar'),
(r'..\/foo\/bar', 'foo/bar'), (r'..\/foo\/bar', 'foo/bar'),
(r'foo/\..\/bar', 'foo/bar'), (r'foo/\..\/bar', 'foo/bar'),
...@@ -570,14 +581,19 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -570,14 +581,19 @@ class TestsWithSourceFile(unittest.TestCase):
(r'C:/../C:/foo/bar', 'C_/foo/bar'), (r'C:/../C:/foo/bar', 'C_/foo/bar'),
(r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'), (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'),
('../../foo../../ba..r', 'foo/ba..r'), ('../../foo../../ba..r', 'foo/ba..r'),
]) ]
else: # Unix self._test_extract_hackers_arcnames(windows_hacknames)
hacknames.extend([
@unittest.skipIf(os.path.sep != '/', r'Requires / as path separator.')
def test_extract_hackers_arcnames_posix_only(self):
posix_hacknames = [
('//foo/bar', 'foo/bar'), ('//foo/bar', 'foo/bar'),
('../../foo../../ba..r', 'foo../ba..r'), ('../../foo../../ba..r', 'foo../ba..r'),
(r'foo/..\bar', r'foo/..\bar'), (r'foo/..\bar', r'foo/..\bar'),
]) ]
self._test_extract_hackers_arcnames(posix_hacknames)
def _test_extract_hackers_arcnames(self, hacknames):
for arcname, fixedname in hacknames: for arcname, fixedname in hacknames:
content = b'foobar' + arcname.encode() content = b'foobar' + arcname.encode()
with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_STORED) as zipfp: with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_STORED) as zipfp:
...@@ -594,7 +610,8 @@ class TestsWithSourceFile(unittest.TestCase): ...@@ -594,7 +610,8 @@ class TestsWithSourceFile(unittest.TestCase):
with zipfile.ZipFile(TESTFN2, 'r') as zipfp: with zipfile.ZipFile(TESTFN2, 'r') as zipfp:
writtenfile = zipfp.extract(arcname, targetpath) writtenfile = zipfp.extract(arcname, targetpath)
self.assertEqual(writtenfile, correctfile, self.assertEqual(writtenfile, correctfile,
msg="extract %r" % arcname) msg='extract %r: %r != %r' %
(arcname, writtenfile, correctfile))
self.check_file(correctfile, content) self.check_file(correctfile, content)
shutil.rmtree('target') shutil.rmtree('target')
......
...@@ -883,6 +883,7 @@ class ZipFile: ...@@ -883,6 +883,7 @@ class ZipFile:
""" """
fp = None # Set here since __del__ checks it fp = None # Set here since __del__ checks it
_windows_illegal_name_trans_table = None
def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False): def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False):
"""Open the ZIP file with mode read "r", write "w" or append "a".""" """Open the ZIP file with mode read "r", write "w" or append "a"."""
...@@ -1223,6 +1224,21 @@ class ZipFile: ...@@ -1223,6 +1224,21 @@ class ZipFile:
for zipinfo in members: for zipinfo in members:
self.extract(zipinfo, path, pwd) self.extract(zipinfo, path, pwd)
@classmethod
def _sanitize_windows_name(cls, arcname, pathsep):
"""Replace bad characters and remove trailing dots from parts."""
table = cls._windows_illegal_name_trans_table
if not table:
illegal = ':<>|"?*'
table = str.maketrans(illegal, '_' * len(illegal))
cls._windows_illegal_name_trans_table = table
arcname = arcname.translate(table)
# remove trailing dots
arcname = (x.rstrip('.') for x in arcname.split(pathsep))
# rejoin, removing empty parts.
arcname = pathsep.join(x for x in arcname if x)
return arcname
def _extract_member(self, member, targetpath, pwd): def _extract_member(self, member, targetpath, pwd):
"""Extract the ZipInfo object 'member' to a physical """Extract the ZipInfo object 'member' to a physical
file on the path targetpath. file on the path targetpath.
...@@ -1236,16 +1252,12 @@ class ZipFile: ...@@ -1236,16 +1252,12 @@ class ZipFile:
# interpret absolute pathname as relative, remove drive letter or # interpret absolute pathname as relative, remove drive letter or
# UNC path, redundant separators, "." and ".." components. # UNC path, redundant separators, "." and ".." components.
arcname = os.path.splitdrive(arcname)[1] arcname = os.path.splitdrive(arcname)[1]
invalid_path_parts = ('', os.path.curdir, os.path.pardir)
arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) arcname = os.path.sep.join(x for x in arcname.split(os.path.sep)
if x not in ('', os.path.curdir, os.path.pardir)) if x not in invalid_path_parts)
if os.path.sep == '\\': if os.path.sep == '\\':
# filter illegal characters on Windows # filter illegal characters on Windows
illegal = ':<>|"?*' arcname = self._sanitize_windows_name(arcname, os.path.sep)
table = str.maketrans(illegal, '_' * len(illegal))
arcname = arcname.translate(table)
# remove trailing dots
arcname = (x.rstrip('.') for x in arcname.split(os.path.sep))
arcname = os.path.sep.join(x for x in arcname if x)
targetpath = os.path.join(targetpath, arcname) targetpath = os.path.join(targetpath, arcname)
targetpath = os.path.normpath(targetpath) targetpath = os.path.normpath(targetpath)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment