Commit 7661db62 authored by Yury Selivanov's avatar Yury Selivanov

Issue #27041: asyncio: Add loop.create_future method

parent 7ed7ce6e
...@@ -209,7 +209,7 @@ class Server(events.AbstractServer): ...@@ -209,7 +209,7 @@ class Server(events.AbstractServer):
def wait_closed(self): def wait_closed(self):
if self.sockets is None or self._waiters is None: if self.sockets is None or self._waiters is None:
return return
waiter = futures.Future(loop=self._loop) waiter = self._loop.create_future()
self._waiters.append(waiter) self._waiters.append(waiter)
yield from waiter yield from waiter
...@@ -243,6 +243,10 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -243,6 +243,10 @@ class BaseEventLoop(events.AbstractEventLoop):
% (self.__class__.__name__, self.is_running(), % (self.__class__.__name__, self.is_running(),
self.is_closed(), self.get_debug())) self.is_closed(), self.get_debug()))
def create_future(self):
"""Create a Future object attached to the loop."""
return futures.Future(loop=self)
def create_task(self, coro): def create_task(self, coro):
"""Schedule a coroutine object. """Schedule a coroutine object.
...@@ -536,7 +540,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -536,7 +540,7 @@ class BaseEventLoop(events.AbstractEventLoop):
assert not args assert not args
assert not isinstance(func, events.TimerHandle) assert not isinstance(func, events.TimerHandle)
if func._cancelled: if func._cancelled:
f = futures.Future(loop=self) f = self.create_future()
f.set_result(None) f.set_result(None)
return f return f
func, args = func._callback, func._args func, args = func._callback, func._args
...@@ -579,7 +583,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -579,7 +583,7 @@ class BaseEventLoop(events.AbstractEventLoop):
family=0, type=0, proto=0, flags=0): family=0, type=0, proto=0, flags=0):
info = _ipaddr_info(host, port, family, type, proto) info = _ipaddr_info(host, port, family, type, proto)
if info is not None: if info is not None:
fut = futures.Future(loop=self) fut = self.create_future()
fut.set_result([info]) fut.set_result([info])
return fut return fut
elif self._debug: elif self._debug:
...@@ -720,7 +724,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -720,7 +724,7 @@ class BaseEventLoop(events.AbstractEventLoop):
def _create_connection_transport(self, sock, protocol_factory, ssl, def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname): server_hostname):
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = self.create_future()
if ssl: if ssl:
sslcontext = None if isinstance(ssl, bool) else ssl sslcontext = None if isinstance(ssl, bool) else ssl
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
...@@ -840,7 +844,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -840,7 +844,7 @@ class BaseEventLoop(events.AbstractEventLoop):
raise exceptions[0] raise exceptions[0]
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = self.create_future()
transport = self._make_datagram_transport( transport = self._make_datagram_transport(
sock, protocol, r_addr, waiter) sock, protocol, r_addr, waiter)
if self._debug: if self._debug:
...@@ -979,7 +983,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -979,7 +983,7 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine @coroutine
def connect_read_pipe(self, protocol_factory, pipe): def connect_read_pipe(self, protocol_factory, pipe):
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = self.create_future()
transport = self._make_read_pipe_transport(pipe, protocol, waiter) transport = self._make_read_pipe_transport(pipe, protocol, waiter)
try: try:
...@@ -996,7 +1000,7 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -996,7 +1000,7 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine @coroutine
def connect_write_pipe(self, protocol_factory, pipe): def connect_write_pipe(self, protocol_factory, pipe):
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = self.create_future()
transport = self._make_write_pipe_transport(pipe, protocol, waiter) transport = self._make_write_pipe_transport(pipe, protocol, waiter)
try: try:
......
...@@ -227,7 +227,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport): ...@@ -227,7 +227,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport):
if self._returncode is not None: if self._returncode is not None:
return self._returncode return self._returncode
waiter = futures.Future(loop=self._loop) waiter = self._loop.create_future()
self._exit_waiters.append(waiter) self._exit_waiters.append(waiter)
return (yield from waiter) return (yield from waiter)
......
...@@ -266,6 +266,9 @@ class AbstractEventLoop: ...@@ -266,6 +266,9 @@ class AbstractEventLoop:
def time(self): def time(self):
raise NotImplementedError raise NotImplementedError
def create_future(self):
raise NotImplementedError
# Method scheduling a coroutine object: create a task. # Method scheduling a coroutine object: create a task.
def create_task(self, coro): def create_task(self, coro):
......
...@@ -451,6 +451,8 @@ def wrap_future(future, *, loop=None): ...@@ -451,6 +451,8 @@ def wrap_future(future, *, loop=None):
return future return future
assert isinstance(future, concurrent.futures.Future), \ assert isinstance(future, concurrent.futures.Future), \
'concurrent.futures.Future is expected, got {!r}'.format(future) 'concurrent.futures.Future is expected, got {!r}'.format(future)
new_future = Future(loop=loop) if loop is None:
loop = events.get_event_loop()
new_future = loop.create_future()
_chain_future(future, new_future) _chain_future(future, new_future)
return new_future return new_future
...@@ -170,7 +170,7 @@ class Lock(_ContextManagerMixin): ...@@ -170,7 +170,7 @@ class Lock(_ContextManagerMixin):
self._locked = True self._locked = True
return True return True
fut = futures.Future(loop=self._loop) fut = self._loop.create_future()
self._waiters.append(fut) self._waiters.append(fut)
try: try:
yield from fut yield from fut
...@@ -258,7 +258,7 @@ class Event: ...@@ -258,7 +258,7 @@ class Event:
if self._value: if self._value:
return True return True
fut = futures.Future(loop=self._loop) fut = self._loop.create_future()
self._waiters.append(fut) self._waiters.append(fut)
try: try:
yield from fut yield from fut
...@@ -320,7 +320,7 @@ class Condition(_ContextManagerMixin): ...@@ -320,7 +320,7 @@ class Condition(_ContextManagerMixin):
self.release() self.release()
try: try:
fut = futures.Future(loop=self._loop) fut = self._loop.create_future()
self._waiters.append(fut) self._waiters.append(fut)
try: try:
yield from fut yield from fut
...@@ -433,7 +433,7 @@ class Semaphore(_ContextManagerMixin): ...@@ -433,7 +433,7 @@ class Semaphore(_ContextManagerMixin):
True. True.
""" """
while self._value <= 0: while self._value <= 0:
fut = futures.Future(loop=self._loop) fut = self._loop.create_future()
self._waiters.append(fut) self._waiters.append(fut)
try: try:
yield from fut yield from fut
......
...@@ -443,7 +443,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -443,7 +443,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
try: try:
base_events._check_resolved_address(sock, address) base_events._check_resolved_address(sock, address)
except ValueError as err: except ValueError as err:
fut = futures.Future(loop=self) fut = self.create_future()
fut.set_exception(err) fut.set_exception(err)
return fut return fut
else: else:
......
...@@ -128,7 +128,7 @@ class Queue: ...@@ -128,7 +128,7 @@ class Queue:
This method is a coroutine. This method is a coroutine.
""" """
while self.full(): while self.full():
putter = futures.Future(loop=self._loop) putter = self._loop.create_future()
self._putters.append(putter) self._putters.append(putter)
try: try:
yield from putter yield from putter
...@@ -162,7 +162,7 @@ class Queue: ...@@ -162,7 +162,7 @@ class Queue:
This method is a coroutine. This method is a coroutine.
""" """
while self.empty(): while self.empty():
getter = futures.Future(loop=self._loop) getter = self._loop.create_future()
self._getters.append(getter) self._getters.append(getter)
try: try:
yield from getter yield from getter
......
...@@ -196,7 +196,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -196,7 +196,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
transport = None transport = None
try: try:
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = self.create_future()
if sslcontext: if sslcontext:
transport = self._make_ssl_transport( transport = self._make_ssl_transport(
conn, protocol, sslcontext, waiter=waiter, conn, protocol, sslcontext, waiter=waiter,
...@@ -314,7 +314,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -314,7 +314,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
""" """
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self) fut = self.create_future()
self._sock_recv(fut, False, sock, n) self._sock_recv(fut, False, sock, n)
return fut return fut
...@@ -352,7 +352,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -352,7 +352,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
""" """
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self) fut = self.create_future()
if data: if data:
self._sock_sendall(fut, False, sock, data) self._sock_sendall(fut, False, sock, data)
else: else:
...@@ -395,7 +395,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -395,7 +395,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
""" """
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self) fut = self.create_future()
try: try:
base_events._check_resolved_address(sock, address) base_events._check_resolved_address(sock, address)
except ValueError as err: except ValueError as err:
...@@ -453,7 +453,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): ...@@ -453,7 +453,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
""" """
if self._debug and sock.gettimeout() != 0: if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking") raise ValueError("the socket must be non-blocking")
fut = futures.Future(loop=self) fut = self.create_future()
self._sock_accept(fut, False, sock) self._sock_accept(fut, False, sock)
return fut return fut
......
...@@ -210,7 +210,7 @@ class FlowControlMixin(protocols.Protocol): ...@@ -210,7 +210,7 @@ class FlowControlMixin(protocols.Protocol):
return return
waiter = self._drain_waiter waiter = self._drain_waiter
assert waiter is None or waiter.cancelled() assert waiter is None or waiter.cancelled()
waiter = futures.Future(loop=self._loop) waiter = self._loop.create_future()
self._drain_waiter = waiter self._drain_waiter = waiter
yield from waiter yield from waiter
...@@ -449,7 +449,7 @@ class StreamReader: ...@@ -449,7 +449,7 @@ class StreamReader:
self._paused = False self._paused = False
self._transport.resume_reading() self._transport.resume_reading()
self._waiter = futures.Future(loop=self._loop) self._waiter = self._loop.create_future()
try: try:
yield from self._waiter yield from self._waiter
finally: finally:
......
...@@ -373,7 +373,7 @@ def wait_for(fut, timeout, *, loop=None): ...@@ -373,7 +373,7 @@ def wait_for(fut, timeout, *, loop=None):
if timeout is None: if timeout is None:
return (yield from fut) return (yield from fut)
waiter = futures.Future(loop=loop) waiter = loop.create_future()
timeout_handle = loop.call_later(timeout, _release_waiter, waiter) timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
cb = functools.partial(_release_waiter, waiter) cb = functools.partial(_release_waiter, waiter)
...@@ -406,7 +406,7 @@ def _wait(fs, timeout, return_when, loop): ...@@ -406,7 +406,7 @@ def _wait(fs, timeout, return_when, loop):
The fs argument must be a collection of Futures. The fs argument must be a collection of Futures.
""" """
assert fs, 'Set of Futures is empty.' assert fs, 'Set of Futures is empty.'
waiter = futures.Future(loop=loop) waiter = loop.create_future()
timeout_handle = None timeout_handle = None
if timeout is not None: if timeout is not None:
timeout_handle = loop.call_later(timeout, _release_waiter, waiter) timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
...@@ -507,7 +507,9 @@ def sleep(delay, result=None, *, loop=None): ...@@ -507,7 +507,9 @@ def sleep(delay, result=None, *, loop=None):
yield yield
return result return result
future = futures.Future(loop=loop) if loop is None:
loop = events.get_event_loop()
future = loop.create_future()
h = future._loop.call_later(delay, h = future._loop.call_later(delay,
futures._set_result_unless_cancelled, futures._set_result_unless_cancelled,
future, result) future, result)
...@@ -604,7 +606,9 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): ...@@ -604,7 +606,9 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
be cancelled.) be cancelled.)
""" """
if not coros_or_futures: if not coros_or_futures:
outer = futures.Future(loop=loop) if loop is None:
loop = events.get_event_loop()
outer = loop.create_future()
outer.set_result([]) outer.set_result([])
return outer return outer
...@@ -692,7 +696,7 @@ def shield(arg, *, loop=None): ...@@ -692,7 +696,7 @@ def shield(arg, *, loop=None):
# Shortcut. # Shortcut.
return inner return inner
loop = inner._loop loop = inner._loop
outer = futures.Future(loop=loop) outer = loop.create_future()
def _done_callback(inner): def _done_callback(inner):
if outer.cancelled(): if outer.cancelled():
......
...@@ -177,7 +177,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -177,7 +177,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): extra=None, **kwargs):
with events.get_child_watcher() as watcher: with events.get_child_watcher() as watcher:
waiter = futures.Future(loop=self) waiter = self.create_future()
transp = _UnixSubprocessTransport(self, protocol, args, shell, transp = _UnixSubprocessTransport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
waiter=waiter, extra=extra, waiter=waiter, extra=extra,
......
...@@ -366,7 +366,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): ...@@ -366,7 +366,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
def _make_subprocess_transport(self, protocol, args, shell, def _make_subprocess_transport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
extra=None, **kwargs): extra=None, **kwargs):
waiter = futures.Future(loop=self) waiter = self.create_future()
transp = _WindowsSubprocessTransport(self, protocol, args, shell, transp = _WindowsSubprocessTransport(self, protocol, args, shell,
stdin, stdout, stderr, bufsize, stdin, stdout, stderr, bufsize,
waiter=waiter, extra=extra, waiter=waiter, extra=extra,
...@@ -417,7 +417,7 @@ class IocpProactor: ...@@ -417,7 +417,7 @@ class IocpProactor:
return tmp return tmp
def _result(self, value): def _result(self, value):
fut = futures.Future(loop=self._loop) fut = self._loop.create_future()
fut.set_result(value) fut.set_result(value)
return fut return fut
......
...@@ -278,14 +278,15 @@ class FutureTests(test_utils.TestCase): ...@@ -278,14 +278,15 @@ class FutureTests(test_utils.TestCase):
f2 = asyncio.wrap_future(f1) f2 = asyncio.wrap_future(f1)
self.assertIs(f1, f2) self.assertIs(f1, f2)
@mock.patch('asyncio.futures.events') def test_wrap_future_use_global_loop(self):
def test_wrap_future_use_global_loop(self, m_events): with mock.patch('asyncio.futures.events') as events:
def run(arg): events.get_event_loop = lambda: self.loop
return (arg, threading.get_ident()) def run(arg):
ex = concurrent.futures.ThreadPoolExecutor(1) return (arg, threading.get_ident())
f1 = ex.submit(run, 'oi') ex = concurrent.futures.ThreadPoolExecutor(1)
f2 = asyncio.wrap_future(f1) f1 = ex.submit(run, 'oi')
self.assertIs(m_events.get_event_loop.return_value, f2._loop) f2 = asyncio.wrap_future(f1)
self.assertIs(self.loop, f2._loop)
def test_wrap_future_cancel(self): def test_wrap_future_cancel(self):
f1 = concurrent.futures.Future() f1 = concurrent.futures.Future()
......
...@@ -456,6 +456,8 @@ Library ...@@ -456,6 +456,8 @@ Library
- Issue #27040: Add loop.get_exception_handler method - Issue #27040: Add loop.get_exception_handler method
- Issue #27041: asyncio: Add loop.create_future method
Documentation Documentation
------------- -------------
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment