Commit 28f25a5f authored by Michael Foord's avatar Michael Foord

Support subclassing unittest.mock._patch and fix various obscure bugs around patcher spec arguments

parent 75b1eb30
...@@ -998,7 +998,7 @@ class _patch(object): ...@@ -998,7 +998,7 @@ class _patch(object):
raise ValueError( raise ValueError(
"Cannot use 'new' and 'new_callable' together" "Cannot use 'new' and 'new_callable' together"
) )
if autospec is not False: if autospec is not None:
raise ValueError( raise ValueError(
"Cannot use 'autospec' and 'new_callable' together" "Cannot use 'autospec' and 'new_callable' together"
) )
...@@ -1059,6 +1059,7 @@ class _patch(object): ...@@ -1059,6 +1059,7 @@ class _patch(object):
extra_args = [] extra_args = []
entered_patchers = [] entered_patchers = []
exc_info = tuple()
try: try:
for patching in patched.patchings: for patching in patched.patchings:
arg = patching.__enter__() arg = patching.__enter__()
...@@ -1076,11 +1077,13 @@ class _patch(object): ...@@ -1076,11 +1077,13 @@ class _patch(object):
# the patcher may have been started, but an exception # the patcher may have been started, but an exception
# raised whilst entering one of its additional_patchers # raised whilst entering one of its additional_patchers
entered_patchers.append(patching) entered_patchers.append(patching)
# Pass the exception to __exit__
exc_info = sys.exc_info()
# re-raise the exception # re-raise the exception
raise raise
finally: finally:
for patching in reversed(entered_patchers): for patching in reversed(entered_patchers):
patching.__exit__() patching.__exit__(*exc_info)
patched.patchings = [self] patched.patchings = [self]
if hasattr(func, 'func_code'): if hasattr(func, 'func_code'):
...@@ -1120,17 +1123,40 @@ class _patch(object): ...@@ -1120,17 +1123,40 @@ class _patch(object):
new_callable = self.new_callable new_callable = self.new_callable
self.target = self.getter() self.target = self.getter()
# normalise False to None
if spec is False:
spec = None
if spec_set is False:
spec_set = None
if autospec is False:
autospec = None
if spec is not None and autospec is not None:
raise TypeError("Can't specify spec and autospec")
if ((spec is not None or autospec is not None) and
spec_set not in (True, None)):
raise TypeError("Can't provide explicit spec_set *and* spec or autospec")
original, local = self.get_original() original, local = self.get_original()
if new is DEFAULT and autospec is False: if new is DEFAULT and autospec is None:
inherit = False inherit = False
if spec_set == True: if spec is True:
spec_set = original
elif spec == True:
# set spec to the object we are replacing # set spec to the object we are replacing
spec = original spec = original
if spec_set is True:
spec_set = original
spec = None
elif spec is not None:
if spec_set is True:
spec_set = spec
spec = None
elif spec_set is True:
spec_set = original
if (spec or spec_set) is not None: if spec is not None or spec_set is not None:
if original is DEFAULT:
raise TypeError("Can't use 'spec' with create=True")
if isinstance(original, type): if isinstance(original, type):
# If we're patching out a class and there is a spec # If we're patching out a class and there is a spec
inherit = True inherit = True
...@@ -1139,7 +1165,7 @@ class _patch(object): ...@@ -1139,7 +1165,7 @@ class _patch(object):
_kwargs = {} _kwargs = {}
if new_callable is not None: if new_callable is not None:
Klass = new_callable Klass = new_callable
elif (spec or spec_set) is not None: elif spec is not None or spec_set is not None:
if not _callable(spec or spec_set): if not _callable(spec or spec_set):
Klass = NonCallableMagicMock Klass = NonCallableMagicMock
...@@ -1159,14 +1185,17 @@ class _patch(object): ...@@ -1159,14 +1185,17 @@ class _patch(object):
if inherit and _is_instance_mock(new): if inherit and _is_instance_mock(new):
# we can only tell if the instance should be callable if the # we can only tell if the instance should be callable if the
# spec is not a list # spec is not a list
if (not _is_list(spec or spec_set) and not this_spec = spec
_instance_callable(spec or spec_set)): if spec_set is not None:
this_spec = spec_set
if (not _is_list(this_spec) and not
_instance_callable(this_spec)):
Klass = NonCallableMagicMock Klass = NonCallableMagicMock
_kwargs.pop('name') _kwargs.pop('name')
new.return_value = Klass(_new_parent=new, _new_name='()', new.return_value = Klass(_new_parent=new, _new_name='()',
**_kwargs) **_kwargs)
elif autospec is not False: elif autospec is not None:
# spec is ignored, new *must* be default, spec_set is treated # spec is ignored, new *must* be default, spec_set is treated
# as a boolean. Should we check spec is not None and that spec_set # as a boolean. Should we check spec is not None and that spec_set
# is a bool? # is a bool?
...@@ -1175,6 +1204,8 @@ class _patch(object): ...@@ -1175,6 +1204,8 @@ class _patch(object):
"autospec creates the mock for you. Can't specify " "autospec creates the mock for you. Can't specify "
"autospec and new." "autospec and new."
) )
if original is DEFAULT:
raise TypeError("Can't use 'spec' with create=True")
spec_set = bool(spec_set) spec_set = bool(spec_set)
if autospec is True: if autospec is True:
autospec = original autospec = original
...@@ -1204,7 +1235,7 @@ class _patch(object): ...@@ -1204,7 +1235,7 @@ class _patch(object):
return new return new
def __exit__(self, *_): def __exit__(self, *exc_info):
"""Undo the patch.""" """Undo the patch."""
if not _is_started(self): if not _is_started(self):
raise RuntimeError('stop called on unstarted patcher') raise RuntimeError('stop called on unstarted patcher')
...@@ -1222,7 +1253,7 @@ class _patch(object): ...@@ -1222,7 +1253,7 @@ class _patch(object):
del self.target del self.target
for patcher in reversed(self.additional_patchers): for patcher in reversed(self.additional_patchers):
if _is_started(patcher): if _is_started(patcher):
patcher.__exit__() patcher.__exit__(*exc_info)
start = __enter__ start = __enter__
stop = __exit__ stop = __exit__
...@@ -1241,14 +1272,10 @@ def _get_target(target): ...@@ -1241,14 +1272,10 @@ def _get_target(target):
def _patch_object( def _patch_object(
target, attribute, new=DEFAULT, spec=None, target, attribute, new=DEFAULT, spec=None,
create=False, spec_set=None, autospec=False, create=False, spec_set=None, autospec=None,
new_callable=None, **kwargs new_callable=None, **kwargs
): ):
""" """
patch.object(target, attribute, new=DEFAULT, spec=None, create=False,
spec_set=None, autospec=False,
new_callable=None, **kwargs)
patch the named member (`attribute`) on an object (`target`) with a mock patch the named member (`attribute`) on an object (`target`) with a mock
object. object.
...@@ -1268,10 +1295,8 @@ def _patch_object( ...@@ -1268,10 +1295,8 @@ def _patch_object(
) )
def _patch_multiple(target, spec=None, create=False, def _patch_multiple(target, spec=None, create=False, spec_set=None,
spec_set=None, autospec=False, autospec=None, new_callable=None, **kwargs):
new_callable=None, **kwargs
):
"""Perform multiple patches in a single call. It takes the object to be """Perform multiple patches in a single call. It takes the object to be
patched (either as an object or a string to fetch the object by importing) patched (either as an object or a string to fetch the object by importing)
and keyword arguments for the patches:: and keyword arguments for the patches::
...@@ -1321,8 +1346,7 @@ def _patch_multiple(target, spec=None, create=False, ...@@ -1321,8 +1346,7 @@ def _patch_multiple(target, spec=None, create=False,
def patch( def patch(
target, new=DEFAULT, spec=None, create=False, target, new=DEFAULT, spec=None, create=False,
spec_set=None, autospec=False, spec_set=None, autospec=None, new_callable=None, **kwargs
new_callable=None, **kwargs
): ):
""" """
`patch` acts as a function decorator, class decorator or a context `patch` acts as a function decorator, class decorator or a context
...@@ -2079,7 +2103,7 @@ def _get_class(obj): ...@@ -2079,7 +2103,7 @@ def _get_class(obj):
try: try:
return obj.__class__ return obj.__class__
except AttributeError: except AttributeError:
# in Python 2, _sre.SRE_Pattern objects have no __class__ # it is possible for objects to have no __class__
return type(obj) return type(obj)
......
...@@ -11,14 +11,15 @@ from unittest.test.testmock.support import SomeClass, is_instance ...@@ -11,14 +11,15 @@ from unittest.test.testmock.support import SomeClass, is_instance
from unittest.mock import ( from unittest.mock import (
NonCallableMock, CallableMixin, patch, sentinel, NonCallableMock, CallableMixin, patch, sentinel,
MagicMock, Mock, NonCallableMagicMock, patch, MagicMock, Mock, NonCallableMagicMock, patch, _patch,
DEFAULT, call DEFAULT, call, _get_target
) )
builtin_string = 'builtins' builtin_string = 'builtins'
PTModule = sys.modules[__name__] PTModule = sys.modules[__name__]
MODNAME = '%s.PTModule' % __name__
def _get_proxy(obj, get_only=True): def _get_proxy(obj, get_only=True):
...@@ -724,8 +725,8 @@ class PatchTest(unittest.TestCase): ...@@ -724,8 +725,8 @@ class PatchTest(unittest.TestCase):
patcher = patch('%s.something' % __name__) patcher = patch('%s.something' % __name__)
self.assertIs(something, original) self.assertIs(something, original)
mock = patcher.start() mock = patcher.start()
self.assertIsNot(mock, original)
try: try:
self.assertIsNot(mock, original)
self.assertIs(something, mock) self.assertIs(something, mock)
finally: finally:
patcher.stop() patcher.stop()
...@@ -744,8 +745,8 @@ class PatchTest(unittest.TestCase): ...@@ -744,8 +745,8 @@ class PatchTest(unittest.TestCase):
patcher = patch.object(PTModule, 'something', 'foo') patcher = patch.object(PTModule, 'something', 'foo')
self.assertIs(something, original) self.assertIs(something, original)
replaced = patcher.start() replaced = patcher.start()
self.assertEqual(replaced, 'foo')
try: try:
self.assertEqual(replaced, 'foo')
self.assertIs(something, replaced) self.assertIs(something, replaced)
finally: finally:
patcher.stop() patcher.stop()
...@@ -759,9 +760,10 @@ class PatchTest(unittest.TestCase): ...@@ -759,9 +760,10 @@ class PatchTest(unittest.TestCase):
self.assertEqual(d, original) self.assertEqual(d, original)
patcher.start() patcher.start()
self.assertEqual(d, {'spam': 'eggs'}) try:
self.assertEqual(d, {'spam': 'eggs'})
patcher.stop() finally:
patcher.stop()
self.assertEqual(d, original) self.assertEqual(d, original)
...@@ -1647,6 +1649,99 @@ class PatchTest(unittest.TestCase): ...@@ -1647,6 +1649,99 @@ class PatchTest(unittest.TestCase):
self.assertEqual(squizz.squozz, 3) self.assertEqual(squizz.squozz, 3)
def test_patch_propogrates_exc_on_exit(self):
class holder:
exc_info = None, None, None
class custom_patch(_patch):
def __exit__(self, etype=None, val=None, tb=None):
_patch.__exit__(self, etype, val, tb)
holder.exc_info = etype, val, tb
stop = __exit__
def with_custom_patch(target):
getter, attribute = _get_target(target)
return custom_patch(
getter, attribute, DEFAULT, None, False, None,
None, None, {}
)
@with_custom_patch('squizz.squozz')
def test(mock):
raise RuntimeError
self.assertRaises(RuntimeError, test)
self.assertIs(holder.exc_info[0], RuntimeError)
self.assertIsNotNone(holder.exc_info[1],
'exception value not propgated')
self.assertIsNotNone(holder.exc_info[2],
'exception traceback not propgated')
def test_create_and_specs(self):
for kwarg in ('spec', 'spec_set', 'autospec'):
p = patch('%s.doesnotexist' % __name__, create=True,
**{kwarg: True})
self.assertRaises(TypeError, p.start)
self.assertRaises(NameError, lambda: doesnotexist)
# check that spec with create is innocuous if the original exists
p = patch(MODNAME, create=True, **{kwarg: True})
p.start()
p.stop()
def test_multiple_specs(self):
original = PTModule
for kwarg in ('spec', 'spec_set'):
p = patch(MODNAME, autospec=0, **{kwarg: 0})
self.assertRaises(TypeError, p.start)
self.assertIs(PTModule, original)
for kwarg in ('spec', 'autospec'):
p = patch(MODNAME, spec_set=0, **{kwarg: 0})
self.assertRaises(TypeError, p.start)
self.assertIs(PTModule, original)
for kwarg in ('spec_set', 'autospec'):
p = patch(MODNAME, spec=0, **{kwarg: 0})
self.assertRaises(TypeError, p.start)
self.assertIs(PTModule, original)
def test_specs_false_instead_of_none(self):
p = patch(MODNAME, spec=False, spec_set=False, autospec=False)
mock = p.start()
try:
# no spec should have been set, so attribute access should not fail
mock.does_not_exist
mock.does_not_exist = 3
finally:
p.stop()
def test_falsey_spec(self):
for kwarg in ('spec', 'autospec', 'spec_set'):
p = patch(MODNAME, **{kwarg: 0})
m = p.start()
try:
self.assertRaises(AttributeError, getattr, m, 'doesnotexit')
finally:
p.stop()
def test_spec_set_true(self):
for kwarg in ('spec', 'autospec'):
p = patch(MODNAME, spec_set=True, **{kwarg: True})
m = p.start()
try:
self.assertRaises(AttributeError, setattr, m,
'doesnotexist', 'something')
self.assertRaises(AttributeError, getattr, m, 'doesnotexist')
finally:
p.stop()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment