Commit 8b03f943 authored by Lisa Roach's avatar Lisa Roach Committed by GitHub

bpo-38093: Correctly returns AsyncMock for async subclasses. (GH-15947)

parent 2702638e
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import asyncio import asyncio
import unittest import unittest
from unittest.mock import Mock, MagicMock, patch, call, sentinel from unittest.mock import Mock, MagicMock, AsyncMock, patch, call, sentinel
class SomeClass: class SomeClass:
attribute = 'this is a doctest' attribute = 'this is a doctest'
...@@ -280,14 +280,16 @@ function returns is what the call returns: ...@@ -280,14 +280,16 @@ function returns is what the call returns:
Mocking asynchronous iterators Mocking asynchronous iterators
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since Python 3.8, ``MagicMock`` has support to mock :ref:`async-iterators` Since Python 3.8, ``AsyncMock`` and ``MagicMock`` have support to mock
through ``__aiter__``. The :attr:`~Mock.return_value` attribute of ``__aiter__`` :ref:`async-iterators` through ``__aiter__``. The :attr:`~Mock.return_value`
can be used to set the return values to be used for iteration. attribute of ``__aiter__`` can be used to set the return values to be used for
iteration.
>>> mock = MagicMock() >>> mock = MagicMock() # AsyncMock also works here
>>> mock.__aiter__.return_value = [1, 2, 3] >>> mock.__aiter__.return_value = [1, 2, 3]
>>> async def main(): >>> async def main():
... return [i async for i in mock] ... return [i async for i in mock]
...
>>> asyncio.run(main()) >>> asyncio.run(main())
[1, 2, 3] [1, 2, 3]
...@@ -295,24 +297,25 @@ can be used to set the return values to be used for iteration. ...@@ -295,24 +297,25 @@ can be used to set the return values to be used for iteration.
Mocking asynchronous context manager Mocking asynchronous context manager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since Python 3.8, ``MagicMock`` has support to mock Since Python 3.8, ``AsyncMock`` and ``MagicMock`` have support to mock
:ref:`async-context-managers` through ``__aenter__`` and ``__aexit__``. The :ref:`async-context-managers` through ``__aenter__`` and ``__aexit__``.
return value of ``__aenter__`` is an :class:`AsyncMock`. By default, ``__aenter__`` and ``__aexit__`` are ``AsyncMock`` instances that
return an async function.
>>> class AsyncContextManager: >>> class AsyncContextManager:
...
... async def __aenter__(self): ... async def __aenter__(self):
... return self ... return self
... ... async def __aexit__(self, exc_type, exc, tb):
... async def __aexit__(self):
... pass ... pass
>>> mock_instance = MagicMock(AsyncContextManager()) ...
>>> mock_instance = MagicMock(AsyncContextManager()) # AsyncMock also works here
>>> async def main(): >>> async def main():
... async with mock_instance as result: ... async with mock_instance as result:
... pass ... pass
...
>>> asyncio.run(main()) >>> asyncio.run(main())
>>> mock_instance.__aenter__.assert_called_once() >>> mock_instance.__aenter__.assert_awaited_once()
>>> mock_instance.__aexit__.assert_called_once() >>> mock_instance.__aexit__.assert_awaited_once()
Creating a Mock from an Existing Object Creating a Mock from an Existing Object
......
...@@ -983,9 +983,13 @@ class NonCallableMock(Base): ...@@ -983,9 +983,13 @@ class NonCallableMock(Base):
_type = type(self) _type = type(self)
if issubclass(_type, MagicMock) and _new_name in _async_method_magics: if issubclass(_type, MagicMock) and _new_name in _async_method_magics:
klass = AsyncMock klass = AsyncMock
if issubclass(_type, AsyncMockMixin): elif _new_name in _sync_async_magics:
# Special case these ones b/c users will assume they are async,
# but they are actually sync (ie. __aiter__)
klass = MagicMock klass = MagicMock
if not issubclass(_type, CallableMixin): elif issubclass(_type, AsyncMockMixin):
klass = AsyncMock
elif not issubclass(_type, CallableMixin):
if issubclass(_type, NonCallableMagicMock): if issubclass(_type, NonCallableMagicMock):
klass = MagicMock klass = MagicMock
elif issubclass(_type, NonCallableMock) : elif issubclass(_type, NonCallableMock) :
...@@ -1881,7 +1885,7 @@ _non_defaults = { ...@@ -1881,7 +1885,7 @@ _non_defaults = {
'__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__',
'__getstate__', '__setstate__', '__getformat__', '__setformat__', '__getstate__', '__setstate__', '__getformat__', '__setformat__',
'__repr__', '__dir__', '__subclasses__', '__format__', '__repr__', '__dir__', '__subclasses__', '__format__',
'__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__', '__getnewargs_ex__',
} }
...@@ -1900,10 +1904,12 @@ _magics = { ...@@ -1900,10 +1904,12 @@ _magics = {
# Magic methods used for async `with` statements # Magic methods used for async `with` statements
_async_method_magics = {"__aenter__", "__aexit__", "__anext__"} _async_method_magics = {"__aenter__", "__aexit__", "__anext__"}
# `__aiter__` is a plain function but used with async calls # Magic methods that are only used with async calls but are synchronous functions themselves
_async_magics = _async_method_magics | {"__aiter__"} _sync_async_magics = {"__aiter__"}
_async_magics = _async_method_magics | _sync_async_magics
_all_magics = _magics | _non_defaults _all_sync_magics = _magics | _non_defaults
_all_magics = _all_sync_magics | _async_magics
_unsupported_magics = { _unsupported_magics = {
'__getattr__', '__setattr__', '__getattr__', '__setattr__',
......
...@@ -382,35 +382,88 @@ class AsyncArguments(unittest.TestCase): ...@@ -382,35 +382,88 @@ class AsyncArguments(unittest.TestCase):
class AsyncContextManagerTest(unittest.TestCase): class AsyncContextManagerTest(unittest.TestCase):
class WithAsyncContextManager: class WithAsyncContextManager:
async def __aenter__(self, *args, **kwargs): async def __aenter__(self, *args, **kwargs):
return self return self
async def __aexit__(self, *args, **kwargs): async def __aexit__(self, *args, **kwargs):
pass pass
def test_magic_methods_are_async_mocks(self): class WithSyncContextManager:
mock = MagicMock(self.WithAsyncContextManager()) def __enter__(self, *args, **kwargs):
self.assertIsInstance(mock.__aenter__, AsyncMock) return self
self.assertIsInstance(mock.__aexit__, AsyncMock)
def __exit__(self, *args, **kwargs):
pass
class ProductionCode:
# Example real-world(ish) code
def __init__(self):
self.session = None
async def main(self):
async with self.session.post('https://python.org') as response:
val = await response.json()
return val
def test_async_magic_methods_are_async_mocks_with_magicmock(self):
cm_mock = MagicMock(self.WithAsyncContextManager())
self.assertIsInstance(cm_mock.__aenter__, AsyncMock)
self.assertIsInstance(cm_mock.__aexit__, AsyncMock)
def test_magicmock_has_async_magic_methods(self):
cm = MagicMock(name='magic_cm')
self.assertTrue(hasattr(cm, "__aenter__"))
self.assertTrue(hasattr(cm, "__aexit__"))
def test_magic_methods_are_async_functions(self):
cm = MagicMock(name='magic_cm')
self.assertIsInstance(cm.__aenter__, AsyncMock)
self.assertIsInstance(cm.__aexit__, AsyncMock)
# AsyncMocks are also coroutine functions
self.assertTrue(asyncio.iscoroutinefunction(cm.__aenter__))
self.assertTrue(asyncio.iscoroutinefunction(cm.__aexit__))
def test_set_return_value_of_aenter(self):
def inner_test(mock_type):
pc = self.ProductionCode()
pc.session = MagicMock(name='sessionmock')
cm = mock_type(name='magic_cm')
response = AsyncMock(name='response')
response.json = AsyncMock(return_value={'json': 123})
cm.__aenter__.return_value = response
pc.session.post.return_value = cm
result = asyncio.run(pc.main())
self.assertEqual(result, {'json': 123})
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test set return value of aenter with {mock_type}"):
inner_test(mock_type)
def test_mock_supports_async_context_manager(self): def test_mock_supports_async_context_manager(self):
called = False def inner_test(mock_type):
instance = self.WithAsyncContextManager() called = False
mock_instance = MagicMock(instance) cm = self.WithAsyncContextManager()
cm_mock = mock_type(cm)
async def use_context_manager():
nonlocal called
async with cm_mock as result:
called = True
return result
async def use_context_manager(): cm_result = asyncio.run(use_context_manager())
nonlocal called self.assertTrue(called)
async with mock_instance as result: self.assertTrue(cm_mock.__aenter__.called)
called = True self.assertTrue(cm_mock.__aexit__.called)
return result cm_mock.__aenter__.assert_awaited()
cm_mock.__aexit__.assert_awaited()
# We mock __aenter__ so it does not return self
self.assertIsNot(cm_mock, cm_result)
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test context manager magics with {mock_type}"):
inner_test(mock_type)
result = asyncio.run(use_context_manager())
self.assertTrue(called)
self.assertTrue(mock_instance.__aenter__.called)
self.assertTrue(mock_instance.__aexit__.called)
self.assertIsNot(mock_instance, result)
self.assertIsInstance(result, AsyncMock)
def test_mock_customize_async_context_manager(self): def test_mock_customize_async_context_manager(self):
instance = self.WithAsyncContextManager() instance = self.WithAsyncContextManager()
...@@ -478,27 +531,30 @@ class AsyncIteratorTest(unittest.TestCase): ...@@ -478,27 +531,30 @@ class AsyncIteratorTest(unittest.TestCase):
raise StopAsyncIteration raise StopAsyncIteration
def test_mock_aiter_and_anext(self): def test_aiter_set_return_value(self):
instance = self.WithAsyncIterator() mock_iter = AsyncMock(name="tester")
mock_instance = MagicMock(instance) mock_iter.__aiter__.return_value = [1, 2, 3]
async def main():
self.assertEqual(asyncio.iscoroutine(instance.__aiter__), return [i async for i in mock_iter]
asyncio.iscoroutine(mock_instance.__aiter__)) result = asyncio.run(main())
self.assertEqual(asyncio.iscoroutine(instance.__anext__), self.assertEqual(result, [1, 2, 3])
asyncio.iscoroutine(mock_instance.__anext__))
def test_mock_aiter_and_anext_asyncmock(self):
iterator = instance.__aiter__() def inner_test(mock_type):
if asyncio.iscoroutine(iterator): instance = self.WithAsyncIterator()
iterator = asyncio.run(iterator) mock_instance = mock_type(instance)
# Check that the mock and the real thing bahave the same
mock_iterator = mock_instance.__aiter__() # __aiter__ is not actually async, so not a coroutinefunction
if asyncio.iscoroutine(mock_iterator): self.assertFalse(asyncio.iscoroutinefunction(instance.__aiter__))
mock_iterator = asyncio.run(mock_iterator) self.assertFalse(asyncio.iscoroutinefunction(mock_instance.__aiter__))
# __anext__ is async
self.assertTrue(asyncio.iscoroutinefunction(instance.__anext__))
self.assertTrue(asyncio.iscoroutinefunction(mock_instance.__anext__))
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"test aiter and anext corourtine with {mock_type}"):
inner_test(mock_type)
self.assertEqual(asyncio.iscoroutine(iterator.__aiter__),
asyncio.iscoroutine(mock_iterator.__aiter__))
self.assertEqual(asyncio.iscoroutine(iterator.__anext__),
asyncio.iscoroutine(mock_iterator.__anext__))
def test_mock_async_for(self): def test_mock_async_for(self):
async def iterate(iterator): async def iterate(iterator):
...@@ -509,19 +565,30 @@ class AsyncIteratorTest(unittest.TestCase): ...@@ -509,19 +565,30 @@ class AsyncIteratorTest(unittest.TestCase):
return accumulator return accumulator
expected = ["FOO", "BAR", "BAZ"] expected = ["FOO", "BAR", "BAZ"]
with self.subTest("iterate through default value"): def test_default(mock_type):
mock_instance = MagicMock(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
self.assertEqual([], asyncio.run(iterate(mock_instance))) self.assertEqual(asyncio.run(iterate(mock_instance)), [])
with self.subTest("iterate through set return_value"): def test_set_return_value(mock_type):
mock_instance = MagicMock(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = expected[:] mock_instance.__aiter__.return_value = expected[:]
self.assertEqual(expected, asyncio.run(iterate(mock_instance))) self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
with self.subTest("iterate through set return_value iterator"): def test_set_return_value_iter(mock_type):
mock_instance = MagicMock(self.WithAsyncIterator()) mock_instance = mock_type(self.WithAsyncIterator())
mock_instance.__aiter__.return_value = iter(expected[:]) mock_instance.__aiter__.return_value = iter(expected[:])
self.assertEqual(expected, asyncio.run(iterate(mock_instance))) self.assertEqual(asyncio.run(iterate(mock_instance)), expected)
for mock_type in [AsyncMock, MagicMock]:
with self.subTest(f"default value with {mock_type}"):
test_default(mock_type)
with self.subTest(f"set return_value with {mock_type}"):
test_set_return_value(mock_type)
with self.subTest(f"set return_value iterator with {mock_type}"):
test_set_return_value_iter(mock_type)
class AsyncMockAssert(unittest.TestCase): class AsyncMockAssert(unittest.TestCase):
......
import asyncio
import math import math
import unittest import unittest
import os import os
import sys import sys
from unittest.mock import Mock, MagicMock, _magics from unittest.mock import AsyncMock, Mock, MagicMock, _magics
...@@ -271,6 +272,34 @@ class TestMockingMagicMethods(unittest.TestCase): ...@@ -271,6 +272,34 @@ class TestMockingMagicMethods(unittest.TestCase):
self.assertEqual(mock != mock, False) self.assertEqual(mock != mock, False)
# This should be fixed with issue38163
@unittest.expectedFailure
def test_asyncmock_defaults(self):
mock = AsyncMock()
self.assertEqual(int(mock), 1)
self.assertEqual(complex(mock), 1j)
self.assertEqual(float(mock), 1.0)
self.assertNotIn(object(), mock)
self.assertEqual(len(mock), 0)
self.assertEqual(list(mock), [])
self.assertEqual(hash(mock), object.__hash__(mock))
self.assertEqual(str(mock), object.__str__(mock))
self.assertTrue(bool(mock))
self.assertEqual(round(mock), mock.__round__())
self.assertEqual(math.trunc(mock), mock.__trunc__())
self.assertEqual(math.floor(mock), mock.__floor__())
self.assertEqual(math.ceil(mock), mock.__ceil__())
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__))
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__))
self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock)
# in Python 3 oct and hex use __index__
# so these tests are for __index__ in py3k
self.assertEqual(oct(mock), '0o1')
self.assertEqual(hex(mock), '0x1')
# how to test __sizeof__ ?
def test_magicmock_defaults(self): def test_magicmock_defaults(self):
mock = MagicMock() mock = MagicMock()
self.assertEqual(int(mock), 1) self.assertEqual(int(mock), 1)
...@@ -286,6 +315,10 @@ class TestMockingMagicMethods(unittest.TestCase): ...@@ -286,6 +315,10 @@ class TestMockingMagicMethods(unittest.TestCase):
self.assertEqual(math.trunc(mock), mock.__trunc__()) self.assertEqual(math.trunc(mock), mock.__trunc__())
self.assertEqual(math.floor(mock), mock.__floor__()) self.assertEqual(math.floor(mock), mock.__floor__())
self.assertEqual(math.ceil(mock), mock.__ceil__()) self.assertEqual(math.ceil(mock), mock.__ceil__())
self.assertTrue(asyncio.iscoroutinefunction(mock.__aexit__))
self.assertTrue(asyncio.iscoroutinefunction(mock.__aenter__))
self.assertIsInstance(mock.__aenter__, AsyncMock)
self.assertIsInstance(mock.__aexit__, AsyncMock)
# in Python 3 oct and hex use __index__ # in Python 3 oct and hex use __index__
# so these tests are for __index__ in py3k # so these tests are for __index__ in py3k
......
Fixes AsyncMock so it doesn't crash when used with AsyncContextManagers
or AsyncIterators.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment