Commit 7c5699a0 authored by Victor Stinner's avatar Victor Stinner

asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),

call_at() and run_in_executor() now raise a TypeError if the callback is a
coroutine function.
parent f21aa349
...@@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop):
def call_at(self, when, callback, *args): def call_at(self, when, callback, *args):
"""Like call_later(), but uses an absolute time.""" """Like call_later(), but uses an absolute time."""
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_at()")
timer = events.TimerHandle(when, callback, args) timer = events.TimerHandle(when, callback, args)
heapq.heappush(self._scheduled, timer) heapq.heappush(self._scheduled, timer)
return timer return timer
...@@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop):
Any positional arguments after the callback will be passed to Any positional arguments after the callback will be passed to
the callback when it is called. the callback when it is called.
""" """
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with call_soon()")
handle = events.Handle(callback, args) handle = events.Handle(callback, args)
self._ready.append(handle) self._ready.append(handle)
return handle return handle
...@@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop):
return handle return handle
def run_in_executor(self, executor, callback, *args): def run_in_executor(self, executor, callback, *args):
if tasks.iscoroutinefunction(callback):
raise TypeError("coroutines cannot be used with run_in_executor()")
if isinstance(callback, events.Handle): if isinstance(callback, events.Handle):
assert not args assert not args
assert not isinstance(callback, events.TimerHandle) assert not isinstance(callback, events.TimerHandle)
......
...@@ -135,7 +135,7 @@ def make_test_protocol(base): ...@@ -135,7 +135,7 @@ def make_test_protocol(base):
if name.startswith('__') and name.endswith('__'): if name.startswith('__') and name.endswith('__'):
# skip magic names # skip magic names
continue continue
dct[name] = unittest.mock.Mock(return_value=None) dct[name] = MockCallback(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)() return type('TestProtocol', (base,) + base.__bases__, dct)()
...@@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self): def _write_to_self(self):
pass pass
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)
...@@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
m_socket.getaddrinfo.return_value = [ m_socket.getaddrinfo.return_value = [
(2, 1, 6, '', ('127.0.0.1', 10100))] (2, 1, 6, '', ('127.0.0.1', 10100))]
m_socket.getaddrinfo._is_coroutine = False
m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock = m_socket.socket.return_value = unittest.mock.Mock()
m_sock.bind.side_effect = Err m_sock.bind.side_effect = Err
...@@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
@unittest.mock.patch('asyncio.base_events.socket') @unittest.mock.patch('asyncio.base_events.socket')
def test_create_datagram_endpoint_no_addrinfo(self, m_socket): def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo.return_value = []
m_socket.getaddrinfo._is_coroutine = False
coro = self.loop.create_datagram_endpoint( coro = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('localhost', 0)) MyDatagramProto, local_addr=('localhost', 0))
...@@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): ...@@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
unittest.mock.ANY, unittest.mock.ANY,
MyProto, sock, None, None) MyProto, sock, None, None)
def test_call_coroutine(self):
@asyncio.coroutine
def coroutine_function():
pass
with self.assertRaises(TypeError):
self.loop.call_soon(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_soon_threadsafe(coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_later(60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.call_at(self.loop.time() + 60, coroutine_function)
with self.assertRaises(TypeError):
self.loop.run_in_executor(None, coroutine_function)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase): ...@@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
NotImplementedError, BaseProactorEventLoop, self.proactor) NotImplementedError, BaseProactorEventLoop, self.proactor)
def test_make_socket_transport(self): def test_make_socket_transport(self):
tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
self.assertIsInstance(tr, _ProactorSocketTransport) self.assertIsInstance(tr, _ProactorSocketTransport)
def test_loop_self_reading(self): def test_loop_self_reading(self):
......
...@@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
def test_make_socket_transport(self): def test_make_socket_transport(self):
m = unittest.mock.Mock() m = unittest.mock.Mock()
self.loop.add_reader = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock()
self.assertIsInstance( transport = self.loop._make_socket_transport(m, asyncio.Protocol())
self.loop._make_socket_transport(m, m), _SelectorSocketTransport) self.assertIsInstance(transport, _SelectorSocketTransport)
@unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipIf(ssl is None, 'No ssl module')
def test_make_ssl_transport(self): def test_make_ssl_transport(self):
...@@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.add_writer = unittest.mock.Mock() self.loop.add_writer = unittest.mock.Mock()
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
self.assertIsInstance( waiter = asyncio.Future(loop=self.loop)
self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None) @unittest.mock.patch('asyncio.selector_events.ssl', None)
def test_make_ssl_transport_without_ssl_error(self): def test_make_ssl_transport_without_ssl_error(self):
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import gc import gc
import unittest import unittest
import unittest.mock
from unittest.mock import Mock
import asyncio import asyncio
from asyncio import test_utils from asyncio import test_utils
...@@ -1358,7 +1356,7 @@ class GatherTestsBase: ...@@ -1358,7 +1356,7 @@ class GatherTestsBase:
def _check_success(self, **kwargs): def _check_success(self, **kwargs):
a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
b.set_result(1) b.set_result(1)
a.set_result(2) a.set_result(2)
...@@ -1380,7 +1378,7 @@ class GatherTestsBase: ...@@ -1380,7 +1378,7 @@ class GatherTestsBase:
def test_one_exception(self): def test_one_exception(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
a.set_result(1) a.set_result(1)
...@@ -1399,7 +1397,7 @@ class GatherTestsBase: ...@@ -1399,7 +1397,7 @@ class GatherTestsBase:
a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
fut = asyncio.gather(*self.wrap_futures(a, b, c, d), fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
return_exceptions=True) return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
exc = ZeroDivisionError() exc = ZeroDivisionError()
exc2 = RuntimeError() exc2 = RuntimeError()
...@@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
def test_one_cancellation(self): def test_one_cancellation(self):
a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
fut = asyncio.gather(a, b, c, d, e) fut = asyncio.gather(a, b, c, d, e)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
b.cancel() b.cancel()
...@@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase): ...@@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
for i in range(6)] for i in range(6)]
fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
cb = Mock() cb = test_utils.MockCallback()
fut.add_done_callback(cb) fut.add_done_callback(cb)
a.set_result(1) a.set_result(1)
zde = ZeroDivisionError() zde = ZeroDivisionError()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment