Commit 041dd8ee authored by Serhiy Storchaka's avatar Serhiy Storchaka

Issue #15836: assertRaises(), assertRaisesRegex(), assertWarns() and

assertWarnsRegex() assertments now check the type of the first argument
to prevent possible user error.  Based on patch by Daniel Wagner-Hall.
parent ff542236
...@@ -97,7 +97,6 @@ class InspectLoaderTests: ...@@ -97,7 +97,6 @@ class InspectLoaderTests:
method = getattr(self.machinery.BuiltinImporter, meth_name) method = getattr(self.machinery.BuiltinImporter, meth_name)
with self.assertRaises(ImportError) as cm: with self.assertRaises(ImportError) as cm:
method(util.BUILTINS.bad_name) method(util.BUILTINS.bad_name)
self.assertRaises(util.BUILTINS.bad_name)
(Frozen_InspectLoaderTests, (Frozen_InspectLoaderTests,
......
...@@ -119,6 +119,10 @@ def expectedFailure(test_item): ...@@ -119,6 +119,10 @@ def expectedFailure(test_item):
test_item.__unittest_expecting_failure__ = True test_item.__unittest_expecting_failure__ = True
return test_item return test_item
def _is_subtype(expected, basetype):
if isinstance(expected, tuple):
return all(_is_subtype(e, basetype) for e in expected)
return isinstance(expected, type) and issubclass(expected, basetype)
class _BaseTestCaseContext: class _BaseTestCaseContext:
...@@ -148,6 +152,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext): ...@@ -148,6 +152,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext):
If args is not empty, call a callable passing positional and keyword If args is not empty, call a callable passing positional and keyword
arguments. arguments.
""" """
if not _is_subtype(self.expected, self._base_type):
raise TypeError('%s() arg 1 must be %s' %
(name, self._base_type_str))
if args and args[0] is None: if args and args[0] is None:
warnings.warn("callable is None", warnings.warn("callable is None",
DeprecationWarning, 3) DeprecationWarning, 3)
...@@ -172,6 +179,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext): ...@@ -172,6 +179,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext):
class _AssertRaisesContext(_AssertRaisesBaseContext): class _AssertRaisesContext(_AssertRaisesBaseContext):
"""A context manager used to implement TestCase.assertRaises* methods.""" """A context manager used to implement TestCase.assertRaises* methods."""
_base_type = BaseException
_base_type_str = 'an exception type or tuple of exception types'
def __enter__(self): def __enter__(self):
return self return self
...@@ -206,6 +216,9 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): ...@@ -206,6 +216,9 @@ class _AssertRaisesContext(_AssertRaisesBaseContext):
class _AssertWarnsContext(_AssertRaisesBaseContext): class _AssertWarnsContext(_AssertRaisesBaseContext):
"""A context manager used to implement TestCase.assertWarns* methods.""" """A context manager used to implement TestCase.assertWarns* methods."""
_base_type = Warning
_base_type_str = 'a warning type or tuple of warning types'
def __enter__(self): def __enter__(self):
# The __warningregistry__'s need to be in a pristine state for tests # The __warningregistry__'s need to be in a pristine state for tests
# to work properly. # to work properly.
......
...@@ -1185,6 +1185,18 @@ test case ...@@ -1185,6 +1185,18 @@ test case
with self.assertRaises(ExceptionMock): with self.assertRaises(ExceptionMock):
self.assertRaises(ValueError, Stub) self.assertRaises(ValueError, Stub)
def testAssertRaisesNoExceptionType(self):
with self.assertRaises(TypeError):
self.assertRaises()
with self.assertRaises(TypeError):
self.assertRaises(1)
with self.assertRaises(TypeError):
self.assertRaises(object)
with self.assertRaises(TypeError):
self.assertRaises((ValueError, 1))
with self.assertRaises(TypeError):
self.assertRaises((ValueError, object))
def testAssertRaisesRegex(self): def testAssertRaisesRegex(self):
class ExceptionMock(Exception): class ExceptionMock(Exception):
pass pass
...@@ -1258,6 +1270,20 @@ test case ...@@ -1258,6 +1270,20 @@ test case
self.assertIsInstance(e, ExceptionMock) self.assertIsInstance(e, ExceptionMock)
self.assertEqual(e.args[0], v) self.assertEqual(e.args[0], v)
def testAssertRaisesRegexNoExceptionType(self):
with self.assertRaises(TypeError):
self.assertRaisesRegex()
with self.assertRaises(TypeError):
self.assertRaisesRegex(ValueError)
with self.assertRaises(TypeError):
self.assertRaisesRegex(1, 'expect')
with self.assertRaises(TypeError):
self.assertRaisesRegex(object, 'expect')
with self.assertRaises(TypeError):
self.assertRaisesRegex((ValueError, 1), 'expect')
with self.assertRaises(TypeError):
self.assertRaisesRegex((ValueError, object), 'expect')
def testAssertWarnsCallable(self): def testAssertWarnsCallable(self):
def _runtime_warn(): def _runtime_warn():
warnings.warn("foo", RuntimeWarning) warnings.warn("foo", RuntimeWarning)
...@@ -1336,6 +1362,20 @@ test case ...@@ -1336,6 +1362,20 @@ test case
with self.assertWarns(DeprecationWarning): with self.assertWarns(DeprecationWarning):
_runtime_warn() _runtime_warn()
def testAssertWarnsNoExceptionType(self):
with self.assertRaises(TypeError):
self.assertWarns()
with self.assertRaises(TypeError):
self.assertWarns(1)
with self.assertRaises(TypeError):
self.assertWarns(object)
with self.assertRaises(TypeError):
self.assertWarns((UserWarning, 1))
with self.assertRaises(TypeError):
self.assertWarns((UserWarning, object))
with self.assertRaises(TypeError):
self.assertWarns((UserWarning, Exception))
def testAssertWarnsRegexCallable(self): def testAssertWarnsRegexCallable(self):
def _runtime_warn(msg): def _runtime_warn(msg):
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
...@@ -1414,6 +1454,22 @@ test case ...@@ -1414,6 +1454,22 @@ test case
with self.assertWarnsRegex(RuntimeWarning, "o+"): with self.assertWarnsRegex(RuntimeWarning, "o+"):
_runtime_warn("barz") _runtime_warn("barz")
def testAssertWarnsRegexNoExceptionType(self):
with self.assertRaises(TypeError):
self.assertWarnsRegex()
with self.assertRaises(TypeError):
self.assertWarnsRegex(UserWarning)
with self.assertRaises(TypeError):
self.assertWarnsRegex(1, 'expect')
with self.assertRaises(TypeError):
self.assertWarnsRegex(object, 'expect')
with self.assertRaises(TypeError):
self.assertWarnsRegex((UserWarning, 1), 'expect')
with self.assertRaises(TypeError):
self.assertWarnsRegex((UserWarning, object), 'expect')
with self.assertRaises(TypeError):
self.assertWarnsRegex((UserWarning, Exception), 'expect')
@contextlib.contextmanager @contextlib.contextmanager
def assertNoStderr(self): def assertNoStderr(self):
with captured_stderr() as buf: with captured_stderr() as buf:
......
...@@ -1472,6 +1472,7 @@ Alex Volkov ...@@ -1472,6 +1472,7 @@ Alex Volkov
Martijn Vries Martijn Vries
Sjoerd de Vries Sjoerd de Vries
Guido Vranken Guido Vranken
Daniel Wagner-Hall
Niki W. Waibel Niki W. Waibel
Wojtek Walczak Wojtek Walczak
Charles Waldman Charles Waldman
......
...@@ -52,6 +52,10 @@ Core and Builtins ...@@ -52,6 +52,10 @@ Core and Builtins
Library Library
------- -------
- Issue #15836: assertRaises(), assertRaisesRegex(), assertWarns() and
assertWarnsRegex() assertments now check the type of the first argument
to prevent possible user error. Based on patch by Daniel Wagner-Hall.
- Issue #9858: Add missing method stubs to _io.RawIOBase. Patch by Laura - Issue #9858: Add missing method stubs to _io.RawIOBase. Patch by Laura
Rupprecht. Rupprecht.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment