Commit 436c2b0d authored by Xtreak's avatar Xtreak Committed by Miss Islington (bot)

bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)



Return a coroutine while patching async functions with a decorator. 
Co-authored-by: default avatarAndrew Svetlov <andrew.svetlov@gmail.com>


https://bugs.python.org/issue36996
parent 71dc7c5f
...@@ -26,6 +26,7 @@ __all__ = ( ...@@ -26,6 +26,7 @@ __all__ = (
__version__ = '1.0' __version__ = '1.0'
import asyncio import asyncio
import contextlib
import io import io
import inspect import inspect
import pprint import pprint
...@@ -1220,6 +1221,8 @@ class _patch(object): ...@@ -1220,6 +1221,8 @@ class _patch(object):
def __call__(self, func): def __call__(self, func):
if isinstance(func, type): if isinstance(func, type):
return self.decorate_class(func) return self.decorate_class(func)
if inspect.iscoroutinefunction(func):
return self.decorate_async_callable(func)
return self.decorate_callable(func) return self.decorate_callable(func)
...@@ -1237,15 +1240,11 @@ class _patch(object): ...@@ -1237,15 +1240,11 @@ class _patch(object):
return klass return klass
def decorate_callable(self, func): @contextlib.contextmanager
if hasattr(func, 'patchings'): def decoration_helper(self, patched, args, keywargs):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
extra_args = [] extra_args = []
entered_patchers = [] entered_patchers = []
patching = None
exc_info = tuple() exc_info = tuple()
try: try:
...@@ -1258,7 +1257,7 @@ class _patch(object): ...@@ -1258,7 +1257,7 @@ class _patch(object):
extra_args.append(arg) extra_args.append(arg)
args += tuple(extra_args) args += tuple(extra_args)
return func(*args, **keywargs) yield (args, keywargs)
except: except:
if (patching not in entered_patchers and if (patching not in entered_patchers and
_is_started(patching)): _is_started(patching)):
...@@ -1273,6 +1272,37 @@ class _patch(object): ...@@ -1273,6 +1272,37 @@ class _patch(object):
for patching in reversed(entered_patchers): for patching in reversed(entered_patchers):
patching.__exit__(*exc_info) patching.__exit__(*exc_info)
def decorate_callable(self, func):
# NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
def decorate_async_callable(self, func):
# NB. Keep the method in sync with decorate_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
async def patched(*args, **keywargs):
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return await func(*newargs, **newkeywargs)
patched.patchings = [self] patched.patchings = [self]
return patched return patched
......
...@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase): ...@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
test_async() test_async()
def test_async_def_patch(self):
@patch(f"{__name__}.async_func", AsyncMock())
async def test_async():
self.assertIsInstance(async_func, AsyncMock)
asyncio.run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func))
class AsyncPatchCMTest(unittest.TestCase): class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self): def test_is_async_function_cm(self):
...@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase): ...@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
test_async() test_async()
def test_async_def_cm(self):
async def test_async():
with patch(f"{__name__}.async_func", AsyncMock()):
self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func))
asyncio.run(test_async())
class AsyncMockTest(unittest.TestCase): class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self): def test_iscoroutinefunction_default(self):
......
Handle :func:`unittest.mock.patch` used as a decorator on async functions.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment