Commit 436c2b0d authored by Xtreak's avatar Xtreak Committed by Miss Islington (bot)

bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)



Return a coroutine while patching async functions with a decorator. 
Co-authored-by: default avatarAndrew Svetlov <andrew.svetlov@gmail.com>


https://bugs.python.org/issue36996
parent 71dc7c5f
......@@ -26,6 +26,7 @@ __all__ = (
__version__ = '1.0'
import asyncio
import contextlib
import io
import inspect
import pprint
......@@ -1220,6 +1221,8 @@ class _patch(object):
def __call__(self, func):
if isinstance(func, type):
return self.decorate_class(func)
if inspect.iscoroutinefunction(func):
return self.decorate_async_callable(func)
return self.decorate_callable(func)
......@@ -1237,15 +1240,11 @@ class _patch(object):
return klass
def decorate_callable(self, func):
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
@contextlib.contextmanager
def decoration_helper(self, patched, args, keywargs):
extra_args = []
entered_patchers = []
patching = None
exc_info = tuple()
try:
......@@ -1258,7 +1257,7 @@ class _patch(object):
extra_args.append(arg)
args += tuple(extra_args)
return func(*args, **keywargs)
yield (args, keywargs)
except:
if (patching not in entered_patchers and
_is_started(patching)):
......@@ -1273,6 +1272,37 @@ class _patch(object):
for patching in reversed(entered_patchers):
patching.__exit__(*exc_info)
def decorate_callable(self, func):
# NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
def decorate_async_callable(self, func):
# NB. Keep the method in sync with decorate_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
async def patched(*args, **keywargs):
with self.decoration_helper(patched,
args,
keywargs) as (newargs, newkeywargs):
return await func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
......
......@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
test_async()
def test_async_def_patch(self):
@patch(f"{__name__}.async_func", AsyncMock())
async def test_async():
self.assertIsInstance(async_func, AsyncMock)
asyncio.run(test_async())
self.assertTrue(inspect.iscoroutinefunction(async_func))
class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self):
......@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
test_async()
def test_async_def_cm(self):
async def test_async():
with patch(f"{__name__}.async_func", AsyncMock()):
self.assertIsInstance(async_func, AsyncMock)
self.assertTrue(inspect.iscoroutinefunction(async_func))
asyncio.run(test_async())
class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):
......
Handle :func:`unittest.mock.patch` used as a decorator on async functions.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment