Commit 436c2b0d authored by Xtreak's avatar Xtreak Committed by Miss Islington (bot)

bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)



Return a coroutine while patching async functions with a decorator. 
Co-authored-by: default avatarAndrew Svetlov <andrew.svetlov@gmail.com>


https://bugs.python.org/issue36996
parent 71dc7c5f
...@@ -26,6 +26,7 @@ __all__ = ( ...@@ -26,6 +26,7 @@ __all__ = (
__version__ = '1.0' __version__ = '1.0'
import asyncio import asyncio
import contextlib
import io import io
import inspect import inspect
import pprint import pprint
...@@ -1220,6 +1221,8 @@ class _patch(object): ...@@ -1220,6 +1221,8 @@ class _patch(object):
def __call__(self, func): def __call__(self, func):
if isinstance(func, type): if isinstance(func, type):
return self.decorate_class(func) return self.decorate_class(func)
if inspect.iscoroutinefunction(func):
return self.decorate_async_callable(func)
return self.decorate_callable(func) return self.decorate_callable(func)
...@@ -1237,41 +1240,68 @@ class _patch(object): ...@@ -1237,41 +1240,68 @@ class _patch(object):
return klass return klass
@contextlib.contextmanager
def decoration_helper(self, patched, args, keywargs):
extra_args = []
entered_patchers = []
patching = None
exc_info = tuple()
try:
for patching in patched.patchings:
arg = patching.__enter__()
entered_patchers.append(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
extra_args.append(arg)
args += tuple(extra_args)
yield (args, keywargs)
except:
if (patching not in entered_patchers and
_is_started(patching)):
# the patcher may have been started, but an exception
# raised whilst entering one of its additional_patchers
entered_patchers.append(patching)
# Pass the exception to __exit__
exc_info = sys.exc_info()
# re-raise the exception
raise
finally:
for patching in reversed(entered_patchers):
patching.__exit__(*exc_info)
def decorate_callable(self, func): def decorate_callable(self, func):
# NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'): if hasattr(func, 'patchings'):
func.patchings.append(self) func.patchings.append(self)
return func return func
@wraps(func) @wraps(func)
def patched(*args, **keywargs): def patched(*args, **keywargs):
extra_args = [] with self.decoration_helper(patched,
entered_patchers = [] args,
keywargs) as (newargs, newkeywargs):
return func(*newargs, **newkeywargs)
exc_info = tuple() patched.patchings = [self]
try: return patched
for patching in patched.patchings:
arg = patching.__enter__()
entered_patchers.append(patching) def decorate_async_callable(self, func):
if patching.attribute_name is not None: # NB. Keep the method in sync with decorate_callable()
keywargs.update(arg) if hasattr(func, 'patchings'):
elif patching.new is DEFAULT: func.patchings.append(self)
extra_args.append(arg) return func
args += tuple(extra_args) @wraps(func)
return func(*args, **keywargs) async def patched(*args, **keywargs):
except: with self.decoration_helper(patched,
if (patching not in entered_patchers and args,
_is_started(patching)): keywargs) as (newargs, newkeywargs):
# the patcher may have been started, but an exception return await func(*newargs, **newkeywargs)
# raised whilst entering one of its additional_patchers
entered_patchers.append(patching)
# Pass the exception to __exit__
exc_info = sys.exc_info()
# re-raise the exception
raise
finally:
for patching in reversed(entered_patchers):
patching.__exit__(*exc_info)
patched.patchings = [self] patched.patchings = [self]
return patched return patched
......
...@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase): ...@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
test_async() test_async()
def test_async_def_patch(self):
@patch(f"{__name__}.async_func", AsyncMock())
async def test_async():
self.assertIsInstance(async_func, AsyncMock)
asyncio.run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func))
class AsyncPatchCMTest(unittest.TestCase): class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self): def test_is_async_function_cm(self):
...@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase): ...@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
test_async() test_async()
def test_async_def_cm(self):
async def test_async():
with patch(f"{__name__}.async_func", AsyncMock()):
self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func))
asyncio.run(test_async())
class AsyncMockTest(unittest.TestCase): class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self): def test_iscoroutinefunction_default(self):
......
Handle :func:`unittest.mock.patch` used as a decorator on async functions.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment