shutil.py 39.4 KB
Newer Older
1
"""Utility functions for copying and archiving files and directory trees.
2

3
XXX The functions here don't copy the resource fork or other metadata on Mac.
4 5

"""
Guido van Rossum's avatar
Guido van Rossum committed
6

Guido van Rossum's avatar
Guido van Rossum committed
7
import os
8
import sys
9
import stat
Georg Brandl's avatar
Georg Brandl committed
10
import fnmatch
11
import collections
12
import errno
13 14 15 16 17 18 19

try:
    import zlib
    del zlib
    _ZLIB_SUPPORTED = True
except ImportError:
    _ZLIB_SUPPORTED = False
20

21 22
try:
    import bz2
Florent Xicluna's avatar
Florent Xicluna committed
23
    del bz2
24
    _BZ2_SUPPORTED = True
25
except ImportError:
26 27
    _BZ2_SUPPORTED = False

28 29 30 31 32 33 34
try:
    import lzma
    del lzma
    _LZMA_SUPPORTED = True
except ImportError:
    _LZMA_SUPPORTED = False

35 36
try:
    from pwd import getpwnam
37
except ImportError:
38 39 40 41
    getpwnam = None

try:
    from grp import getgrnam
42
except ImportError:
43
    getgrnam = None
Guido van Rossum's avatar
Guido van Rossum committed
44

45 46 47
__all__ = ["copyfileobj", "copyfile", "copymode", "copystat", "copy", "copy2",
           "copytree", "move", "rmtree", "Error", "SpecialFileError",
           "ExecError", "make_archive", "get_archive_formats",
48 49
           "register_archive_format", "unregister_archive_format",
           "get_unpack_formats", "register_unpack_format",
Éric Araujo's avatar
Éric Araujo committed
50
           "unregister_unpack_format", "unpack_archive",
51 52
           "ignore_patterns", "chown", "which", "get_terminal_size",
           "SameFileError"]
Éric Araujo's avatar
Éric Araujo committed
53
           # disk_usage is added later, if available on the platform
54

55
class Error(OSError):
56
    pass
Guido van Rossum's avatar
Guido van Rossum committed
57

58 59 60
class SameFileError(Error):
    """Raised when source and destination are the same file."""

61
class SpecialFileError(OSError):
62 63 64
    """Raised when trying to do a kind of operation (e.g. copying) which is
    not supported on a special file (e.g. a named pipe)"""

65
class ExecError(OSError):
66 67
    """Raised when a command could not be executed"""

68
class ReadError(OSError):
69 70 71
    """Raised when an archive cannot be read"""

class RegistryError(Exception):
72
    """Raised when a registry operation with the archiving
73
    and unpacking registries fails"""
74 75


76 77 78 79 80 81 82 83
def copyfileobj(fsrc, fdst, length=16*1024):
    """copy data from file-like object fsrc to file-like object fdst"""
    while 1:
        buf = fsrc.read(length)
        if not buf:
            break
        fdst.write(buf)

84 85
def _samefile(src, dst):
    # Macintosh, Unix.
86
    if hasattr(os.path, 'samefile'):
87 88 89 90
        try:
            return os.path.samefile(src, dst)
        except OSError:
            return False
91 92 93 94

    # All other platforms: check for same pathname.
    return (os.path.normcase(os.path.abspath(src)) ==
            os.path.normcase(os.path.abspath(dst)))
Tim Peters's avatar
Tim Peters committed
95

96
def copyfile(src, dst, *, follow_symlinks=True):
97 98
    """Copy data from src to dst.

99
    If follow_symlinks is not set and src is a symbolic link, a new
100 101 102
    symlink will be created instead of copying the file it points to.

    """
103
    if _samefile(src, dst):
104
        raise SameFileError("{!r} and {!r} are the same file".format(src, dst))
105

106 107 108 109 110 111
    for fn in [src, dst]:
        try:
            st = os.stat(fn)
        except OSError:
            # File most likely does not exist
            pass
112 113 114 115
        else:
            # XXX What about other special files? (sockets, devices...)
            if stat.S_ISFIFO(st.st_mode):
                raise SpecialFileError("`%s` is a named pipe" % fn)
116

117
    if not follow_symlinks and os.path.islink(src):
118 119 120 121 122
        os.symlink(os.readlink(src), dst)
    else:
        with open(src, 'rb') as fsrc:
            with open(dst, 'wb') as fdst:
                copyfileobj(fsrc, fdst)
123
    return dst
124

125
def copymode(src, dst, *, follow_symlinks=True):
126
    """Copy mode bits from src to dst.
Guido van Rossum's avatar
Guido van Rossum committed
127

128 129 130
    If follow_symlinks is not set, symlinks aren't followed if and only
    if both `src` and `dst` are symlinks.  If `lchmod` isn't available
    (e.g. Linux) this method does nothing.
131 132

    """
133
    if not follow_symlinks and os.path.islink(src) and os.path.islink(dst):
134 135 136 137 138 139 140 141
        if hasattr(os, 'lchmod'):
            stat_func, chmod_func = os.lstat, os.lchmod
        else:
            return
    elif hasattr(os, 'chmod'):
        stat_func, chmod_func = os.stat, os.chmod
    else:
        return
Guido van Rossum's avatar
Guido van Rossum committed
142

143 144 145
    st = stat_func(src)
    chmod_func(dst, stat.S_IMODE(st.st_mode))

146
if hasattr(os, 'listxattr'):
147
    def _copyxattr(src, dst, *, follow_symlinks=True):
148 149 150 151
        """Copy extended filesystem attributes from `src` to `dst`.

        Overwrite existing attributes.

152
        If `follow_symlinks` is false, symlinks won't be followed.
153 154 155

        """

156 157 158 159 160 161 162
        try:
            names = os.listxattr(src, follow_symlinks=follow_symlinks)
        except OSError as e:
            if e.errno not in (errno.ENOTSUP, errno.ENODATA):
                raise
            return
        for name in names:
163
            try:
164 165
                value = os.getxattr(src, name, follow_symlinks=follow_symlinks)
                os.setxattr(dst, name, value, follow_symlinks=follow_symlinks)
166 167 168 169 170 171 172
            except OSError as e:
                if e.errno not in (errno.EPERM, errno.ENOTSUP, errno.ENODATA):
                    raise
else:
    def _copyxattr(*args, **kwargs):
        pass

173
def copystat(src, dst, *, follow_symlinks=True):
174 175
    """Copy all stat info (mode bits, atime, mtime, flags) from src to dst.

176
    If the optional flag `follow_symlinks` is not set, symlinks aren't followed if and
177 178 179
    only if both `src` and `dst` are symlinks.

    """
180
    def _nop(*args, ns=None, follow_symlinks=None):
181 182
        pass

183
    # follow symlinks (aka don't not follow symlinks)
184
    follow = follow_symlinks or not (os.path.islink(src) and os.path.islink(dst))
185 186 187 188
    if follow:
        # use the real function if it exists
        def lookup(name):
            return getattr(os, name, _nop)
189
    else:
190 191 192 193 194 195 196 197 198
        # use the real function only if it exists
        # *and* it supports follow_symlinks
        def lookup(name):
            fn = getattr(os, name, _nop)
            if fn in os.supports_follow_symlinks:
                return fn
            return _nop

    st = lookup("stat")(src, follow_symlinks=follow)
199
    mode = stat.S_IMODE(st.st_mode)
200 201 202 203 204 205 206 207 208
    lookup("utime")(dst, ns=(st.st_atime_ns, st.st_mtime_ns),
        follow_symlinks=follow)
    try:
        lookup("chmod")(dst, mode, follow_symlinks=follow)
    except NotImplementedError:
        # if we got a NotImplementedError, it's because
        #   * follow_symlinks=False,
        #   * lchown() is unavailable, and
        #   * either
209
        #       * fchownat() is unavailable or
210 211 212 213 214 215
        #       * fchownat() doesn't implement AT_SYMLINK_NOFOLLOW.
        #         (it returned ENOSUP.)
        # therefore we're out of options--we simply cannot chown the
        # symlink.  give up, suppress the error.
        # (which is what shutil always did in this circumstance.)
        pass
216
    if hasattr(st, 'st_flags'):
217
        try:
218
            lookup("chflags")(dst, st.st_flags, follow_symlinks=follow)
219
        except OSError as why:
220 221 222 223
            for err in 'EOPNOTSUPP', 'ENOTSUP':
                if hasattr(errno, err) and why.errno == getattr(errno, err):
                    break
            else:
224
                raise
225
    _copyxattr(src, dst, follow_symlinks=follow)
226

227
def copy(src, dst, *, follow_symlinks=True):
228
    """Copy data and mode bits ("cp src dst"). Return the file's destination.
Tim Peters's avatar
Tim Peters committed
229

230 231
    The destination may be a directory.

232
    If follow_symlinks is false, symlinks won't be followed. This
233 234
    resembles GNU's "cp -P src dst".

235 236 237
    If source and destination are the same file, a SameFileError will be
    raised.

238
    """
Guido van Rossum's avatar
Guido van Rossum committed
239
    if os.path.isdir(dst):
240
        dst = os.path.join(dst, os.path.basename(src))
241 242
    copyfile(src, dst, follow_symlinks=follow_symlinks)
    copymode(src, dst, follow_symlinks=follow_symlinks)
243
    return dst
Guido van Rossum's avatar
Guido van Rossum committed
244

245
def copy2(src, dst, *, follow_symlinks=True):
246 247
    """Copy data and all stat info ("cp -p src dst"). Return the file's
    destination."
248 249 250

    The destination may be a directory.

251
    If follow_symlinks is false, symlinks won't be followed. This
252 253
    resembles GNU's "cp -P src dst".

254
    """
Guido van Rossum's avatar
Guido van Rossum committed
255
    if os.path.isdir(dst):
256
        dst = os.path.join(dst, os.path.basename(src))
257 258
    copyfile(src, dst, follow_symlinks=follow_symlinks)
    copystat(src, dst, follow_symlinks=follow_symlinks)
259
    return dst
Guido van Rossum's avatar
Guido van Rossum committed
260

Georg Brandl's avatar
Georg Brandl committed
261 262
def ignore_patterns(*patterns):
    """Function that can be used as copytree() ignore parameter.
263

Georg Brandl's avatar
Georg Brandl committed
264 265 266 267 268 269 270 271 272
    Patterns is a sequence of glob-style patterns
    that are used to exclude files"""
    def _ignore_patterns(path, names):
        ignored_names = []
        for pattern in patterns:
            ignored_names.extend(fnmatch.filter(names, pattern))
        return set(ignored_names)
    return _ignore_patterns

273 274
def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2,
             ignore_dangling_symlinks=False):
275
    """Recursively copy a directory tree.
276 277

    The destination directory must not already exist.
278
    If exception(s) occur, an Error is raised with a list of reasons.
279 280 281 282

    If the optional symlinks flag is true, symbolic links in the
    source tree result in symbolic links in the destination tree; if
    it is false, the contents of the files pointed to by symbolic
283 284 285 286 287
    links are copied. If the file pointed by the symlink doesn't
    exist, an exception will be added in the list of errors raised in
    an Error exception at the end of the copy process.

    You can set the optional ignore_dangling_symlinks flag to true if you
288 289
    want to silence this exception. Notice that this has no effect on
    platforms that don't support os.symlink.
290

Georg Brandl's avatar
Georg Brandl committed
291 292 293 294 295 296 297 298 299 300 301 302
    The optional ignore argument is a callable. If given, it
    is called with the `src` parameter, which is the directory
    being visited by copytree(), and `names` which is the list of
    `src` contents, as returned by os.listdir():

        callable(src, names) -> ignored_names

    Since copytree() is called recursively, the callable will be
    called once for each directory that is copied. It returns a
    list of names relative to the `src` directory that should
    not be copied.

303 304 305 306
    The optional copy_function argument is a callable that will be used
    to copy each file. It will be called with the source path and the
    destination path as arguments. By default, copy2() is used, but any
    function that supports the same signature (like copy()) can be used.
307 308

    """
Guido van Rossum's avatar
Guido van Rossum committed
309
    names = os.listdir(src)
Georg Brandl's avatar
Georg Brandl committed
310 311 312 313 314
    if ignore is not None:
        ignored_names = ignore(src, names)
    else:
        ignored_names = set()

Johannes Gijsbers's avatar
Johannes Gijsbers committed
315
    os.makedirs(dst)
316
    errors = []
Guido van Rossum's avatar
Guido van Rossum committed
317
    for name in names:
Georg Brandl's avatar
Georg Brandl committed
318 319
        if name in ignored_names:
            continue
320 321 322
        srcname = os.path.join(src, name)
        dstname = os.path.join(dst, name)
        try:
323
            if os.path.islink(srcname):
324
                linkto = os.readlink(srcname)
325
                if symlinks:
326 327 328
                    # We can't just leave it to `copy_function` because legacy
                    # code with a custom `copy_function` may rely on copytree
                    # doing the right thing.
329
                    os.symlink(linkto, dstname)
330
                    copystat(srcname, dstname, follow_symlinks=not symlinks)
331 332 333 334 335
                else:
                    # ignore dangling symlink if the flag is on
                    if not os.path.exists(linkto) and ignore_dangling_symlinks:
                        continue
                    # otherwise let the copy occurs. copy2 will raise an error
336 337 338 339 340
                    if os.path.isdir(srcname):
                        copytree(srcname, dstname, symlinks, ignore,
                                 copy_function)
                    else:
                        copy_function(srcname, dstname)
341
            elif os.path.isdir(srcname):
342
                copytree(srcname, dstname, symlinks, ignore, copy_function)
343
            else:
344
                # Will raise a SpecialFileError for unsupported file types
345
                copy_function(srcname, dstname)
346 347
        # catch the Error from the recursive copytree so that we can
        # continue with other files
348
        except Error as err:
349
            errors.extend(err.args[0])
350
        except OSError as why:
351
            errors.append((srcname, dstname, str(why)))
352 353
    try:
        copystat(src, dst)
354
    except OSError as why:
355
        # Copying file access times may fail on Windows
356
        if getattr(why, 'winerror', None) is None:
357
            errors.append((src, dst, str(why)))
358
    if errors:
359
        raise Error(errors)
360
    return dst
361

362 363
# version vulnerable to race conditions
def _rmtree_unsafe(path, onerror):
364 365 366 367 368 369 370 371
    try:
        if os.path.islink(path):
            # symlinks to directories are forbidden, see bug #1669
            raise OSError("Cannot call rmtree on a symbolic link")
    except OSError:
        onerror(os.path.islink, path, sys.exc_info())
        # can't continue even if onerror hook returns
        return
372 373 374
    names = []
    try:
        names = os.listdir(path)
375
    except OSError:
376 377 378 379 380
        onerror(os.listdir, path, sys.exc_info())
    for name in names:
        fullname = os.path.join(path, name)
        try:
            mode = os.lstat(fullname).st_mode
381
        except OSError:
382 383
            mode = 0
        if stat.S_ISDIR(mode):
384
            _rmtree_unsafe(fullname, onerror)
385
        else:
386
            try:
387
                os.unlink(fullname)
388
            except OSError:
389
                onerror(os.unlink, fullname, sys.exc_info())
390 391
    try:
        os.rmdir(path)
392
    except OSError:
393
        onerror(os.rmdir, path, sys.exc_info())
394

395 396 397 398
# Version using fd-based APIs to protect against races
def _rmtree_safe_fd(topfd, path, onerror):
    names = []
    try:
399
        names = os.listdir(topfd)
400 401
    except OSError as err:
        err.filename = path
402
        onerror(os.listdir, path, sys.exc_info())
403 404 405
    for name in names:
        fullname = os.path.join(path, name)
        try:
406
            orig_st = os.stat(name, dir_fd=topfd, follow_symlinks=False)
407
            mode = orig_st.st_mode
408
        except OSError:
409 410 411
            mode = 0
        if stat.S_ISDIR(mode):
            try:
412
                dirfd = os.open(name, os.O_RDONLY, dir_fd=topfd)
413
            except OSError:
414
                onerror(os.open, fullname, sys.exc_info())
415 416 417 418
            else:
                try:
                    if os.path.samestat(orig_st, os.fstat(dirfd)):
                        _rmtree_safe_fd(dirfd, fullname, onerror)
419 420
                        try:
                            os.rmdir(name, dir_fd=topfd)
421
                        except OSError:
422
                            onerror(os.rmdir, fullname, sys.exc_info())
423 424 425 426 427 428 429 430 431
                    else:
                        try:
                            # This can only happen if someone replaces
                            # a directory with a symlink after the call to
                            # stat.S_ISDIR above.
                            raise OSError("Cannot call rmtree on a symbolic "
                                          "link")
                        except OSError:
                            onerror(os.path.islink, fullname, sys.exc_info())
432 433 434 435
                finally:
                    os.close(dirfd)
        else:
            try:
436
                os.unlink(name, dir_fd=topfd)
437
            except OSError:
438
                onerror(os.unlink, fullname, sys.exc_info())
439

440 441 442 443
_use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <=
                     os.supports_dir_fd and
                     os.listdir in os.supports_fd and
                     os.stat in os.supports_follow_symlinks)
444

445 446 447 448 449
def rmtree(path, ignore_errors=False, onerror=None):
    """Recursively delete a directory tree.

    If ignore_errors is set, errors are ignored; otherwise, if onerror
    is set, it is called to handle the error with arguments (func,
450
    path, exc_info) where func is platform and implementation dependent;
451 452 453 454 455 456 457 458 459 460 461 462
    path is the argument to that function that caused it to fail; and
    exc_info is a tuple returned by sys.exc_info().  If ignore_errors
    is false and onerror is None, an exception is raised.

    """
    if ignore_errors:
        def onerror(*args):
            pass
    elif onerror is None:
        def onerror(*args):
            raise
    if _use_fd_functions:
463 464 465
        # While the unsafe rmtree works fine on bytes, the fd based does not.
        if isinstance(path, bytes):
            path = os.fsdecode(path)
466 467 468 469 470 471 472 473 474 475 476 477 478
        # Note: To guard against symlink races, we use the standard
        # lstat()/open()/fstat() trick.
        try:
            orig_st = os.lstat(path)
        except Exception:
            onerror(os.lstat, path, sys.exc_info())
            return
        try:
            fd = os.open(path, os.O_RDONLY)
        except Exception:
            onerror(os.lstat, path, sys.exc_info())
            return
        try:
479
            if os.path.samestat(orig_st, os.fstat(fd)):
480
                _rmtree_safe_fd(fd, path, onerror)
481 482
                try:
                    os.rmdir(path)
483
                except OSError:
484
                    onerror(os.rmdir, path, sys.exc_info())
485
            else:
486 487 488 489 490
                try:
                    # symlinks to directories are forbidden, see bug #1669
                    raise OSError("Cannot call rmtree on a symbolic link")
                except OSError:
                    onerror(os.path.islink, path, sys.exc_info())
491 492 493 494 495
        finally:
            os.close(fd)
    else:
        return _rmtree_unsafe(path, onerror)

496 497 498
# Allow introspection of whether or not the hardening against symlink
# attacks is supported on the current platform
rmtree.avoids_symlink_attacks = _use_fd_functions
499 500 501 502

def _basename(path):
    # A basename() variant which first strips the trailing slash, if present.
    # Thus we always get the last component of the path, even for directories.
503 504
    sep = os.path.sep + (os.path.altsep or '')
    return os.path.basename(path.rstrip(sep))
505

506
def move(src, dst, copy_function=copy2):
507
    """Recursively move a file or directory to another location. This is
508 509
    similar to the Unix "mv" command. Return the file or directory's
    destination.
510 511 512 513

    If the destination is a directory or a symlink to a directory, the source
    is moved inside the directory. The destination path must not already
    exist.
514

515 516 517 518
    If the destination already exists but is not a directory, it may be
    overwritten depending on os.rename() semantics.

    If the destination is on our current filesystem, then rename() is used.
519 520 521 522
    Otherwise, src is copied to the destination and then removed. Symlinks are
    recreated under the new name if os.rename() fails because of cross
    filesystem renames.

523 524 525 526 527
    The optional `copy_function` argument is a callable that will be used
    to copy the source or it will be delegated to `copytree`.
    By default, copy2() is used, but any function that supports the same
    signature (like copy()) can be used.

528 529 530 531
    A lot more could be done here...  A look at a mv.c shows a lot of
    the issues this implementation glosses over.

    """
532 533
    real_dst = dst
    if os.path.isdir(dst):
534 535 536 537 538 539
        if _samefile(src, dst):
            # We might be on a case insensitive filesystem,
            # perform the rename anyway.
            os.rename(src, dst)
            return

540 541 542
        real_dst = os.path.join(dst, _basename(src))
        if os.path.exists(real_dst):
            raise Error("Destination path '%s' already exists" % real_dst)
543
    try:
544
        os.rename(src, real_dst)
545
    except OSError:
546 547 548 549 550
        if os.path.islink(src):
            linkto = os.readlink(src)
            os.symlink(linkto, real_dst)
            os.unlink(src)
        elif os.path.isdir(src):
551
            if _destinsrc(src, dst):
552 553 554 555
                raise Error("Cannot move a directory '%s' into itself"
                            " '%s'." % (src, dst))
            copytree(src, real_dst, copy_function=copy_function,
                     symlinks=True)
556 557
            rmtree(src)
        else:
558
            copy_function(src, real_dst)
559
            os.unlink(src)
560
    return real_dst
561

562
def _destinsrc(src, dst):
563 564
    src = os.path.abspath(src)
    dst = os.path.abspath(dst)
565 566 567 568 569
    if not src.endswith(os.path.sep):
        src += os.path.sep
    if not dst.endswith(os.path.sep):
        dst += os.path.sep
    return dst.startswith(src)
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599

def _get_gid(name):
    """Returns a gid, given a group name."""
    if getgrnam is None or name is None:
        return None
    try:
        result = getgrnam(name)
    except KeyError:
        result = None
    if result is not None:
        return result[2]
    return None

def _get_uid(name):
    """Returns an uid, given a user name."""
    if getpwnam is None or name is None:
        return None
    try:
        result = getpwnam(name)
    except KeyError:
        result = None
    if result is not None:
        return result[2]
    return None

def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,
                  owner=None, group=None, logger=None):
    """Create a (possibly compressed) tar file from all the files under
    'base_dir'.

600
    'compress' must be "gzip" (the default), "bzip2", "xz", or None.
601 602 603 604 605

    'owner' and 'group' can be used to define an owner and a group for the
    archive that is being built. If not provided, the current owner and group
    will be used.

606
    The output tar file will be named 'base_name' +  ".tar", possibly plus
607
    the appropriate compression extension (".gz", ".bz2", or ".xz").
608 609 610

    Returns the output filename.
    """
611 612 613 614 615 616 617 618 619
    if compress is None:
        tar_compression = ''
    elif _ZLIB_SUPPORTED and compress == 'gzip':
        tar_compression = 'gz'
    elif _BZ2_SUPPORTED and compress == 'bzip2':
        tar_compression = 'bz2'
    elif _LZMA_SUPPORTED and compress == 'xz':
        tar_compression = 'xz'
    else:
620 621
        raise ValueError("bad value for 'compress', or compression format not "
                         "supported : {0}".format(compress))
622

623 624 625 626
    import tarfile  # late import for breaking circular dependency

    compress_ext = '.' + tar_compression if compress else ''
    archive_name = base_name + '.tar' + compress_ext
627
    archive_dir = os.path.dirname(archive_name)
628

629
    if archive_dir and not os.path.exists(archive_dir):
630
        if logger is not None:
Éric Araujo's avatar
Éric Araujo committed
631
            logger.info("creating %s", archive_dir)
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
        if not dry_run:
            os.makedirs(archive_dir)

    # creating the tarball
    if logger is not None:
        logger.info('Creating tar archive')

    uid = _get_uid(owner)
    gid = _get_gid(group)

    def _set_uid_gid(tarinfo):
        if gid is not None:
            tarinfo.gid = gid
            tarinfo.gname = group
        if uid is not None:
            tarinfo.uid = uid
            tarinfo.uname = owner
        return tarinfo

    if not dry_run:
652
        tar = tarfile.open(archive_name, 'w|%s' % tar_compression)
653 654 655 656 657 658 659 660 661 662
        try:
            tar.add(base_dir, filter=_set_uid_gid)
        finally:
            tar.close()

    return archive_name

def _make_zipfile(base_name, base_dir, verbose=0, dry_run=0, logger=None):
    """Create a zip file from all the files under 'base_dir'.

663 664
    The output zip file will be named 'base_name' + ".zip".  Returns the
    name of the output zip file.
665
    """
666
    import zipfile  # late import for breaking circular dependency
667

668 669 670
    zip_filename = base_name + ".zip"
    archive_dir = os.path.dirname(base_name)

671
    if archive_dir and not os.path.exists(archive_dir):
672 673 674 675 676
        if logger is not None:
            logger.info("creating %s", archive_dir)
        if not dry_run:
            os.makedirs(archive_dir)

677 678 679
    if logger is not None:
        logger.info("creating '%s' and adding '%s' to it",
                    zip_filename, base_dir)
680

681 682 683
    if not dry_run:
        with zipfile.ZipFile(zip_filename, "w",
                             compression=zipfile.ZIP_DEFLATED) as zf:
684
            path = os.path.normpath(base_dir)
685 686 687 688
            if path != os.curdir:
                zf.write(path, path)
                if logger is not None:
                    logger.info("adding '%s'", path)
689
            for dirpath, dirnames, filenames in os.walk(base_dir):
690 691 692 693 694
                for name in sorted(dirnames):
                    path = os.path.normpath(os.path.join(dirpath, name))
                    zf.write(path, path)
                    if logger is not None:
                        logger.info("adding '%s'", path)
695 696 697 698 699 700
                for name in filenames:
                    path = os.path.normpath(os.path.join(dirpath, name))
                    if os.path.isfile(path):
                        zf.write(path, path)
                        if logger is not None:
                            logger.info("adding '%s'", path)
701 702 703 704 705

    return zip_filename

_ARCHIVE_FORMATS = {
    'tar':   (_make_tarball, [('compress', None)], "uncompressed tar file"),
706 707 708 709 710 711
}

if _ZLIB_SUPPORTED:
    _ARCHIVE_FORMATS['gztar'] = (_make_tarball, [('compress', 'gzip')],
                                "gzip'ed tar-file")
    _ARCHIVE_FORMATS['zip'] = (_make_zipfile, [], "ZIP file")
712

713 714 715 716
if _BZ2_SUPPORTED:
    _ARCHIVE_FORMATS['bztar'] = (_make_tarball, [('compress', 'bzip2')],
                                "bzip2'ed tar-file")

717 718 719 720
if _LZMA_SUPPORTED:
    _ARCHIVE_FORMATS['xztar'] = (_make_tarball, [('compress', 'xz')],
                                "xz'ed tar-file")

721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
def get_archive_formats():
    """Returns a list of supported formats for archiving and unarchiving.

    Each element of the returned sequence is a tuple (name, description)
    """
    formats = [(name, registry[2]) for name, registry in
               _ARCHIVE_FORMATS.items()]
    formats.sort()
    return formats

def register_archive_format(name, function, extra_args=None, description=''):
    """Registers an archive format.

    name is the name of the format. function is the callable that will be
    used to create archives. If provided, extra_args is a sequence of
    (name, value) tuples that will be passed as arguments to the callable.
    description can be provided to describe the format, and will be returned
    by the get_archive_formats() function.
    """
    if extra_args is None:
        extra_args = []
742
    if not callable(function):
743 744 745 746
        raise TypeError('The %s object is not callable' % function)
    if not isinstance(extra_args, (tuple, list)):
        raise TypeError('extra_args needs to be a sequence')
    for element in extra_args:
747
        if not isinstance(element, (tuple, list)) or len(element) !=2:
748 749 750 751 752 753 754 755 756 757 758 759
            raise TypeError('extra_args elements are : (arg_name, value)')

    _ARCHIVE_FORMATS[name] = (function, extra_args, description)

def unregister_archive_format(name):
    del _ARCHIVE_FORMATS[name]

def make_archive(base_name, format, root_dir=None, base_dir=None, verbose=0,
                 dry_run=0, owner=None, group=None, logger=None):
    """Create an archive file (eg. zip or tar).

    'base_name' is the name of the file to create, minus any format-specific
760 761
    extension; 'format' is the archive format: one of "zip", "tar", "gztar",
    "bztar", or "xztar".  Or any other registered format.
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788

    'root_dir' is a directory that will be the root directory of the
    archive; ie. we typically chdir into 'root_dir' before creating the
    archive.  'base_dir' is the directory where we start archiving from;
    ie. 'base_dir' will be the common prefix of all files and
    directories in the archive.  'root_dir' and 'base_dir' both default
    to the current directory.  Returns the name of the archive file.

    'owner' and 'group' are used when creating a tar archive. By default,
    uses the current owner and group.
    """
    save_cwd = os.getcwd()
    if root_dir is not None:
        if logger is not None:
            logger.debug("changing into '%s'", root_dir)
        base_name = os.path.abspath(base_name)
        if not dry_run:
            os.chdir(root_dir)

    if base_dir is None:
        base_dir = os.curdir

    kwargs = {'dry_run': dry_run, 'logger': logger}

    try:
        format_info = _ARCHIVE_FORMATS[format]
    except KeyError:
789
        raise ValueError("unknown archive format '%s'" % format) from None
790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807

    func = format_info[0]
    for arg, val in format_info[1]:
        kwargs[arg] = val

    if format != 'zip':
        kwargs['owner'] = owner
        kwargs['group'] = group

    try:
        filename = func(base_name, base_dir, **kwargs)
    finally:
        if root_dir is not None:
            if logger is not None:
                logger.debug("changing back to '%s'", save_cwd)
            os.chdir(save_cwd)

    return filename
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834


def get_unpack_formats():
    """Returns a list of supported formats for unpacking.

    Each element of the returned sequence is a tuple
    (name, extensions, description)
    """
    formats = [(name, info[0], info[3]) for name, info in
               _UNPACK_FORMATS.items()]
    formats.sort()
    return formats

def _check_unpack_options(extensions, function, extra_args):
    """Checks what gets registered as an unpacker."""
    # first make sure no other unpacker is registered for this extension
    existing_extensions = {}
    for name, info in _UNPACK_FORMATS.items():
        for ext in info[0]:
            existing_extensions[ext] = name

    for extension in extensions:
        if extension in existing_extensions:
            msg = '%s is already registered for "%s"'
            raise RegistryError(msg % (extension,
                                       existing_extensions[extension]))

835
    if not callable(function):
836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
        raise TypeError('The registered function must be a callable')


def register_unpack_format(name, extensions, function, extra_args=None,
                           description=''):
    """Registers an unpack format.

    `name` is the name of the format. `extensions` is a list of extensions
    corresponding to the format.

    `function` is the callable that will be
    used to unpack archives. The callable will receive archives to unpack.
    If it's unable to handle an archive, it needs to raise a ReadError
    exception.

    If provided, `extra_args` is a sequence of
    (name, value) tuples that will be passed as arguments to the callable.
    description can be provided to describe the format, and will be returned
    by the get_unpack_formats() function.
    """
    if extra_args is None:
        extra_args = []
    _check_unpack_options(extensions, function, extra_args)
    _UNPACK_FORMATS[name] = extensions, function, extra_args, description

def unregister_unpack_format(name):
862
    """Removes the pack format from the registry."""
863 864 865 866 867 868 869 870 871 872 873
    del _UNPACK_FORMATS[name]

def _ensure_directory(path):
    """Ensure that the parent directory of `path` exists"""
    dirname = os.path.dirname(path)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)

def _unpack_zipfile(filename, extract_dir):
    """Unpack zip `filename` to `extract_dir`
    """
874
    import zipfile  # late import for breaking circular dependency
875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895

    if not zipfile.is_zipfile(filename):
        raise ReadError("%s is not a zip file" % filename)

    zip = zipfile.ZipFile(filename)
    try:
        for info in zip.infolist():
            name = info.filename

            # don't extract absolute paths or ones with .. in them
            if name.startswith('/') or '..' in name:
                continue

            target = os.path.join(extract_dir, *name.split('/'))
            if not target:
                continue

            _ensure_directory(target)
            if not name.endswith('/'):
                # file
                data = zip.read(info.filename)
896
                f = open(target, 'wb')
897 898 899 900 901 902 903 904 905
                try:
                    f.write(data)
                finally:
                    f.close()
                    del data
    finally:
        zip.close()

def _unpack_tarfile(filename, extract_dir):
906
    """Unpack tar/tar.gz/tar.bz2/tar.xz `filename` to `extract_dir`
907
    """
908
    import tarfile  # late import for breaking circular dependency
909 910 911 912 913 914 915 916 917 918 919 920
    try:
        tarobj = tarfile.open(filename)
    except tarfile.TarError:
        raise ReadError(
            "%s is not a compressed or uncompressed tar file" % filename)
    try:
        tarobj.extractall(extract_dir)
    finally:
        tarobj.close()

_UNPACK_FORMATS = {
    'tar':   (['.tar'], _unpack_tarfile, [], "uncompressed tar file"),
921 922 923 924 925 926
    'zip':   (['.zip'], _unpack_zipfile, [], "ZIP file"),
}

if _ZLIB_SUPPORTED:
    _UNPACK_FORMATS['gztar'] = (['.tar.gz', '.tgz'], _unpack_tarfile, [],
                                "gzip'ed tar-file")
927

928
if _BZ2_SUPPORTED:
929
    _UNPACK_FORMATS['bztar'] = (['.tar.bz2', '.tbz2'], _unpack_tarfile, [],
930 931
                                "bzip2'ed tar-file")

932 933 934 935
if _LZMA_SUPPORTED:
    _UNPACK_FORMATS['xztar'] = (['.tar.xz', '.txz'], _unpack_tarfile, [],
                                "xz'ed tar-file")

936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
def _find_unpack_format(filename):
    for name, info in _UNPACK_FORMATS.items():
        for extension in info[0]:
            if filename.endswith(extension):
                return name
    return None

def unpack_archive(filename, extract_dir=None, format=None):
    """Unpack an archive.

    `filename` is the name of the archive.

    `extract_dir` is the name of the target directory, where the archive
    is unpacked. If not provided, the current working directory is used.

951 952 953 954
    `format` is the archive format: one of "zip", "tar", "gztar", "bztar",
    or "xztar".  Or any other registered format.  If not provided,
    unpack_archive will use the filename extension and see if an unpacker
    was registered for that extension.
955 956 957 958 959 960

    In case none is found, a ValueError is raised.
    """
    if extract_dir is None:
        extract_dir = os.getcwd()

961 962 963
    extract_dir = os.fspath(extract_dir)
    filename = os.fspath(filename)

964 965 966 967
    if format is not None:
        try:
            format_info = _UNPACK_FORMATS[format]
        except KeyError:
968
            raise ValueError("Unknown unpack format '{0}'".format(format)) from None
969

970 971
        func = format_info[1]
        func(filename, extract_dir, **dict(format_info[2]))
972 973 974 975 976 977 978 979 980
    else:
        # we need to look at the registered unpackers supported extensions
        format = _find_unpack_format(filename)
        if format is None:
            raise ReadError("Unknown archive format '{0}'".format(filename))

        func = _UNPACK_FORMATS[format][1]
        kwargs = dict(_UNPACK_FORMATS[format][2])
        func(filename, extract_dir, **kwargs)
981

Éric Araujo's avatar
Éric Araujo committed
982 983 984 985 986

if hasattr(os, 'statvfs'):

    __all__.append('disk_usage')
    _ntuple_diskusage = collections.namedtuple('usage', 'total used free')
987 988 989
    _ntuple_diskusage.total.__doc__ = 'Total space in bytes'
    _ntuple_diskusage.used.__doc__ = 'Used space in bytes'
    _ntuple_diskusage.free.__doc__ = 'Free space in bytes'
990 991

    def disk_usage(path):
Éric Araujo's avatar
Éric Araujo committed
992 993
        """Return disk usage statistics about the given path.

994
        Returned value is a named tuple with attributes 'total', 'used' and
Éric Araujo's avatar
Éric Araujo committed
995
        'free', which are the amount of total, used and free space, in bytes.
996
        """
Éric Araujo's avatar
Éric Araujo committed
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
        st = os.statvfs(path)
        free = st.f_bavail * st.f_frsize
        total = st.f_blocks * st.f_frsize
        used = (st.f_blocks - st.f_bfree) * st.f_frsize
        return _ntuple_diskusage(total, used, free)

elif os.name == 'nt':

    import nt
    __all__.append('disk_usage')
    _ntuple_diskusage = collections.namedtuple('usage', 'total used free')

    def disk_usage(path):
        """Return disk usage statistics about the given path.

1012
        Returned values is a named tuple with attributes 'total', 'used' and
Éric Araujo's avatar
Éric Araujo committed
1013 1014 1015 1016
        'free', which are the amount of total, used and free space, in bytes.
        """
        total, free = nt._getdiskusage(path)
        used = total - free
1017
        return _ntuple_diskusage(total, used, free)
1018

1019

1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
def chown(path, user=None, group=None):
    """Change owner user and group of the given path.

    user and group can be the uid/gid or the user/group names, and in that case,
    they are converted to their respective uid/gid.
    """

    if user is None and group is None:
        raise ValueError("user and/or group must be set")

    _user = user
    _group = group

    # -1 means don't change it
    if user is None:
        _user = -1
    # user can either be an int (the uid) or a string (the system username)
    elif isinstance(user, str):
        _user = _get_uid(user)
        if _user is None:
            raise LookupError("no such user: {!r}".format(user))

    if group is None:
        _group = -1
    elif not isinstance(group, int):
        _group = _get_gid(group)
        if _group is None:
            raise LookupError("no such group: {!r}".format(group))

    os.chown(path, _user, _group)
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084

def get_terminal_size(fallback=(80, 24)):
    """Get the size of the terminal window.

    For each of the two dimensions, the environment variable, COLUMNS
    and LINES respectively, is checked. If the variable is defined and
    the value is a positive integer, it is used.

    When COLUMNS or LINES is not defined, which is the common case,
    the terminal connected to sys.__stdout__ is queried
    by invoking os.get_terminal_size.

    If the terminal size cannot be successfully queried, either because
    the system doesn't support querying, or because we are not
    connected to a terminal, the value given in fallback parameter
    is used. Fallback defaults to (80, 24) which is the default
    size used by many terminal emulators.

    The value returned is a named tuple of type os.terminal_size.
    """
    # columns, lines are the working values
    try:
        columns = int(os.environ['COLUMNS'])
    except (KeyError, ValueError):
        columns = 0

    try:
        lines = int(os.environ['LINES'])
    except (KeyError, ValueError):
        lines = 0

    # only query if necessary
    if columns <= 0 or lines <= 0:
        try:
            size = os.get_terminal_size(sys.__stdout__.fileno())
1085 1086 1087
        except (AttributeError, ValueError, OSError):
            # stdout is None, closed, detached, or not a terminal, or
            # os.get_terminal_size() is unsupported
1088 1089 1090 1091 1092 1093 1094
            size = os.terminal_size(fallback)
        if columns <= 0:
            columns = size.columns
        if lines <= 0:
            lines = size.lines

    return os.terminal_size((columns, lines))
1095 1096

def which(cmd, mode=os.F_OK | os.X_OK, path=None):
1097
    """Given a command, mode, and a PATH string, return the path which
1098 1099 1100 1101 1102 1103 1104 1105
    conforms to the given mode on the PATH, or None if there is no such
    file.

    `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
    of os.environ.get("PATH"), or can be overridden with a custom search
    path.

    """
1106 1107 1108 1109 1110 1111 1112
    # Check that a given file can be accessed with the correct mode.
    # Additionally check that `file` is not a directory, as on Windows
    # directories pass the os.access check.
    def _access_check(fn, mode):
        return (os.path.exists(fn) and os.access(fn, mode)
                and not os.path.isdir(fn))

1113 1114 1115 1116 1117 1118 1119
    # If we're given a path with a directory part, look it up directly rather
    # than referring to PATH directories. This includes checking relative to the
    # current directory, e.g. ./script
    if os.path.dirname(cmd):
        if _access_check(cmd, mode):
            return cmd
        return None
1120

1121 1122 1123 1124
    if path is None:
        path = os.environ.get("PATH", os.defpath)
    if not path:
        return None
1125
    path = path.split(os.pathsep)
1126 1127 1128 1129 1130 1131 1132 1133 1134 1135

    if sys.platform == "win32":
        # The current directory takes precedence on Windows.
        if not os.curdir in path:
            path.insert(0, os.curdir)

        # PATHEXT is necessary to check on Windows.
        pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
        # See if the given file matches any of the expected path extensions.
        # This will allow us to short circuit when given "python.exe".
1136 1137
        # If it does match, only test that one, otherwise we have to try
        # others.
1138 1139 1140 1141
        if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
            files = [cmd]
        else:
            files = [cmd + ext for ext in pathext]
1142 1143 1144 1145 1146 1147 1148
    else:
        # On other platforms you don't have things like PATHEXT to tell you
        # what file suffixes are executable, so just pass on cmd as-is.
        files = [cmd]

    seen = set()
    for dir in path:
1149 1150 1151
        normdir = os.path.normcase(dir)
        if not normdir in seen:
            seen.add(normdir)
1152 1153 1154 1155 1156
            for thefile in files:
                name = os.path.join(dir, thefile)
                if _access_check(name, mode):
                    return name
    return None