__init__.py 69.5 KB
Newer Older
1
"""Supporting definitions for the Python regression tests."""
2

3 4
if __name__ != 'test.support':
    raise ImportError('support must be imported from the test package')
5

6 7
import contextlib
import errno
Benjamin Peterson's avatar
Benjamin Peterson committed
8
import functools
9
import gc
10
import socket
11
import sys
12
import os
Benjamin Peterson's avatar
Benjamin Peterson committed
13
import platform
14
import shutil
15
import warnings
16
import unittest
17
import importlib
18
import collections.abc
19
import re
20
import subprocess
Barry Warsaw's avatar
Barry Warsaw committed
21
import imp
22
import time
Benjamin Peterson's avatar
Benjamin Peterson committed
23
import sysconfig
24
import fnmatch
25
import logging.handlers
26 27 28
import struct
import tempfile
import _testcapi
Benjamin Peterson's avatar
Benjamin Peterson committed
29

30
try:
31
    import _thread, threading
32 33
except ImportError:
    _thread = None
34 35 36 37 38 39
    threading = None
try:
    import multiprocessing.process
except ImportError:
    multiprocessing = None

40 41 42 43
try:
    import zlib
except ImportError:
    zlib = None
44

45 46 47 48 49
try:
    import gzip
except ImportError:
    gzip = None

50 51 52 53 54
try:
    import bz2
except ImportError:
    bz2 = None

55 56 57 58
try:
    import lzma
except ImportError:
    lzma = None
59

60
__all__ = [
61 62
    "Error", "TestFailed", "ResourceDenied", "import_module", "verbose",
    "use_resources", "max_memuse", "record_original_stdout",
63
    "get_original_stdout", "unload", "unlink", "rmtree", "forget",
64
    "is_resource_enabled", "requires", "requires_freebsd_version",
65 66 67 68 69 70 71
    "requires_linux_version", "requires_mac_ver", "find_unused_port",
    "bind_port", "IPV6_ENABLED", "is_jython", "TESTFN", "HOST", "SAVEDCWD",
    "temp_cwd", "findfile", "create_empty_file", "sortdict",
    "check_syntax_error", "open_urlresource", "check_warnings", "CleanImport",
    "EnvironmentVarGuard", "TransientResource", "captured_stdout",
    "captured_stdin", "captured_stderr", "time_out", "socket_peer_reset",
    "ioerror_peer_reset", "run_with_locale", 'temp_umask',
72 73 74 75
    "transient_internet", "set_memlimit", "bigmemtest", "bigaddrspacetest",
    "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
    "threading_cleanup", "reap_children", "cpython_only", "check_impl_detail",
    "get_attribute", "swap_item", "swap_attr", "requires_IEEE_754",
Benjamin Peterson's avatar
Benjamin Peterson committed
76
    "TestHandler", "Matcher", "can_symlink", "skip_unless_symlink",
77 78
    "skip_unless_xattr", "import_fresh_module", "requires_zlib",
    "PIPE_MAX_SIZE", "failfast", "anticipate_failure", "run_with_tz",
79
    "requires_gzip", "requires_bz2", "requires_lzma", "suppress_crash_popup",
80
    ]
81

82
class Error(Exception):
83
    """Base class for regression test exceptions."""
84 85

class TestFailed(Error):
86
    """Test failed."""
87

Benjamin Peterson's avatar
Benjamin Peterson committed
88
class ResourceDenied(unittest.SkipTest):
89 90 91 92 93 94 95
    """Test skipped because it requested a disallowed resource.

    This is raised when a test calls requires() for a resource that
    has not be enabled.  It is used to distinguish between expected
    and unexpected skips.
    """

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
@contextlib.contextmanager
def _ignore_deprecated_imports(ignore=True):
    """Context manager to suppress package and module deprecation
    warnings when importing them.

    If ignore is False, this context manager has no effect."""
    if ignore:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", ".+ (module|package)",
                                    DeprecationWarning)
            yield
    else:
        yield


111
def import_module(name, deprecated=False):
112 113 114 115 116
    """Import and return the module to be tested, raising SkipTest if
    it is not available.

    If deprecated is True, any module or package deprecation messages
    will be suppressed."""
117
    with _ignore_deprecated_imports(deprecated):
118
        try:
119
            return importlib.import_module(name)
120 121
        except ImportError as msg:
            raise unittest.SkipTest(str(msg))
122 123


124 125 126
def _save_and_remove_module(name, orig_modules):
    """Helper function to save and remove a module from sys.modules

127 128
    Raise ImportError if the module can't be imported.
    """
129 130
    # try to import the module and raise an error if it can't be imported
    if name not in sys.modules:
131
        __import__(name)
132
        del sys.modules[name]
133 134 135 136
    for modname in list(sys.modules):
        if modname == name or modname.startswith(name + '.'):
            orig_modules[modname] = sys.modules[modname]
            del sys.modules[modname]
137 138 139 140

def _save_and_block_module(name, orig_modules):
    """Helper function to save and block a module in sys.modules

141 142
    Return True if the module was in sys.modules, False otherwise.
    """
143 144 145 146 147
    saved = True
    try:
        orig_modules[name] = sys.modules[name]
    except KeyError:
        saved = False
148
    sys.modules[name] = None
149 150 151
    return saved


152 153 154 155 156 157 158 159 160 161 162
def anticipate_failure(condition):
    """Decorator to mark a test that is known to be broken in some cases

       Any use of this decorator should have a comment identifying the
       associated tracker issue.
    """
    if condition:
        return unittest.expectedFailure
    return lambda f: f


163
def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
164
    """Import and return a module, deliberately bypassing sys.modules.
165

166 167 168 169
    This function imports and returns a fresh copy of the named Python module
    by removing the named module from sys.modules before doing the import.
    Note that unlike reload, the original module is not affected by
    this operation.
170

171 172
    *fresh* is an iterable of additional module names that are also removed
    from the sys.modules cache before doing the import.
173

174 175 176 177 178 179 180 181 182 183 184 185 186 187
    *blocked* is an iterable of module names that are replaced with None
    in the module cache during the import to ensure that attempts to import
    them raise ImportError.

    The named module and any modules named in the *fresh* and *blocked*
    parameters are saved before starting the import and then reinserted into
    sys.modules when the fresh import is complete.

    Module and package deprecation messages are suppressed during this import
    if *deprecated* is True.

    This function will raise ImportError if the named module cannot be
    imported.
    """
188 189
    # NOTE: test_heapq, test_json and test_warnings include extra sanity checks
    # to make sure that this utility function is working as expected
190
    with _ignore_deprecated_imports(deprecated):
191 192
        # Keep track of modules saved for later restoration as well
        # as those which just need a blocking entry removed
193
        orig_modules = {}
194 195
        names_to_remove = []
        _save_and_remove_module(name, orig_modules)
196
        try:
197 198 199 200 201 202
            for fresh_name in fresh:
                _save_and_remove_module(fresh_name, orig_modules)
            for blocked_name in blocked:
                if not _save_and_block_module(blocked_name, orig_modules):
                    names_to_remove.append(blocked_name)
            fresh_module = importlib.import_module(name)
203 204
        except ImportError:
            fresh_module = None
205
        finally:
206 207 208 209 210
            for orig_name, module in orig_modules.items():
                sys.modules[orig_name] = module
            for name_to_remove in names_to_remove:
                del sys.modules[name_to_remove]
        return fresh_module
211

212

213 214 215 216 217
def get_attribute(obj, name):
    """Get an attribute, raising SkipTest if AttributeError is raised."""
    try:
        attribute = getattr(obj, name)
    except AttributeError:
218
        raise unittest.SkipTest("object %r has no attribute %r" % (obj, name))
219 220 221
    else:
        return attribute

222
verbose = 1              # Flag set to 0 by regrtest.py
223 224 225
use_resources = None     # Flag set to [] by regrtest.py
max_memuse = 0           # Disable bigmem tests (they will still be run with
                         # small sizes, to make sure they work.)
Neal Norwitz's avatar
Neal Norwitz committed
226
real_max_memuse = 0
227
failfast = False
228
match_tests = None
229

230 231 232 233 234 235 236 237 238 239 240
# _original_stdout is meant to hold stdout at the time regrtest began.
# This may be "the real" stdout, or IDLE's emulation of stdout, or whatever.
# The point is to have some flavor of stdout the user can actually see.
_original_stdout = None
def record_original_stdout(stdout):
    global _original_stdout
    _original_stdout = stdout

def get_original_stdout():
    return _original_stdout or sys.stdout

Guido van Rossum's avatar
Guido van Rossum committed
241
def unload(name):
242 243 244 245
    try:
        del sys.modules[name]
    except KeyError:
        pass
Guido van Rossum's avatar
Guido van Rossum committed
246

247 248
if sys.platform.startswith("win"):
    def _waitfor(func, pathname, waitall=False):
249
        # Perform the operation
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        func(pathname)
        # Now setup the wait loop
        if waitall:
            dirname = pathname
        else:
            dirname, name = os.path.split(pathname)
            dirname = dirname or '.'
        # Check for `pathname` to be removed from the filesystem.
        # The exponential backoff of the timeout amounts to a total
        # of ~1 second after which the deletion is probably an error
        # anyway.
        # Testing on a i7@4.3GHz shows that usually only 1 iteration is
        # required when contention occurs.
        timeout = 0.001
        while timeout < 1.0:
265
            # Note we are only testing for the existence of the file(s) in
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
            # the contents of the directory regardless of any security or
            # access rights.  If we have made it this far, we have sufficient
            # permissions to do that much using Python's equivalent of the
            # Windows API FindFirstFile.
            # Other Windows APIs can fail or give incorrect results when
            # dealing with files that are pending deletion.
            L = os.listdir(dirname)
            if not (L if waitall else name in L):
                return
            # Increase the timeout and try again
            time.sleep(timeout)
            timeout *= 2
        warnings.warn('tests may fail, delete still pending for ' + pathname,
                      RuntimeWarning, stacklevel=4)

    def _unlink(filename):
        _waitfor(os.unlink, filename)

    def _rmdir(dirname):
        _waitfor(os.rmdir, dirname)

    def _rmtree(path):
        def _rmtree_inner(path):
            for name in os.listdir(path):
                fullname = os.path.join(path, name)
                if os.path.isdir(fullname):
                    _waitfor(_rmtree_inner, fullname, waitall=True)
                    os.rmdir(fullname)
                else:
                    os.unlink(fullname)
        _waitfor(_rmtree_inner, path, waitall=True)
        _waitfor(os.rmdir, path)
else:
    _unlink = os.unlink
    _rmdir = os.rmdir
    _rmtree = shutil.rmtree

303 304
def unlink(filename):
    try:
305
        _unlink(filename)
Barry Warsaw's avatar
Barry Warsaw committed
306 307
    except OSError as error:
        # The filename need not exist.
308
        if error.errno not in (errno.ENOENT, errno.ENOTDIR):
Barry Warsaw's avatar
Barry Warsaw committed
309
            raise
310

311 312 313 314 315 316 317 318
def rmdir(dirname):
    try:
        _rmdir(dirname)
    except OSError as error:
        # The directory need not exist.
        if error.errno != errno.ENOENT:
            raise

319 320
def rmtree(path):
    try:
321
        _rmtree(path)
Barry Warsaw's avatar
Barry Warsaw committed
322
    except OSError as error:
323
        if error.errno != errno.ENOENT:
324 325
            raise

Barry Warsaw's avatar
Barry Warsaw committed
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
def make_legacy_pyc(source):
    """Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location.

    The choice of .pyc or .pyo extension is done based on the __debug__ flag
    value.

    :param source: The file system path to the source file.  The source file
        does not need to exist, however the PEP 3147 pyc file must exist.
    :return: The file system path to the legacy pyc file.
    """
    pyc_file = imp.cache_from_source(source)
    up_one = os.path.dirname(os.path.abspath(source))
    legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o'))
    os.rename(pyc_file, legacy_pyc)
    return legacy_pyc

Guido van Rossum's avatar
Guido van Rossum committed
342
def forget(modname):
Barry Warsaw's avatar
Barry Warsaw committed
343 344 345 346 347
    """'Forget' a module was ever imported.

    This removes the module from sys.modules and deletes any PEP 3147 or
    legacy .pyc and .pyo files.
    """
348 349
    unload(modname)
    for dirname in sys.path:
Barry Warsaw's avatar
Barry Warsaw committed
350 351 352 353 354 355 356
        source = os.path.join(dirname, modname + '.py')
        # It doesn't matter if they exist or not, unlink all possible
        # combinations of PEP 3147 and legacy pyc and pyo files.
        unlink(source + 'c')
        unlink(source + 'o')
        unlink(imp.cache_from_source(source, debug_override=True))
        unlink(imp.cache_from_source(source, debug_override=False))
Guido van Rossum's avatar
Guido van Rossum committed
357

358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
# On some platforms, should not run gui test even if it is allowed
# in `use_resources'.
if sys.platform.startswith('win'):
    import ctypes
    import ctypes.wintypes
    def _is_gui_available():
        UOI_FLAGS = 1
        WSF_VISIBLE = 0x0001
        class USEROBJECTFLAGS(ctypes.Structure):
            _fields_ = [("fInherit", ctypes.wintypes.BOOL),
                        ("fReserved", ctypes.wintypes.BOOL),
                        ("dwFlags", ctypes.wintypes.DWORD)]
        dll = ctypes.windll.user32
        h = dll.GetProcessWindowStation()
        if not h:
            raise ctypes.WinError()
        uof = USEROBJECTFLAGS()
        needed = ctypes.wintypes.DWORD()
        res = dll.GetUserObjectInformationW(h,
            UOI_FLAGS,
            ctypes.byref(uof),
            ctypes.sizeof(uof),
            ctypes.byref(needed))
        if not res:
            raise ctypes.WinError()
        return bool(uof.dwFlags & WSF_VISIBLE)
else:
    def _is_gui_available():
        return True

388
def is_resource_enabled(resource):
389 390
    """Test whether a resource is enabled.  Known resources are set by
    regrtest.py."""
391 392
    return use_resources is not None and resource in use_resources

393
def requires(resource, msg=None):
394 395 396
    """Raise ResourceDenied if the specified resource is not available.

    If the caller's module is __main__ then automatically return True.  The
Barry Warsaw's avatar
Barry Warsaw committed
397 398 399
    possibility of False being returned occurs when regrtest.py is
    executing.
    """
400 401
    if resource == 'gui' and not _is_gui_available():
        raise unittest.SkipTest("Cannot use the 'gui' resource")
402 403
    # see if the caller's module is __main__ - if so, treat as if
    # the resource was set
Benjamin Peterson's avatar
Benjamin Peterson committed
404
    if sys._getframe(1).f_globals.get("__name__") == "__main__":
405
        return
406
    if not is_resource_enabled(resource):
407
        if msg is None:
408
            msg = "Use of the %r resource not enabled" % resource
409
        raise ResourceDenied(msg)
410

411 412 413
def _requires_unix_version(sysname, min_version):
    """Decorator raising SkipTest if the OS is `sysname` and the version is less
    than `min_version`.
414

415 416
    For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if
    the FreeBSD version is less than 7.2.
417 418 419 420
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kw):
421
            if platform.system() == sysname:
422 423 424 425 426 427 428 429 430
                version_txt = platform.release().split('-', 1)[0]
                try:
                    version = tuple(map(int, version_txt.split('.')))
                except ValueError:
                    pass
                else:
                    if version < min_version:
                        min_version_txt = '.'.join(map(str, min_version))
                        raise unittest.SkipTest(
431 432
                            "%s version %s or higher required, not %s"
                            % (sysname, min_version_txt, version_txt))
433 434
            return func(*args, **kw)
        wrapper.min_version = min_version
435 436
        return wrapper
    return decorator
437

438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
def requires_freebsd_version(*min_version):
    """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version is
    less than `min_version`.

    For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD
    version is less than 7.2.
    """
    return _requires_unix_version('FreeBSD', min_version)

def requires_linux_version(*min_version):
    """Decorator raising SkipTest if the OS is Linux and the Linux version is
    less than `min_version`.

    For example, @requires_linux_version(2, 6, 32) raises SkipTest if the Linux
    version is less than 2.6.32.
    """
    return _requires_unix_version('Linux', min_version)

456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482
def requires_mac_ver(*min_version):
    """Decorator raising SkipTest if the OS is Mac OS X and the OS X
    version if less than min_version.

    For example, @requires_mac_ver(10, 5) raises SkipTest if the OS X version
    is lesser than 10.5.
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kw):
            if sys.platform == 'darwin':
                version_txt = platform.mac_ver()[0]
                try:
                    version = tuple(map(int, version_txt.split('.')))
                except ValueError:
                    pass
                else:
                    if version < min_version:
                        min_version_txt = '.'.join(map(str, min_version))
                        raise unittest.SkipTest(
                            "Mac OS X %s or higher required, not %s"
                            % (min_version_txt, version_txt))
            return func(*args, **kw)
        wrapper.min_version = min_version
        return wrapper
    return decorator

483

484 485 486 487 488
# Don't use "localhost", since resolving it uses the DNS under recent
# Windows versions (see issue #18792).
HOST = "127.0.0.1"
HOSTv6 = "::1"

Christian Heimes's avatar
Christian Heimes committed
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581

def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
    """Returns an unused port that should be suitable for binding.  This is
    achieved by creating a temporary socket with the same family and type as
    the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to
    the specified host address (defaults to 0.0.0.0) with the port set to 0,
    eliciting an unused ephemeral port from the OS.  The temporary socket is
    then closed and deleted, and the ephemeral port is returned.

    Either this method or bind_port() should be used for any tests where a
    server socket needs to be bound to a particular port for the duration of
    the test.  Which one to use depends on whether the calling code is creating
    a python socket, or if an unused port needs to be provided in a constructor
    or passed to an external program (i.e. the -accept argument to openssl's
    s_server mode).  Always prefer bind_port() over find_unused_port() where
    possible.  Hard coded ports should *NEVER* be used.  As soon as a server
    socket is bound to a hard coded port, the ability to run multiple instances
    of the test simultaneously on the same host is compromised, which makes the
    test a ticking time bomb in a buildbot environment. On Unix buildbots, this
    may simply manifest as a failed test, which can be recovered from without
    intervention in most cases, but on Windows, the entire python process can
    completely and utterly wedge, requiring someone to log in to the buildbot
    and manually kill the affected process.

    (This is easy to reproduce on Windows, unfortunately, and can be traced to
    the SO_REUSEADDR socket option having different semantics on Windows versus
    Unix/Linux.  On Unix, you can't have two AF_INET SOCK_STREAM sockets bind,
    listen and then accept connections on identical host/ports.  An EADDRINUSE
    socket.error will be raised at some point (depending on the platform and
    the order bind and listen were called on each socket).

    However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE
    will ever be raised when attempting to bind two identical host/ports. When
    accept() is called on each socket, the second caller's process will steal
    the port from the first caller, leaving them both in an awkwardly wedged
    state where they'll no longer respond to any signals or graceful kills, and
    must be forcibly killed via OpenProcess()/TerminateProcess().

    The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option
    instead of SO_REUSEADDR, which effectively affords the same semantics as
    SO_REUSEADDR on Unix.  Given the propensity of Unix developers in the Open
    Source world compared to Windows ones, this is a common mistake.  A quick
    look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when
    openssl.exe is called with the 's_server' option, for example. See
    http://bugs.python.org/issue2550 for more info.  The following site also
    has a very thorough description about the implications of both REUSEADDR
    and EXCLUSIVEADDRUSE on Windows:
    http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx)

    XXX: although this approach is a vast improvement on previous attempts to
    elicit unused ports, it rests heavily on the assumption that the ephemeral
    port returned to us by the OS won't immediately be dished back out to some
    other process when we close and delete our temporary socket but before our
    calling code has a chance to bind the returned port.  We can deal with this
    issue if/when we come across it.
    """

    tempsock = socket.socket(family, socktype)
    port = bind_port(tempsock)
    tempsock.close()
    del tempsock
    return port

def bind_port(sock, host=HOST):
    """Bind the socket to a free port and return the port number.  Relies on
    ephemeral ports in order to ensure we are using an unbound port.  This is
    important as many tests may be running simultaneously, especially in a
    buildbot environment.  This method raises an exception if the sock.family
    is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR
    or SO_REUSEPORT set on it.  Tests should *never* set these socket options
    for TCP/IP sockets.  The only case for setting these options is testing
    multicasting via multiple UDP sockets.

    Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e.
    on Windows), it will be set on the socket.  This will prevent anyone else
    from bind()'ing to our host/port for the duration of the test.
    """

    if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM:
        if hasattr(socket, 'SO_REUSEADDR'):
            if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1:
                raise TestFailed("tests should never set the SO_REUSEADDR "   \
                                 "socket option on TCP/IP sockets!")
        if hasattr(socket, 'SO_REUSEPORT'):
            if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1:
                raise TestFailed("tests should never set the SO_REUSEPORT "   \
                                 "socket option on TCP/IP sockets!")
        if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'):
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)

    sock.bind((host, 0))
    port = sock.getsockname()[1]
    return port
582

583 584 585
def _is_ipv6_enabled():
    """Check whether IPv6 is enabled on this host."""
    if socket.has_ipv6:
586
        sock = None
587
        try:
588 589
            sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
            sock.bind(('::1', 0))
590
            return True
591
        except (socket.error, socket.gaierror):
592
            pass
593 594 595
        finally:
            if sock:
                sock.close()
596 597 598
    return False

IPV6_ENABLED = _is_ipv6_enabled()
599

600

601 602 603 604 605
# A constant likely larger than the underlying OS pipe buffer size, to
# make writes blocking.
# Windows limit seems to be around 512 B, and many Unix kernels have a
# 64 KiB pipe buffer size or 16 * PAGE_SIZE: take a few megs to be sure.
# (see issue #17835 for a discussion of this number).
606 607 608 609 610 611 612 613
PIPE_MAX_SIZE = 4 * 1024 * 1024 + 1

# A constant likely larger than the underlying OS socket buffer size, to make
# writes blocking.
# The socket buffer sizes can usually be tuned system-wide (e.g. through sysctl
# on Linux), or on a per-socket basis (SO_SNDBUF/SO_RCVBUF). See issue #18643
# for a discussion of this number).
SOCK_MAX_SIZE = 16 * 1024 * 1024 + 1
614

615 616 617 618 619
# decorator for skipping tests on non-IEEE 754 platforms
requires_IEEE_754 = unittest.skipUnless(
    float.__getformat__("double").startswith("IEEE"),
    "test requires IEEE 754 doubles")

620 621
requires_zlib = unittest.skipUnless(zlib, 'requires zlib')

622 623
requires_gzip = unittest.skipUnless(gzip, 'requires gzip')

624 625
requires_bz2 = unittest.skipUnless(bz2, 'requires bz2')

626 627
requires_lzma = unittest.skipUnless(lzma, 'requires lzma')

628 629
is_jython = sys.platform.startswith('java')

630 631 632 633
# Filename used for testing
if os.name == 'java':
    # Jython disallows @ in module names
    TESTFN = '$test'
634
else:
635
    TESTFN = '@test'
636

637 638
# Disambiguate TESTFN for parallel testing, while letting it remain a valid
# module name.
639
TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid())
640

641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
# FS_NONASCII: non-ASCII character encodable by os.fsencode(),
# or None if there is no such character.
FS_NONASCII = None
for character in (
    # First try printable and common characters to have a readable filename.
    # For each character, the encoding list are just example of encodings able
    # to encode the character (the list is not exhaustive).

    # U+00E6 (Latin Small Letter Ae): cp1252, iso-8859-1
    '\u00E6',
    # U+0130 (Latin Capital Letter I With Dot Above): cp1254, iso8859_3
    '\u0130',
    # U+0141 (Latin Capital Letter L With Stroke): cp1250, cp1257
    '\u0141',
    # U+03C6 (Greek Small Letter Phi): cp1253
    '\u03C6',
    # U+041A (Cyrillic Capital Letter Ka): cp1251
    '\u041A',
    # U+05D0 (Hebrew Letter Alef): Encodable to cp424
    '\u05D0',
    # U+060C (Arabic Comma): cp864, cp1006, iso8859_6, mac_arabic
    '\u060C',
    # U+062A (Arabic Letter Teh): cp720
    '\u062A',
    # U+0E01 (Thai Character Ko Kai): cp874
    '\u0E01',

    # Then try more "special" characters. "special" because they may be
    # interpreted or displayed differently depending on the exact locale
    # encoding and the font.

    # U+00A0 (No-Break Space)
    '\u00A0',
    # U+20AC (Euro Sign)
    '\u20AC',
):
    try:
        os.fsdecode(os.fsencode(character))
    except UnicodeError:
        pass
    else:
        FS_NONASCII = character
        break
684

685 686
# TESTFN_UNICODE is a non-ascii filename
TESTFN_UNICODE = TESTFN + "-\xe0\xf2\u0258\u0141\u011f"
687 688 689 690 691 692
if sys.platform == 'darwin':
    # In Mac OS X's VFS API file names are, by definition, canonically
    # decomposed Unicode, encoded using UTF-8. See QA1173:
    # http://developer.apple.com/mac/library/qa/qa2001/qa1173.html
    import unicodedata
    TESTFN_UNICODE = unicodedata.normalize('NFD', TESTFN_UNICODE)
693
TESTFN_ENCODING = sys.getfilesystemencoding()
694

695
# TESTFN_UNENCODABLE is a filename (str type) that should *not* be able to be
696 697
# encoded by the filesystem encoding (in strict mode). It can be None if we
# cannot generate such filename.
698
TESTFN_UNENCODABLE = None
699
if os.name in ('nt', 'ce'):
700 701
    # skip win32s (0) or Windows 9x/ME (1)
    if sys.getwindowsversion().platform >= 2:
702 703 704
        # Different kinds of characters from various languages to minimize the
        # probability that the whole name is encodable to MBCS (issue #9819)
        TESTFN_UNENCODABLE = TESTFN + "-\u5171\u0141\u2661\u0363\uDC80"
705
        try:
706
            TESTFN_UNENCODABLE.encode(TESTFN_ENCODING)
707 708
        except UnicodeEncodeError:
            pass
709
        else:
710
            print('WARNING: The filename %r CAN be encoded by the filesystem encoding (%s). '
Georg Brandl's avatar
Georg Brandl committed
711
                  'Unicode filename tests may not be effective'
712 713
                  % (TESTFN_UNENCODABLE, TESTFN_ENCODING))
            TESTFN_UNENCODABLE = None
714
# Mac OS X denies unencodable filenames (invalid utf-8)
715
elif sys.platform != 'darwin':
716
    try:
717 718 719 720
        # ascii and utf-8 cannot encode the byte 0xff
        b'\xff'.decode(TESTFN_ENCODING)
    except UnicodeDecodeError:
        # 0xff will be encoded using the surrogate character u+DCFF
721
        TESTFN_UNENCODABLE = TESTFN \
722
            + b'-\xff'.decode(TESTFN_ENCODING, 'surrogateescape')
723
    else:
724 725
        # File system encoding (eg. ISO-8859-* encodings) can encode
        # the byte 0xff. Skip some unicode filename tests.
726
        pass
727

728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
# TESTFN_UNDECODABLE is a filename (bytes type) that should *not* be able to be
# decoded from the filesystem encoding (in strict mode). It can be None if we
# cannot generate such filename (ex: the latin1 encoding can decode any byte
# sequence). On UNIX, TESTFN_UNDECODABLE can be decoded by os.fsdecode() thanks
# to the surrogateescape error handler (PEP 383), but not from the filesystem
# encoding in strict mode.
TESTFN_UNDECODABLE = None
for name in (
    # b'\xff' is not decodable by os.fsdecode() with code page 932. Windows
    # accepts it to create a file or a directory, or don't accept to enter to
    # such directory (when the bytes name is used). So test b'\xe7' first: it is
    # not decodable from cp932.
    b'\xe7w\xf0',
    # undecodable from ASCII, UTF-8
    b'\xff',
    # undecodable from iso8859-3, iso8859-6, iso8859-7, cp424, iso8859-8, cp856
    # and cp857
    b'\xae\xd5'
    # undecodable from UTF-8 (UNIX and Mac OS X)
    b'\xed\xb2\x80', b'\xed\xb4\x80',
748 749 750
    # undecodable from shift_jis, cp869, cp874, cp932, cp1250, cp1251, cp1252,
    # cp1253, cp1254, cp1255, cp1257, cp1258
    b'\x81\x98',
751 752 753 754 755 756 757 758 759 760 761 762
):
    try:
        name.decode(TESTFN_ENCODING)
    except UnicodeDecodeError:
        TESTFN_UNDECODABLE = os.fsencode(TESTFN) + name
        break

if FS_NONASCII:
    TESTFN_NONASCII = TESTFN + '-' + FS_NONASCII
else:
    TESTFN_NONASCII = None

763 764
# Save the initial cwd
SAVEDCWD = os.getcwd()
765

766
@contextlib.contextmanager
767 768 769 770 771 772 773
def temp_dir(path=None, quiet=False):
    """Return a context manager that creates a temporary directory.

    Arguments:

      path: the directory to create temporarily.  If omitted or None,
        defaults to creating a temporary directory using tempfile.mkdtemp.
774

775 776 777
      quiet: if False (the default), the context manager raises an exception
        on error.  Otherwise, if the path is specified and cannot be
        created, only a warning is issued.
778

779
    """
780
    dir_created = False
781
    if path is None:
782 783 784 785
        path = tempfile.mkdtemp()
        dir_created = True
        path = os.path.realpath(path)
    else:
786
        try:
787 788
            os.mkdir(path)
            dir_created = True
789 790 791
        except OSError:
            if not quiet:
                raise
792
            warnings.warn('tests may fail, unable to create temp dir: ' + path,
793
                          RuntimeWarning, stacklevel=3)
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
    try:
        yield path
    finally:
        if dir_created:
            shutil.rmtree(path)

@contextlib.contextmanager
def change_cwd(path, quiet=False):
    """Return a context manager that changes the current working directory.

    Arguments:

      path: the directory to use as the temporary current working directory.

      quiet: if False (the default), the context manager raises an exception
        on error.  Otherwise, it issues only a warning and keeps the current
        working directory the same.

    """
    saved_dir = os.getcwd()
814
    try:
815
        os.chdir(path)
816 817 818
    except OSError:
        if not quiet:
            raise
819
        warnings.warn('tests may fail, unable to change CWD to: ' + path,
820 821 822 823 824 825
                      RuntimeWarning, stacklevel=3)
    try:
        yield os.getcwd()
    finally:
        os.chdir(saved_dir)

826

827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845
@contextlib.contextmanager
def temp_cwd(name='tempcwd', quiet=False):
    """
    Context manager that temporarily creates and changes the CWD.

    The function temporarily changes the current working directory
    after creating a temporary directory in the current directory with
    name *name*.  If *name* is None, the temporary directory is
    created using tempfile.mkdtemp.

    If *quiet* is False (default) and it is not possible to
    create or change the CWD, an error is raised.  If *quiet* is True,
    only a warning is raised and the original CWD is used.

    """
    with temp_dir(path=name, quiet=quiet) as temp_path:
        with change_cwd(temp_path, quiet=quiet) as cwd_dir:
            yield cwd_dir

846 847 848 849 850 851 852 853 854
if hasattr(os, "umask"):
    @contextlib.contextmanager
    def temp_umask(umask):
        """Context manager that temporarily sets the process umask."""
        oldmask = os.umask(umask)
        try:
            yield
        finally:
            os.umask(oldmask)
Barry Warsaw's avatar
Barry Warsaw committed
855

856
# TEST_HOME_DIR refers to the top level directory of the "test" package
857
# that contains Python's regression test suite
858 859
TEST_SUPPORT_DIR = os.path.dirname(os.path.abspath(__file__))
TEST_HOME_DIR = os.path.dirname(TEST_SUPPORT_DIR)
860

861 862 863 864
# TEST_DATA_DIR is used as a target download location for remote resources
TEST_DATA_DIR = os.path.join(TEST_HOME_DIR, "data")

def findfile(filename, subdir=None):
865
    """Try to find a file on sys.path or in the test directory.  If it is not
866
    found the argument passed to the function is returned (this does not
867 868 869 870 871 872 873
    necessarily signal failure; could still be the legitimate path).

    Setting *subdir* indicates a relative path to use to find the file
    rather than looking directly in the path directories.
    """
    if os.path.isabs(filename):
        return filename
874
    if subdir is not None:
875 876
        filename = os.path.join(subdir, filename)
    path = [TEST_HOME_DIR] + sys.path
877
    for dn in path:
878
        fn = os.path.join(dn, filename)
879
        if os.path.exists(fn): return fn
880
    return filename
881

882 883 884 885 886
def create_empty_file(filename):
    """Create an empty file. If the file already exists, truncate it."""
    fd = os.open(filename, os.O_WRONLY | os.O_CREAT | os.O_TRUNC)
    os.close(fd)

887 888
def sortdict(dict):
    "Like repr(dict), but in sorted order."
889
    items = sorted(dict.items())
890 891 892 893
    reprpairs = ["%r: %r" % pair for pair in items]
    withcommas = ", ".join(reprpairs)
    return "{%s}" % withcommas

894 895 896 897 898 899 900 901 902 903 904 905
def make_bad_fd():
    """
    Create an invalid file descriptor by opening and closing a file and return
    its fd.
    """
    file = open(TESTFN, "wb")
    try:
        return file.fileno()
    finally:
        file.close()
        unlink(TESTFN)

906
def check_syntax_error(testcase, statement):
Benjamin Peterson's avatar
Benjamin Peterson committed
907 908
    testcase.assertRaises(SyntaxError, compile, statement,
                          '<test string>', 'exec')
909

910
def open_urlresource(url, *args, **kw):
Jeremy Hylton's avatar
Jeremy Hylton committed
911
    import urllib.request, urllib.parse
912

913 914
    check = kw.pop('check', None)

Jeremy Hylton's avatar
Jeremy Hylton committed
915
    filename = urllib.parse.urlparse(url)[2].split('/')[-1] # '/': it's URL!
916

917
    fn = os.path.join(TEST_DATA_DIR, filename)
918 919 920 921 922 923 924 925 926 927

    def check_valid_file(fn):
        f = open(fn, *args, **kw)
        if check is None:
            return f
        elif check(f):
            f.seek(0)
            return f
        f.close()

928
    if os.path.exists(fn):
929 930 931 932 933 934 935
        f = check_valid_file(fn)
        if f is not None:
            return f
        unlink(fn)

    # Verify the requirement before downloading the file
    requires('urlfetch')
936

937
    print('\tfetching %s ...' % url, file=get_original_stdout())
938 939
    f = urllib.request.urlopen(url, timeout=15)
    try:
940
        with open(fn, "wb") as out:
941 942 943 944 945 946
            s = f.read()
            while s:
                out.write(s)
                s = f.read()
    finally:
        f.close()
947 948 949 950

    f = check_valid_file(fn)
    if f is not None:
        return f
951
    raise TestFailed('invalid resource %r' % fn)
952

953

Benjamin Peterson's avatar
Benjamin Peterson committed
954 955 956
class WarningsRecorder(object):
    """Convenience wrapper for the warnings list returned on
       entry to the warnings.catch_warnings() context manager.
957
    """
Benjamin Peterson's avatar
Benjamin Peterson committed
958
    def __init__(self, warnings_list):
959 960
        self._warnings = warnings_list
        self._last = 0
Benjamin Peterson's avatar
Benjamin Peterson committed
961 962

    def __getattr__(self, attr):
963 964
        if len(self._warnings) > self._last:
            return getattr(self._warnings[-1], attr)
Benjamin Peterson's avatar
Benjamin Peterson committed
965 966 967 968
        elif attr in warnings.WarningMessage._WARNING_DETAILS:
            return None
        raise AttributeError("%r has no attribute %r" % (self, attr))

969 970 971 972
    @property
    def warnings(self):
        return self._warnings[self._last:]

Benjamin Peterson's avatar
Benjamin Peterson committed
973
    def reset(self):
974
        self._last = len(self._warnings)
Benjamin Peterson's avatar
Benjamin Peterson committed
975

976 977 978 979 980 981 982 983 984 985 986 987

def _filterwarnings(filters, quiet=False):
    """Catch the warnings, then check if all the expected
    warnings have been raised and re-raise unexpected warnings.
    If 'quiet' is True, only re-raise the unexpected warnings.
    """
    # Clear the warning registry of the calling module
    # in order to re-raise the warnings.
    frame = sys._getframe(2)
    registry = frame.f_globals.get('__warningregistry__')
    if registry:
        registry.clear()
Benjamin Peterson's avatar
Benjamin Peterson committed
988
    with warnings.catch_warnings(record=True) as w:
989 990 991 992
        # Set filter "always" to record all warnings.  Because
        # test_warnings swap the module, we need to look up in
        # the sys.modules dictionary.
        sys.modules['warnings'].simplefilter("always")
Benjamin Peterson's avatar
Benjamin Peterson committed
993
        yield WarningsRecorder(w)
994
    # Filter the recorded warnings
995
    reraise = list(w)
996 997 998
    missing = []
    for msg, cat in filters:
        seen = False
999 1000
        for w in reraise[:]:
            warning = w.message
1001
            # Filter out the matching messages
1002 1003
            if (re.match(msg, str(warning), re.I) and
                issubclass(warning.__class__, cat)):
1004
                seen = True
1005
                reraise.remove(w)
1006 1007 1008 1009
        if not seen and not quiet:
            # This filter caught nothing
            missing.append((msg, cat.__name__))
    if reraise:
1010
        raise AssertionError("unhandled warning %s" % reraise[0])
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
    if missing:
        raise AssertionError("filter (%r, %s) did not catch any warning" %
                             missing[0])


@contextlib.contextmanager
def check_warnings(*filters, **kwargs):
    """Context manager to silence warnings.

    Accept 2-tuples as positional arguments:
        ("message regexp", WarningCategory)

    Optional argument:
     - if 'quiet' is True, it does not fail if a filter catches nothing
1025 1026
        (default True without argument,
         default False if some filters are defined)
1027 1028

    Without argument, it defaults to:
1029
        check_warnings(("", Warning), quiet=True)
1030
    """
1031
    quiet = kwargs.get('quiet')
1032 1033
    if not filters:
        filters = (("", Warning),)
1034 1035 1036 1037
        # Preserve backward compatibility
        if quiet is None:
            quiet = True
    return _filterwarnings(filters, quiet)
1038

Alexandre Vassalotti's avatar
Alexandre Vassalotti committed
1039 1040 1041 1042 1043

class CleanImport(object):
    """Context manager to force import to return a new module reference.

    This is useful for testing module-level behaviours, such as
1044
    the emission of a DeprecationWarning on import.
Alexandre Vassalotti's avatar
Alexandre Vassalotti committed
1045 1046 1047 1048

    Use like this:

        with CleanImport("foo"):
1049
            importlib.import_module("foo") # new reference
Alexandre Vassalotti's avatar
Alexandre Vassalotti committed
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
    """

    def __init__(self, *module_names):
        self.original_modules = sys.modules.copy()
        for module_name in module_names:
            if module_name in sys.modules:
                module = sys.modules[module_name]
                # It is possible that module_name is just an alias for
                # another module (e.g. stub for modules renamed in 3.x).
                # In that case, we also need delete the real module to clear
                # the import cache.
                if module.__name__ != module_name:
                    del sys.modules[module.__name__]
                del sys.modules[module_name]

    def __enter__(self):
        return self

    def __exit__(self, *ignore_exc):
        sys.modules.update(self.original_modules)


1072
class EnvironmentVarGuard(collections.abc.MutableMapping):
1073 1074 1075 1076 1077

    """Class to help protect the environment variable properly.  Can be used as
    a context manager."""

    def __init__(self):
1078
        self._environ = os.environ
1079
        self._changed = {}
1080

1081 1082 1083 1084
    def __getitem__(self, envvar):
        return self._environ[envvar]

    def __setitem__(self, envvar, value):
1085 1086
        # Remember the initial value on the first access
        if envvar not in self._changed:
1087 1088
            self._changed[envvar] = self._environ.get(envvar)
        self._environ[envvar] = value
1089

1090
    def __delitem__(self, envvar):
1091 1092
        # Remember the initial value on the first access
        if envvar not in self._changed:
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
            self._changed[envvar] = self._environ.get(envvar)
        if envvar in self._environ:
            del self._environ[envvar]

    def keys(self):
        return self._environ.keys()

    def __iter__(self):
        return iter(self._environ)

    def __len__(self):
        return len(self._environ)

    def set(self, envvar, value):
        self[envvar] = value

    def unset(self, envvar):
        del self[envvar]
1111 1112 1113 1114 1115

    def __enter__(self):
        return self

    def __exit__(self, *ignore_exc):
1116 1117
        for (k, v) in self._changed.items():
            if v is None:
1118 1119
                if k in self._environ:
                    del self._environ[k]
1120
            else:
1121
                self._environ[k] = v
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147
        os.environ = self._environ


class DirsOnSysPath(object):
    """Context manager to temporarily add directories to sys.path.

    This makes a copy of sys.path, appends any directories given
    as positional arguments, then reverts sys.path to the copied
    settings when the context ends.

    Note that *all* sys.path modifications in the body of the
    context manager, including replacement of the object,
    will be reverted at the end of the block.
    """

    def __init__(self, *paths):
        self.original_value = sys.path[:]
        self.original_object = sys.path
        sys.path.extend(paths)

    def __enter__(self):
        return self

    def __exit__(self, *ignore_exc):
        sys.path = self.original_object
        sys.path[:] = self.original_value
1148

1149

1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
class TransientResource(object):

    """Raise ResourceDenied if an exception is raised while the context manager
    is in effect that matches the specified exception and attributes."""

    def __init__(self, exc, **kwargs):
        self.exc = exc
        self.attrs = kwargs

    def __enter__(self):
        return self

    def __exit__(self, type_=None, value=None, traceback=None):
        """If type_ is a subclass of self.exc and value has attributes matching
        self.attrs, raise ResourceDenied.  Otherwise let the exception
        propagate (if any)."""
        if type_ is not None and issubclass(self.exc, type_):
            for attr, attr_value in self.attrs.items():
                if not hasattr(value, attr):
                    break
                if getattr(value, attr) != attr_value:
                    break
            else:
                raise ResourceDenied("an optional resource is not available")

1175 1176
# Context managers that raise ResourceDenied when various issues
# with the Internet connection manifest themselves as exceptions.
1177
# XXX deprecate these and use transient_internet() instead
1178 1179 1180
time_out = TransientResource(IOError, errno=errno.ETIMEDOUT)
socket_peer_reset = TransientResource(socket.error, errno=errno.ECONNRESET)
ioerror_peer_reset = TransientResource(IOError, errno=errno.ECONNRESET)
1181

1182

1183
@contextlib.contextmanager
1184
def transient_internet(resource_name, *, timeout=30.0, errnos=()):
1185 1186
    """Return a context manager that raises ResourceDenied when various issues
    with the Internet connection manifest themselves as exceptions."""
1187 1188 1189
    default_errnos = [
        ('ECONNREFUSED', 111),
        ('ECONNRESET', 104),
1190
        ('EHOSTUNREACH', 113),
1191 1192 1193
        ('ENETUNREACH', 101),
        ('ETIMEDOUT', 110),
    ]
1194
    default_gai_errnos = [
1195
        ('EAI_AGAIN', -3),
1196
        ('EAI_FAIL', -4),
1197 1198
        ('EAI_NONAME', -2),
        ('EAI_NODATA', -5),
1199 1200
        # Encountered when trying to resolve IPv6-only hostnames
        ('WSANO_DATA', 11004),
1201
    ]
1202

1203
    denied = ResourceDenied("Resource %r is not available" % resource_name)
1204
    captured_errnos = errnos
1205
    gai_errnos = []
1206 1207 1208
    if not captured_errnos:
        captured_errnos = [getattr(errno, name, num)
                           for (name, num) in default_errnos]
1209 1210
        gai_errnos = [getattr(socket, name, num)
                      for (name, num) in default_gai_errnos]
1211 1212

    def filter_error(err):
1213
        n = getattr(err, 'errno', None)
1214
        if (isinstance(err, socket.timeout) or
1215 1216
            (isinstance(err, socket.gaierror) and n in gai_errnos) or
            n in captured_errnos):
1217 1218 1219 1220 1221 1222 1223 1224
            if not verbose:
                sys.stderr.write(denied.args[0] + "\n")
            raise denied from err

    old_timeout = socket.getdefaulttimeout()
    try:
        if timeout is not None:
            socket.setdefaulttimeout(timeout)
1225
        yield
1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
    except IOError as err:
        # urllib can wrap original socket errors multiple times (!), we must
        # unwrap to get at the original error.
        while True:
            a = err.args
            if len(a) >= 1 and isinstance(a[0], IOError):
                err = a[0]
            # The error can also be wrapped as args[1]:
            #    except socket.error as msg:
            #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])
            elif len(a) >= 2 and isinstance(a[1], IOError):
                err = a[1]
            else:
                break
        filter_error(err)
        raise
    # XXX should we catch generic exceptions and look for their
    # __cause__ or __context__?
    finally:
        socket.setdefaulttimeout(old_timeout)
1246 1247


1248
@contextlib.contextmanager
1249
def captured_output(stream_name):
Ezio Melotti's avatar
Ezio Melotti committed
1250
    """Return a context manager used by captured_stdout/stdin/stderr
1251
    that temporarily replaces the sys stream *stream_name* with a StringIO."""
1252
    import io
1253 1254
    orig_stdout = getattr(sys, stream_name)
    setattr(sys, stream_name, io.StringIO())
Christian Heimes's avatar
Christian Heimes committed
1255 1256 1257 1258
    try:
        yield getattr(sys, stream_name)
    finally:
        setattr(sys, stream_name, orig_stdout)
1259 1260

def captured_stdout():
1261 1262
    """Capture the output of sys.stdout:

1263
       with captured_stdout() as stdout:
1264
           print("hello")
1265
       self.assertEqual(stdout.getvalue(), "hello\n")
1266
    """
1267
    return captured_output("stdout")
1268

1269
def captured_stderr():
1270 1271 1272 1273 1274 1275
    """Capture the output of sys.stderr:

       with captured_stderr() as stderr:
           print("hello", file=sys.stderr)
       self.assertEqual(stderr.getvalue(), "hello\n")
    """
1276 1277
    return captured_output("stderr")

1278
def captured_stdin():
1279 1280 1281 1282 1283 1284 1285 1286 1287
    """Capture the input to sys.stdin:

       with captured_stdin() as stdin:
           stdin.write('hello\n')
           stdin.seek(0)
           # call test code that consumes from sys.stdin
           captured = input()
       self.assertEqual(captured, "hello")
    """
1288 1289
    return captured_output("stdin")

Ezio Melotti's avatar
Ezio Melotti committed
1290

Benjamin Peterson's avatar
Benjamin Peterson committed
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301
def gc_collect():
    """Force as many objects as possible to be collected.

    In non-CPython implementations of Python, this is needed because timely
    deallocation is not guaranteed by the garbage collector.  (Even in CPython
    this can be the case in case of reference cycles.)  This means that __del__
    methods may be called later than expected and weakrefs may remain alive for
    longer than expected.  This function tries its best to force all garbage
    objects to disappear.
    """
    gc.collect()
1302 1303
    if is_jython:
        time.sleep(0.1)
Benjamin Peterson's avatar
Benjamin Peterson committed
1304 1305 1306
    gc.collect()
    gc.collect()

1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
@contextlib.contextmanager
def disable_gc():
    have_gc = gc.isenabled()
    gc.disable()
    try:
        yield
    finally:
        if have_gc:
            gc.enable()

1317

Benjamin Peterson's avatar
Benjamin Peterson committed
1318 1319
def python_is_optimized():
    """Find if Python was built with optimizations."""
1320
    cflags = sysconfig.get_config_var('PY_CFLAGS') or ''
Benjamin Peterson's avatar
Benjamin Peterson committed
1321 1322 1323 1324
    final_opt = ""
    for opt in cflags.split():
        if opt.startswith('-O'):
            final_opt = opt
1325
    return final_opt != '' and final_opt != '-O0'
Benjamin Peterson's avatar
Benjamin Peterson committed
1326 1327


Martin v. Löwis's avatar
Martin v. Löwis committed
1328 1329
_header = 'nP'
_align = '0n'
1330 1331
if hasattr(sys, "gettotalrefcount"):
    _header = '2P' + _header
Martin v. Löwis's avatar
Martin v. Löwis committed
1332 1333
    _align = '0P'
_vheader = _header + 'n'
1334 1335

def calcobjsize(fmt):
Martin v. Löwis's avatar
Martin v. Löwis committed
1336
    return struct.calcsize(_header + fmt + _align)
1337 1338

def calcvobjsize(fmt):
Martin v. Löwis's avatar
Martin v. Löwis committed
1339
    return struct.calcsize(_vheader + fmt + _align)
1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354


_TPFLAGS_HAVE_GC = 1<<14
_TPFLAGS_HEAPTYPE = 1<<9

def check_sizeof(test, o, size):
    result = sys.getsizeof(o)
    # add GC header size
    if ((type(o) == type) and (o.__flags__ & _TPFLAGS_HEAPTYPE) or\
        ((type(o) != type) and (type(o).__flags__ & _TPFLAGS_HAVE_GC))):
        size += _testcapi.SIZEOF_PYGC_HEAD
    msg = 'wrong size for %s: got %d, expected %d' \
            % (type(o), result, size)
    test.assertEqual(result, size, msg)

1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385
#=======================================================================
# Decorator for running a function in a different locale, correctly resetting
# it afterwards.

def run_with_locale(catstr, *locales):
    def decorator(func):
        def inner(*args, **kwds):
            try:
                import locale
                category = getattr(locale, catstr)
                orig_locale = locale.setlocale(category)
            except AttributeError:
                # if the test author gives us an invalid category string
                raise
            except:
                # cannot retrieve original locale, so do nothing
                locale = orig_locale = None
            else:
                for loc in locales:
                    try:
                        locale.setlocale(category, loc)
                        break
                    except:
                        pass

            # now run the function, resetting the locale on exceptions
            try:
                return func(*args, **kwds)
            finally:
                if locale and orig_locale:
                    locale.setlocale(category, orig_locale)
1386
        inner.__name__ = func.__name__
1387 1388 1389 1390
        inner.__doc__ = func.__doc__
        return inner
    return decorator

1391 1392 1393 1394 1395 1396 1397
#=======================================================================
# Decorator for running a function in a specific timezone, correctly
# resetting it afterwards.

def run_with_tz(tz):
    def decorator(func):
        def inner(*args, **kwds):
1398 1399 1400 1401
            try:
                tzset = time.tzset
            except AttributeError:
                raise unittest.SkipTest("tzset required")
1402 1403 1404 1405 1406
            if 'TZ' in os.environ:
                orig_tz = os.environ['TZ']
            else:
                orig_tz = None
            os.environ['TZ'] = tz
1407
            tzset()
1408 1409 1410 1411 1412

            # now run the function, resetting the tz on exceptions
            try:
                return func(*args, **kwds)
            finally:
1413
                if orig_tz is None:
1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
                    del os.environ['TZ']
                else:
                    os.environ['TZ'] = orig_tz
                time.tzset()

        inner.__name__ = func.__name__
        inner.__doc__ = func.__doc__
        return inner
    return decorator

1424
#=======================================================================
Georg Brandl's avatar
Georg Brandl committed
1425 1426
# Big-memory-test support. Separate from 'resources' because memory use
# should be configurable.
1427 1428 1429 1430 1431 1432

# Some handy shorthands. Note that these are used for byte-limits as well
# as size-limits, in the various bigmem tests
_1M = 1024*1024
_1G = 1024 * _1M
_2G = 2 * _1G
Neal Norwitz's avatar
Neal Norwitz committed
1433
_4G = 4 * _1G
1434

1435
MAX_Py_ssize_t = sys.maxsize
1436

1437 1438
def set_memlimit(limit):
    global max_memuse
Neal Norwitz's avatar
Neal Norwitz committed
1439
    global real_max_memuse
1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450
    sizes = {
        'k': 1024,
        'm': _1M,
        'g': _1G,
        't': 1024*_1G,
    }
    m = re.match(r'(\d+(\.\d+)?) (K|M|G|T)b?$', limit,
                 re.IGNORECASE | re.VERBOSE)
    if m is None:
        raise ValueError('Invalid memory limit %r' % (limit,))
    memlimit = int(float(m.group(1)) * sizes[m.group(3).lower()])
Neal Norwitz's avatar
Neal Norwitz committed
1451
    real_max_memuse = memlimit
1452 1453 1454
    if memlimit > MAX_Py_ssize_t:
        memlimit = MAX_Py_ssize_t
    if memlimit < _2G - 1:
1455 1456 1457
        raise ValueError('Memory limit %r too low to be useful' % (limit,))
    max_memuse = memlimit

1458 1459
class _MemoryWatchdog:
    """An object which periodically watches the process' memory consumption
Antoine Pitrou's avatar
Antoine Pitrou committed
1460 1461
    and prints it out.
    """
1462 1463 1464 1465 1466 1467

    def __init__(self):
        self.procfile = '/proc/{pid}/statm'.format(pid=os.getpid())
        self.started = False

    def start(self):
Antoine Pitrou's avatar
Antoine Pitrou committed
1468
        try:
1469
            f = open(self.procfile, 'r')
1470 1471 1472 1473 1474
        except OSError as e:
            warnings.warn('/proc not available for stats: {}'.format(e),
                          RuntimeWarning)
            sys.stderr.flush()
            return
1475 1476 1477 1478 1479

        watchdog_script = findfile("memory_watchdog.py")
        self.mem_watchdog = subprocess.Popen([sys.executable, watchdog_script],
                                             stdin=f, stderr=subprocess.DEVNULL)
        f.close()
1480 1481 1482
        self.started = True

    def stop(self):
1483 1484 1485
        if self.started:
            self.mem_watchdog.terminate()
            self.mem_watchdog.wait()
1486

Antoine Pitrou's avatar
Antoine Pitrou committed
1487 1488

def bigmemtest(size, memuse, dry_run=True):
1489 1490 1491 1492
    """Decorator for bigmem tests.

    'minsize' is the minimum useful size for the test (in arbitrary,
    test-interpreted units.) 'memuse' is the number of 'bytes per size' for
1493
    the test, or a good estimate of it.
1494

Antoine Pitrou's avatar
Antoine Pitrou committed
1495 1496
    if 'dry_run' is False, it means the test doesn't support dummy runs
    when -M is not specified.
1497
    """
Neal Norwitz's avatar
Neal Norwitz committed
1498 1499
    def decorator(f):
        def wrapper(self):
1500 1501
            size = wrapper.size
            memuse = wrapper.memuse
Neal Norwitz's avatar
Neal Norwitz committed
1502 1503 1504 1505 1506
            if not real_max_memuse:
                maxsize = 5147
            else:
                maxsize = size

1507 1508 1509 1510 1511
            if ((real_max_memuse or not dry_run)
                and real_max_memuse < maxsize * memuse):
                raise unittest.SkipTest(
                    "not enough memory: %.1fG minimum needed"
                    % (size * memuse / (1024 ** 3)))
Neal Norwitz's avatar
Neal Norwitz committed
1512

1513
            if real_max_memuse and verbose:
Antoine Pitrou's avatar
Antoine Pitrou committed
1514 1515 1516
                print()
                print(" ... expected peak memory use: {peak:.1f}G"
                      .format(peak=size * memuse / (1024 ** 3)))
1517 1518
                watchdog = _MemoryWatchdog()
                watchdog.start()
Antoine Pitrou's avatar
Antoine Pitrou committed
1519
            else:
1520
                watchdog = None
Antoine Pitrou's avatar
Antoine Pitrou committed
1521 1522 1523 1524

            try:
                return f(self, maxsize)
            finally:
1525 1526
                if watchdog:
                    watchdog.stop()
Antoine Pitrou's avatar
Antoine Pitrou committed
1527

Neal Norwitz's avatar
Neal Norwitz committed
1528 1529 1530 1531 1532
        wrapper.size = size
        wrapper.memuse = memuse
        return wrapper
    return decorator

1533 1534 1535 1536
def bigaddrspacetest(f):
    """Decorator for tests that fill the address space."""
    def wrapper(self):
        if max_memuse < MAX_Py_ssize_t:
1537 1538 1539 1540 1541 1542 1543
            if MAX_Py_ssize_t >= 2**63 - 1 and max_memuse >= 2**31:
                raise unittest.SkipTest(
                    "not enough memory: try a 32-bit build instead")
            else:
                raise unittest.SkipTest(
                    "not enough memory: %.1fG minimum needed"
                    % (MAX_Py_ssize_t / (1024 ** 3)))
1544 1545 1546 1547
        else:
            return f(self)
    return wrapper

1548
#=======================================================================
1549
# unittest integration.
1550

1551
class BasicTestRunner:
1552
    def run(self, test):
1553
        result = unittest.TestResult()
1554 1555 1556
        test(result)
        return result

Benjamin Peterson's avatar
Benjamin Peterson committed
1557 1558 1559 1560
def _id(obj):
    return obj

def requires_resource(resource):
1561 1562
    if resource == 'gui' and not _is_gui_available():
        return unittest.skip("resource 'gui' is not available")
1563
    if is_resource_enabled(resource):
Benjamin Peterson's avatar
Benjamin Peterson committed
1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590
        return _id
    else:
        return unittest.skip("resource {0!r} is not enabled".format(resource))

def cpython_only(test):
    """
    Decorator for tests only applicable on CPython.
    """
    return impl_detail(cpython=True)(test)

def impl_detail(msg=None, **guards):
    if check_impl_detail(**guards):
        return _id
    if msg is None:
        guardnames, default = _parse_guards(guards)
        if default:
            msg = "implementation detail not available on {0}"
        else:
            msg = "implementation detail specific to {0}"
        guardnames = sorted(guardnames.keys())
        msg = msg.format(' or '.join(guardnames))
    return unittest.skip(msg)

def _parse_guards(guards):
    # Returns a tuple ({platform_name: run_me}, default_value)
    if not guards:
        return ({'cpython': True}, False)
Eric Smith's avatar
Eric Smith committed
1591 1592
    is_true = list(guards.values())[0]
    assert list(guards.values()) == [is_true] * len(guards)   # all True or all False
Benjamin Peterson's avatar
Benjamin Peterson committed
1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607
    return (guards, not is_true)

# Use the following check to guard CPython's implementation-specific tests --
# or to run them only on the implementation(s) guarded by the arguments.
def check_impl_detail(**guards):
    """This function returns True or False depending on the host platform.
       Examples:
          if check_impl_detail():               # only on CPython (default)
          if check_impl_detail(jython=True):    # only on Jython
          if check_impl_detail(cpython=False):  # everywhere except on CPython
    """
    guards, default = _parse_guards(guards)
    return guards.get(platform.python_implementation().lower(), default)


1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622
def no_tracing(func):
    """Decorator to temporarily turn off tracing for the duration of a test."""
    if not hasattr(sys, 'gettrace'):
        return func
    else:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            original_trace = sys.gettrace()
            try:
                sys.settrace(None)
                return func(*args, **kwargs)
            finally:
                sys.settrace(original_trace)
        return wrapper

1623

1624 1625 1626 1627 1628 1629 1630 1631 1632 1633
def refcount_test(test):
    """Decorator for tests which involve reference counting.

    To start, the decorator does not run the test if is not run by CPython.
    After that, any trace function is unset during the test to prevent
    unexpected refcounts caused by the trace function.

    """
    return no_tracing(cpython_only(test))

1634

1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646
def _filter_suite(suite, pred):
    """Recursively filter test cases in a suite based on a predicate."""
    newtests = []
    for test in suite._tests:
        if isinstance(test, unittest.TestSuite):
            _filter_suite(test, pred)
            newtests.append(test)
        else:
            if pred(test):
                newtests.append(test)
    suite._tests = newtests

1647
def _run_suite(suite):
1648
    """Run tests from a unittest.TestSuite-derived class."""
1649
    if verbose:
1650 1651
        runner = unittest.TextTestRunner(sys.stdout, verbosity=2,
                                         failfast=failfast)
1652
    else:
1653
        runner = BasicTestRunner()
1654

1655 1656
    result = runner.run(suite)
    if not result.wasSuccessful():
1657 1658 1659 1660 1661
        if len(result.errors) == 1 and not result.failures:
            err = result.errors[0][1]
        elif len(result.failures) == 1 and not result.errors:
            err = result.failures[0][1]
        else:
1662 1663
            err = "multiple errors occurred"
            if not verbose: err += "; run in verbose mode for details"
1664
        raise TestFailed(err)
1665

1666

1667 1668
def run_unittest(*classes):
    """Run tests from unittest.TestCase-derived classes."""
1669
    valid_types = (unittest.TestSuite, unittest.TestCase)
1670
    suite = unittest.TestSuite()
1671
    for cls in classes:
1672
        if isinstance(cls, str):
1673 1674 1675 1676 1677
            if cls in sys.modules:
                suite.addTest(unittest.findTestCases(sys.modules[cls]))
            else:
                raise ValueError("str arguments must be keys in sys.modules")
        elif isinstance(cls, valid_types):
1678 1679 1680
            suite.addTest(cls)
        else:
            suite.addTest(unittest.makeSuite(cls))
1681 1682 1683 1684 1685 1686 1687 1688
    def case_pred(test):
        if match_tests is None:
            return True
        for name in test.id().split("."):
            if fnmatch.fnmatchcase(name, match_tests):
                return True
        return False
    _filter_suite(suite, case_pred)
1689
    _run_suite(suite)
1690

1691 1692 1693 1694 1695 1696 1697 1698 1699 1700
#=======================================================================
# Check for the presence of docstrings.

HAVE_DOCSTRINGS = (check_impl_detail(cpython=False) or
                   sys.platform == 'win32' or
                   sysconfig.get_config_var('WITH_DOC_STRINGS'))

requires_docstrings = unittest.skipUnless(HAVE_DOCSTRINGS,
                                          "test requires docstrings")

1701

1702 1703 1704
#=======================================================================
# doctest driver.

1705
def run_doctest(module, verbosity=None, optionflags=0):
1706
    """Run doctest on the given module.  Return (#failures, #tests).
1707 1708

    If optional argument verbosity is not specified (or is None), pass
1709
    support's belief about verbosity on to doctest.  Else doctest's
Tim Peters's avatar
Tim Peters committed
1710
    usual behavior is used (it searches sys.argv for -v).
1711 1712 1713 1714 1715 1716 1717 1718 1719
    """

    import doctest

    if verbosity is None:
        verbosity = verbose
    else:
        verbosity = None

1720
    f, t = doctest.testmod(module, verbose=verbosity, optionflags=optionflags)
1721 1722
    if f:
        raise TestFailed("%d of %d doctests failed" % (f, t))
1723
    if verbose:
Georg Brandl's avatar
Georg Brandl committed
1724 1725
        print('doctest (%s) ... %d tests with zero failures' %
              (module.__name__, t))
1726
    return f, t
1727

1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742

#=======================================================================
# Support for saving and restoring the imported modules.

def modules_setup():
    return sys.modules.copy(),

def modules_cleanup(oldmodules):
    # Encoders/decoders are registered permanently within the internal
    # codec cache. If we destroy the corresponding modules their
    # globals will be set to None which will trip up the cached functions.
    encodings = [(k, v) for k, v in sys.modules.items()
                 if k.startswith('encodings.')]
    sys.modules.clear()
    sys.modules.update(encodings)
1743
    # XXX: This kind of problem can affect more than just encodings. In particular
Eric Smith's avatar
Typo.  
Eric Smith committed
1744
    # extension modules (such as _ssl) don't cope with reloading properly.
1745 1746 1747 1748
    # Really, test modules should be cleaning out the test specific modules they
    # know they added (ala test_runpy) rather than relying on this function (as
    # test_importhooks and test_pkg do currently).
    # Implicitly imported *real* modules should be left alone (see issue 10556).
1749 1750
    sys.modules.update(oldmodules)

1751 1752 1753
#=======================================================================
# Threading support to prevent reporting refleaks when running regrtest.py -R

1754 1755 1756 1757 1758 1759 1760
# NOTE: we use thread._count() rather than threading.enumerate() (or the
# moral equivalent thereof) because a threading.Thread object is still alive
# until its __bootstrap() method has returned, even after it has been
# unregistered from the threading module.
# thread._count(), on the other hand, only gets decremented *after* the
# __bootstrap() method has returned, which gives us reliable reference counts
# at the end of a test run.
1761 1762

def threading_setup():
1763
    if _thread:
1764
        return _thread._count(), threading._dangling.copy()
1765
    else:
1766
        return 1, ()
1767

1768
def threading_cleanup(*original_values):
1769 1770
    if not _thread:
        return
1771
    _MAX_COUNT = 10
1772
    for count in range(_MAX_COUNT):
1773 1774
        values = _thread._count(), threading._dangling
        if values == original_values:
1775
            break
1776
        time.sleep(0.1)
1777
        gc_collect()
1778
    # XXX print a warning in case of failure?
1779

Benjamin Peterson's avatar
Benjamin Peterson committed
1780
def reap_threads(func):
1781 1782 1783 1784 1785 1786 1787
    """Use this function when threads are being used.  This will
    ensure that the threads are cleaned up even when the test fails.
    If threading is unavailable this function does nothing.
    """
    if not _thread:
        return func

Benjamin Peterson's avatar
Benjamin Peterson committed
1788 1789 1790 1791 1792 1793 1794 1795 1796
    @functools.wraps(func)
    def decorator(*args):
        key = threading_setup()
        try:
            return func(*args)
        finally:
            threading_cleanup(*key)
    return decorator

1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815
def reap_children():
    """Use this function at the end of test_main() whenever sub-processes
    are started.  This will help ensure that no extra children (zombies)
    stick around to hog resources and create problems when looking
    for refleaks.
    """

    # Reap all our dead child processes so we don't leave zombies around.
    # These hog resources and might be causing some of the buildbots to die.
    if hasattr(os, 'waitpid'):
        any_process = -1
        while True:
            try:
                # This will raise an exception on Windows.  That's ok.
                pid, status = os.waitpid(any_process, os.WNOHANG)
                if pid == 0:
                    break
            except:
                break
1816

1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869
@contextlib.contextmanager
def swap_attr(obj, attr, new_val):
    """Temporary swap out an attribute with a new object.

    Usage:
        with swap_attr(obj, "attr", 5):
            ...

        This will set obj.attr to 5 for the duration of the with: block,
        restoring the old value at the end of the block. If `attr` doesn't
        exist on `obj`, it will be created and then deleted at the end of the
        block.
    """
    if hasattr(obj, attr):
        real_val = getattr(obj, attr)
        setattr(obj, attr, new_val)
        try:
            yield
        finally:
            setattr(obj, attr, real_val)
    else:
        setattr(obj, attr, new_val)
        try:
            yield
        finally:
            delattr(obj, attr)

@contextlib.contextmanager
def swap_item(obj, item, new_val):
    """Temporary swap out an item with a new object.

    Usage:
        with swap_item(obj, "item", 5):
            ...

        This will set obj["item"] to 5 for the duration of the with: block,
        restoring the old value at the end of the block. If `item` doesn't
        exist on `obj`, it will be created and then deleted at the end of the
        block.
    """
    if item in obj:
        real_val = obj[item]
        obj[item] = new_val
        try:
            yield
        finally:
            obj[item] = real_val
    else:
        obj[item] = new_val
        try:
            yield
        finally:
            del obj[item]
1870

1871 1872 1873 1874 1875 1876 1877
def strip_python_stderr(stderr):
    """Strip the stderr of a Python process from potential debug output
    emitted by the interpreter.

    This will typically be run on the result of the communicate() method
    of a subprocess.Popen object.
    """
1878
    stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip()
1879
    return stderr
1880

1881 1882
def args_from_interpreter_flags():
    """Return a list of command-line arguments reproducing the current
1883
    settings in sys.flags and sys.warnoptions."""
1884
    return subprocess._args_from_interpreter_flags()
1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949

#============================================================
# Support for assertions about logging.
#============================================================

class TestHandler(logging.handlers.BufferingHandler):
    def __init__(self, matcher):
        # BufferingHandler takes a "capacity" argument
        # so as to know when to flush. As we're overriding
        # shouldFlush anyway, we can set a capacity of zero.
        # You can call flush() manually to clear out the
        # buffer.
        logging.handlers.BufferingHandler.__init__(self, 0)
        self.matcher = matcher

    def shouldFlush(self):
        return False

    def emit(self, record):
        self.format(record)
        self.buffer.append(record.__dict__)

    def matches(self, **kwargs):
        """
        Look for a saved dict whose keys/values match the supplied arguments.
        """
        result = False
        for d in self.buffer:
            if self.matcher.matches(d, **kwargs):
                result = True
                break
        return result

class Matcher(object):

    _partial_matches = ('msg', 'message')

    def matches(self, d, **kwargs):
        """
        Try to match a single dict with the supplied arguments.

        Keys whose values are strings and which are in self._partial_matches
        will be checked for partial (i.e. substring) matches. You can extend
        this scheme to (for example) do regular expression matching, etc.
        """
        result = True
        for k in kwargs:
            v = kwargs[k]
            dv = d.get(k)
            if not self.match_value(k, dv, v):
                result = False
                break
        return result

    def match_value(self, k, dv, v):
        """
        Try to match a single stored value (dv) with a supplied value (v).
        """
        if type(v) != type(dv):
            result = False
        elif type(dv) is not str or k not in self._partial_matches:
            result = (v == dv)
        else:
            result = dv.find(v) >= 0
        return result
1950 1951 1952 1953 1954 1955 1956


_can_symlink = None
def can_symlink():
    global _can_symlink
    if _can_symlink is not None:
        return _can_symlink
1957
    symlink_path = TESTFN + "can_symlink"
1958
    try:
1959
        os.symlink(TESTFN, symlink_path)
1960
        can = True
1961
    except (OSError, NotImplementedError, AttributeError):
1962
        can = False
1963 1964
    else:
        os.remove(symlink_path)
1965 1966 1967 1968 1969 1970 1971 1972
    _can_symlink = can
    return can

def skip_unless_symlink(test):
    """Skip decorator for tests that require functional symlink"""
    ok = can_symlink()
    msg = "Requires functional symlink implementation"
    return test if ok else unittest.skip(msg)(test)
Antoine Pitrou's avatar
Antoine Pitrou committed
1973

1974 1975 1976 1977 1978 1979 1980 1981
_can_xattr = None
def can_xattr():
    global _can_xattr
    if _can_xattr is not None:
        return _can_xattr
    if not hasattr(os, "setxattr"):
        can = False
    else:
1982
        tmp_fp, tmp_name = tempfile.mkstemp()
1983 1984 1985
        try:
            with open(TESTFN, "wb") as fp:
                try:
1986 1987
                    # TESTFN & tempfile may use different file systems with
                    # different capabilities
1988 1989
                    os.setxattr(tmp_fp, b"user.test", b"")
                    os.setxattr(fp.fileno(), b"user.test", b"")
1990 1991 1992 1993 1994 1995 1996 1997
                    # Kernels < 2.6.39 don't respect setxattr flags.
                    kernel_version = platform.release()
                    m = re.match("2.6.(\d{1,2})", kernel_version)
                    can = m is None or int(m.group(1)) >= 39
                except OSError:
                    can = False
        finally:
            unlink(TESTFN)
1998
            unlink(tmp_name)
1999 2000 2001 2002 2003 2004 2005 2006 2007
    _can_xattr = can
    return can

def skip_unless_xattr(test):
    """Skip decorator for tests that require functional extended attributes"""
    ok = can_xattr()
    msg = "no non-broken extended attribute support"
    return test if ok else unittest.skip(msg)(test)

2008 2009 2010 2011 2012 2013

if sys.platform.startswith('win'):
    @contextlib.contextmanager
    def suppress_crash_popup():
        """Disable Windows Error Reporting dialogs using SetErrorMode."""
        # see http://msdn.microsoft.com/en-us/library/windows/desktop/ms680621%28v=vs.85%29.aspx
2014 2015
        # GetErrorMode is not available on Windows XP and Windows Server 2003,
        # but SetErrorMode returns the previous value, so we can use that
2016 2017 2018
        import ctypes
        k32 = ctypes.windll.kernel32
        SEM_NOGPFAULTERRORBOX = 0x02
2019
        old_error_mode = k32.SetErrorMode(SEM_NOGPFAULTERRORBOX)
2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031
        k32.SetErrorMode(old_error_mode | SEM_NOGPFAULTERRORBOX)
        try:
            yield
        finally:
            k32.SetErrorMode(old_error_mode)
else:
    # this is a no-op for other platforms
    @contextlib.contextmanager
    def suppress_crash_popup():
        yield


2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063
def patch(test_instance, object_to_patch, attr_name, new_value):
    """Override 'object_to_patch'.'attr_name' with 'new_value'.

    Also, add a cleanup procedure to 'test_instance' to restore
    'object_to_patch' value for 'attr_name'.
    The 'attr_name' should be a valid attribute for 'object_to_patch'.

    """
    # check that 'attr_name' is a real attribute for 'object_to_patch'
    # will raise AttributeError if it does not exist
    getattr(object_to_patch, attr_name)

    # keep a copy of the old value
    attr_is_local = False
    try:
        old_value = object_to_patch.__dict__[attr_name]
    except (AttributeError, KeyError):
        old_value = getattr(object_to_patch, attr_name, None)
    else:
        attr_is_local = True

    # restore the value when the test is done
    def cleanup():
        if attr_is_local:
            setattr(object_to_patch, attr_name, old_value)
        else:
            delattr(object_to_patch, attr_name)

    test_instance.addCleanup(cleanup)

    # actually override the attribute
    setattr(object_to_patch, attr_name, new_value)