Commit b9bf913a authored by Guido van Rossum's avatar Guido van Rossum

Issue #23972: updates to asyncio datagram API. By Chris Laws.

parent d17e9785
......@@ -283,17 +283,50 @@ Creating connections
(:class:`StreamReader`, :class:`StreamWriter`) instead of a protocol.
.. coroutinemethod:: BaseEventLoop.create_datagram_endpoint(protocol_factory, local_addr=None, remote_addr=None, \*, family=0, proto=0, flags=0)
.. coroutinemethod:: BaseEventLoop.create_datagram_endpoint(protocol_factory, local_addr=None, remote_addr=None, \*, family=0, proto=0, flags=0, reuse_address=None, reuse_port=None, allow_broadcast=None, sock=None)
Create datagram connection: socket family :py:data:`~socket.AF_INET` or
:py:data:`~socket.AF_INET6` depending on *host* (or *family* if specified),
socket type :py:data:`~socket.SOCK_DGRAM`.
socket type :py:data:`~socket.SOCK_DGRAM`. *protocol_factory* must be a
callable returning a :ref:`protocol <asyncio-protocol>` instance.
This method is a :ref:`coroutine <coroutine>` which will try to
establish the connection in the background. When successful, the
coroutine returns a ``(transport, protocol)`` pair.
See the :meth:`BaseEventLoop.create_connection` method for parameters.
Options changing how the connection is created:
* *local_addr*, if given, is a ``(local_host, local_port)`` tuple used
to bind the socket to locally. The *local_host* and *local_port*
are looked up using :meth:`getaddrinfo`.
* *remote_addr*, if given, is a ``(remote_host, remote_port)`` tuple used
to connect the socket to a remote address. The *remote_host* and
*remote_port* are looked up using :meth:`getaddrinfo`.
* *family*, *proto*, *flags* are the optional address family, protocol
and flags to be passed through to :meth:`getaddrinfo` for *host*
resolution. If given, these should all be integers from the
corresponding :mod:`socket` module constants.
* *reuse_address* tells the kernel to reuse a local socket in
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified will automatically be set to True on
UNIX.
* *reuse_port* tells the kernel to allow this endpoint to be bound to the
same port as other existing endpoints are bound to, so long as they all
set this flag when being created. This option is not supported on Windows
and some UNIX's. If the :py:data:`~socket.SO_REUSEPORT` constant is not
defined then this capability is unsupported.
* *allow_broadcast* tells the kernel to allow this endpoint to send
messages to the broadcast address.
* *sock* can optionally be specified in order to use a preexisting,
already connected, :class:`socket.socket` object to be used by the
transport. If specified, *local_addr* and *remote_addr* should be omitted
(must be :const:`None`).
On Windows with :class:`ProactorEventLoop`, this method is not supported.
......@@ -320,7 +353,7 @@ Creating connections
Creating listening connections
------------------------------
.. coroutinemethod:: BaseEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None)
.. coroutinemethod:: BaseEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None)
Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
*host* and *port*.
......@@ -359,6 +392,11 @@ Creating listening connections
expire. If not specified will automatically be set to True on
UNIX.
* *reuse_port* tells the kernel to allow this endpoint to be bound to the
same port as other existing endpoints are bound to, so long as they all
set this flag when being created. This option is not supported on
Windows.
This method is a :ref:`coroutine <coroutine>`.
On Windows with :class:`ProactorEventLoop`, SSL/TLS is not supported.
......
......@@ -700,75 +700,109 @@ class BaseEventLoop(events.AbstractEventLoop):
@coroutine
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
family=0, proto=0, flags=0,
reuse_address=None, reuse_port=None,
allow_broadcast=None, sock=None):
"""Create datagram connection."""
if not (local_addr or remote_addr):
if family == 0:
raise ValueError('unexpected address family')
addr_pairs_info = (((family, proto), (None, None)),)
else:
# join address by (family, protocol)
addr_infos = collections.OrderedDict()
for idx, addr in ((0, local_addr), (1, remote_addr)):
if addr is not None:
assert isinstance(addr, tuple) and len(addr) == 2, (
'2-tuple is expected')
infos = yield from self.getaddrinfo(
*addr, family=family, type=socket.SOCK_DGRAM,
proto=proto, flags=flags)
if not infos:
raise OSError('getaddrinfo() returned empty list')
for fam, _, pro, _, address in infos:
key = (fam, pro)
if key not in addr_infos:
addr_infos[key] = [None, None]
addr_infos[key][idx] = address
# each addr has to have info for each (family, proto) pair
addr_pairs_info = [
(key, addr_pair) for key, addr_pair in addr_infos.items()
if not ((local_addr and addr_pair[0] is None) or
(remote_addr and addr_pair[1] is None))]
if not addr_pairs_info:
raise ValueError('can not get address information')
exceptions = []
for ((family, proto),
(local_address, remote_address)) in addr_pairs_info:
sock = None
if sock is not None:
if (local_addr or remote_addr or
family or proto or flags or
reuse_address or reuse_port or allow_broadcast):
# show the problematic kwargs in exception msg
opts = dict(local_addr=local_addr, remote_addr=remote_addr,
family=family, proto=proto, flags=flags,
reuse_address=reuse_address, reuse_port=reuse_port,
allow_broadcast=allow_broadcast)
problems = ', '.join(
'{}={}'.format(k, v) for k, v in opts.items() if v)
raise ValueError(
'socket modifier keyword arguments can not be used '
'when sock is specified. ({})'.format(problems))
sock.setblocking(False)
r_addr = None
try:
sock = socket.socket(
family=family, type=socket.SOCK_DGRAM, proto=proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(False)
if local_addr:
sock.bind(local_address)
if remote_addr:
yield from self.sock_connect(sock, remote_address)
r_addr = remote_address
except OSError as exc:
if sock is not None:
sock.close()
exceptions.append(exc)
except:
if sock is not None:
sock.close()
raise
else:
break
else:
raise exceptions[0]
if not (local_addr or remote_addr):
if family == 0:
raise ValueError('unexpected address family')
addr_pairs_info = (((family, proto), (None, None)),)
else:
# join address by (family, protocol)
addr_infos = collections.OrderedDict()
for idx, addr in ((0, local_addr), (1, remote_addr)):
if addr is not None:
assert isinstance(addr, tuple) and len(addr) == 2, (
'2-tuple is expected')
infos = yield from self.getaddrinfo(
*addr, family=family, type=socket.SOCK_DGRAM,
proto=proto, flags=flags)
if not infos:
raise OSError('getaddrinfo() returned empty list')
for fam, _, pro, _, address in infos:
key = (fam, pro)
if key not in addr_infos:
addr_infos[key] = [None, None]
addr_infos[key][idx] = address
# each addr has to have info for each (family, proto) pair
addr_pairs_info = [
(key, addr_pair) for key, addr_pair in addr_infos.items()
if not ((local_addr and addr_pair[0] is None) or
(remote_addr and addr_pair[1] is None))]
if not addr_pairs_info:
raise ValueError('can not get address information')
exceptions = []
if reuse_address is None:
reuse_address = os.name == 'posix' and sys.platform != 'cygwin'
for ((family, proto),
(local_address, remote_address)) in addr_pairs_info:
sock = None
r_addr = None
try:
sock = socket.socket(
family=family, type=socket.SOCK_DGRAM, proto=proto)
if reuse_address:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if reuse_port:
if not hasattr(socket, 'SO_REUSEPORT'):
raise ValueError(
'reuse_port not supported by socket module')
else:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if allow_broadcast:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setblocking(False)
if local_addr:
sock.bind(local_address)
if remote_addr:
yield from self.sock_connect(sock, remote_address)
r_addr = remote_address
except OSError as exc:
if sock is not None:
sock.close()
exceptions.append(exc)
except:
if sock is not None:
sock.close()
raise
else:
break
else:
raise exceptions[0]
protocol = protocol_factory()
waiter = futures.Future(loop=self)
transport = self._make_datagram_transport(sock, protocol, r_addr,
waiter)
transport = self._make_datagram_transport(
sock, protocol, r_addr, waiter)
if self._debug:
if local_addr:
logger.info("Datagram endpoint local_addr=%r remote_addr=%r "
......@@ -804,7 +838,8 @@ class BaseEventLoop(events.AbstractEventLoop):
sock=None,
backlog=100,
ssl=None,
reuse_address=None):
reuse_address=None,
reuse_port=None):
"""Create a TCP server.
The host parameter can be a string, in that case the TCP server is bound
......@@ -857,8 +892,15 @@ class BaseEventLoop(events.AbstractEventLoop):
continue
sockets.append(sock)
if reuse_address:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
True)
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR, True)
if reuse_port:
if not hasattr(socket, 'SO_REUSEPORT'):
raise ValueError(
'reuse_port not supported by socket module')
else:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT, True)
# Disable IPv4/IPv6 dual stack support (enabled by
# default on Linux) which makes a single socket
# listen on both address families.
......
......@@ -297,7 +297,8 @@ class AbstractEventLoop:
def create_server(self, protocol_factory, host=None, port=None, *,
family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
sock=None, backlog=100, ssl=None, reuse_address=None):
sock=None, backlog=100, ssl=None, reuse_address=None,
reuse_port=None):
"""A coroutine which creates a TCP server bound to host and port.
The return value is a Server object which can be used to stop
......@@ -327,6 +328,11 @@ class AbstractEventLoop:
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified will automatically be set to True on
UNIX.
reuse_port tells the kernel to allow this endpoint to be bound to
the same port as other existing endpoints are bound to, so long as
they all set this flag when being created. This option is not
supported on Windows.
"""
raise NotImplementedError
......@@ -358,7 +364,37 @@ class AbstractEventLoop:
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
family=0, proto=0, flags=0,
reuse_address=None, reuse_port=None,
allow_broadcast=None, sock=None):
"""A coroutine which creates a datagram endpoint.
This method will try to establish the endpoint in the background.
When successful, the coroutine returns a (transport, protocol) pair.
protocol_factory must be a callable returning a protocol instance.
socket family AF_INET or socket.AF_INET6 depending on host (or
family if specified), socket type SOCK_DGRAM.
reuse_address tells the kernel to reuse a local socket in
TIME_WAIT state, without waiting for its natural timeout to
expire. If not specified it will automatically be set to True on
UNIX.
reuse_port tells the kernel to allow this endpoint to be bound to
the same port as other existing endpoints are bound to, so long as
they all set this flag when being created. This option is not
supported on Windows and some UNIX's. If the
:py:data:`~socket.SO_REUSEPORT` constant is not defined then this
capability is unsupported.
allow_broadcast tells the kernel to allow this endpoint to send
messages to the broadcast address.
sock can optionally be specified in order to use a preexisting
socket object.
"""
raise NotImplementedError
# Pipes and subprocesses.
......
......@@ -3,6 +3,7 @@
import errno
import logging
import math
import os
import socket
import sys
import threading
......@@ -790,11 +791,11 @@ class MyProto(asyncio.Protocol):
class MyDatagramProto(asyncio.DatagramProtocol):
done = None
def __init__(self, create_future=False):
def __init__(self, create_future=False, loop=None):
self.state = 'INITIAL'
self.nbytes = 0
if create_future:
self.done = asyncio.Future()
self.done = asyncio.Future(loop=loop)
def connection_made(self, transport):
self.transport = transport
......@@ -1099,6 +1100,19 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
f = self.loop.create_server(MyProto, '0.0.0.0', 0)
self.assertRaises(OSError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket')
def test_create_server_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_STREAM = socket.SOCK_STREAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock()
f = self.loop.create_server(
MyProto, '0.0.0.0', 0, reuse_port=True)
self.assertRaises(ValueError, self.loop.run_until_complete, f)
@mock.patch('asyncio.base_events.socket')
def test_create_server_cant_bind(self, m_socket):
......@@ -1199,6 +1213,128 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertRaises(Err, self.loop.run_until_complete, fut)
self.assertTrue(m_sock.close.called)
def test_create_datagram_endpoint_sock(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
fut = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(create_future=True, loop=self.loop),
sock=sock)
transport, protocol = self.loop.run_until_complete(fut)
transport.close()
self.loop.run_until_complete(protocol.done)
self.assertEqual('CLOSED', protocol.state)
def test_create_datagram_endpoint_sock_sockopts(self):
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, family=1, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, proto=1, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, flags=1, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_address=True, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_port=True, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint(
MyDatagramProto, allow_broadcast=True, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, fut)
def test_create_datagram_endpoint_sockopts(self):
# Socket options should not be applied unless asked for.
# SO_REUSEADDR defaults to on for UNIX.
# SO_REUSEPORT is not available on all platforms.
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(create_future=True, loop=self.loop),
local_addr=('127.0.0.1', 0))
transport, protocol = self.loop.run_until_complete(coro)
sock = transport.get_extra_info('socket')
reuse_address_default_on = (
os.name == 'posix' and sys.platform != 'cygwin')
reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
if reuse_address_default_on:
self.assertTrue(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR))
else:
self.assertFalse(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR))
if reuseport_supported:
self.assertFalse(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT))
self.assertFalse(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_BROADCAST))
transport.close()
self.loop.run_until_complete(protocol.done)
self.assertEqual('CLOSED', protocol.state)
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(create_future=True, loop=self.loop),
local_addr=('127.0.0.1', 0),
reuse_address=True,
reuse_port=reuseport_supported,
allow_broadcast=True)
transport, protocol = self.loop.run_until_complete(coro)
sock = transport.get_extra_info('socket')
self.assertTrue(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEADDR))
if reuseport_supported:
self.assertTrue(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT))
else:
self.assertFalse(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT))
self.assertTrue(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_BROADCAST))
transport.close()
self.loop.run_until_complete(protocol.done)
self.assertEqual('CLOSED', protocol.state)
@mock.patch('asyncio.base_events.socket')
def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
m_socket.getaddrinfo = socket.getaddrinfo
m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
m_socket.SOL_SOCKET = socket.SOL_SOCKET
del m_socket.SO_REUSEPORT
m_socket.socket.return_value = mock.Mock()
coro = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
local_addr=('127.0.0.1', 0),
reuse_address=False,
reuse_port=True)
self.assertRaises(ValueError, self.loop.run_until_complete, coro)
def test_accept_connection_retry(self):
sock = mock.Mock()
sock.accept.side_effect = BlockingIOError()
......
......@@ -814,6 +814,32 @@ class EventLoopTestsMixin:
# close server
server.close()
@unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT')
def test_create_server_reuse_port(self):
proto = MyProto(self.loop)
f = self.loop.create_server(
lambda: proto, '0.0.0.0', 0)
server = self.loop.run_until_complete(f)
self.assertEqual(len(server.sockets), 1)
sock = server.sockets[0]
self.assertFalse(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT))
server.close()
test_utils.run_briefly(self.loop)
proto = MyProto(self.loop)
f = self.loop.create_server(
lambda: proto, '0.0.0.0', 0, reuse_port=True)
server = self.loop.run_until_complete(f)
self.assertEqual(len(server.sockets), 1)
sock = server.sockets[0]
self.assertTrue(
sock.getsockopt(
socket.SOL_SOCKET, socket.SO_REUSEPORT))
server.close()
def _make_unix_server(self, factory, **kwargs):
path = test_utils.gen_unix_socket_path()
self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
......@@ -1264,6 +1290,32 @@ class EventLoopTestsMixin:
self.assertEqual('CLOSED', client.state)
server.transport.close()
def test_create_datagram_endpoint_sock(self):
sock = None
local_address = ('127.0.0.1', 0)
infos = self.loop.run_until_complete(
self.loop.getaddrinfo(
*local_address, type=socket.SOCK_DGRAM))
for family, type, proto, cname, address in infos:
try:
sock = socket.socket(family=family, type=type, proto=proto)
sock.setblocking(False)
sock.bind(address)
except:
pass
else:
break
else:
assert False, 'Can not create socket.'
f = self.loop.create_connection(
lambda: MyDatagramProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, MyDatagramProto)
tr.close()
self.loop.run_until_complete(pr.done)
def test_internal_fds(self):
loop = self.create_event_loop()
if not isinstance(loop, selector_events.BaseSelectorEventLoop):
......
......@@ -789,6 +789,7 @@ Ben Laurie
Simon Law
Julia Lawall
Chris Lawrence
Chris Laws
Brian Leair
Mathieu Leduc-Hamel
Amandine Lee
......
......@@ -90,6 +90,12 @@ Core and Builtins
Library
-------
- Issue #23972: Updates asyncio datagram create method allowing reuseport
and reuseaddr socket options to be set prior to binding the socket.
Mirroring the existing asyncio create_server method the reuseaddr option
for datagram sockets defaults to True if the O/S is 'posix' (except if the
platform is Cygwin). Patch by Chris Laws.
- Issue #25304: Add asyncio.run_coroutine_threadsafe(). This lets you
submit a coroutine to a loop from another thread, returning a
concurrent.futures.Future. By Vincent Michel.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment