Commit a1a8b7d3 authored by Yury Selivanov's avatar Yury Selivanov

Issue #28652: Make loop methods reject socket kinds they do not support.

parent d2fd3599
...@@ -84,12 +84,26 @@ def _set_reuseport(sock): ...@@ -84,12 +84,26 @@ def _set_reuseport(sock):
'SO_REUSEPORT defined but not implemented.') 'SO_REUSEPORT defined but not implemented.')
# Linux's sock.type is a bitmask that can include extra info about socket. def _is_stream_socket(sock):
_SOCKET_TYPE_MASK = 0 # Linux's socket.type is a bitmask that can include extra info
if hasattr(socket, 'SOCK_NONBLOCK'): # about socket, therefore we can't do simple
_SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK # `sock_type == socket.SOCK_STREAM`.
if hasattr(socket, 'SOCK_CLOEXEC'): return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM
_SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
def _is_dgram_socket(sock):
# Linux's socket.type is a bitmask that can include extra info
# about socket, therefore we can't do simple
# `sock_type == socket.SOCK_DGRAM`.
return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM
def _is_ip_socket(sock):
if sock.family == socket.AF_INET:
return True
if hasattr(socket, 'AF_INET6') and sock.family == socket.AF_INET6:
return True
return False
def _ipaddr_info(host, port, family, type, proto): def _ipaddr_info(host, port, family, type, proto):
...@@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto): ...@@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto):
host is None: host is None:
return None return None
type &= ~_SOCKET_TYPE_MASK
if type == socket.SOCK_STREAM: if type == socket.SOCK_STREAM:
# Linux only:
# getaddrinfo() can raise when socket.type is a bit mask.
# So if socket.type is a bit mask of SOCK_STREAM, and say
# SOCK_NONBLOCK, we simply return None, which will trigger
# a call to getaddrinfo() letting it process this request.
proto = socket.IPPROTO_TCP proto = socket.IPPROTO_TCP
elif type == socket.SOCK_DGRAM: elif type == socket.SOCK_DGRAM:
proto = socket.IPPROTO_UDP proto = socket.IPPROTO_UDP
...@@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto): ...@@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto):
return None return None
if family == socket.AF_UNSPEC: if family == socket.AF_UNSPEC:
afs = [socket.AF_INET, socket.AF_INET6] afs = [socket.AF_INET]
if hasattr(socket, 'AF_INET6'):
afs.append(socket.AF_INET6)
else: else:
afs = [family] afs = [family]
...@@ -771,9 +791,13 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -771,9 +791,13 @@ class BaseEventLoop(events.AbstractEventLoop):
raise OSError('Multiple exceptions: {}'.format( raise OSError('Multiple exceptions: {}'.format(
', '.join(str(exc) for exc in exceptions))) ', '.join(str(exc) for exc in exceptions)))
elif sock is None: else:
raise ValueError( if sock is None:
'host and port was not specified and no sock specified') raise ValueError(
'host and port was not specified and no sock specified')
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
raise ValueError(
'A TCP Stream Socket was expected, got {!r}'.format(sock))
transport, protocol = yield from self._create_connection_transport( transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname) sock, protocol_factory, ssl, server_hostname)
...@@ -817,6 +841,9 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -817,6 +841,9 @@ class BaseEventLoop(events.AbstractEventLoop):
allow_broadcast=None, sock=None): allow_broadcast=None, sock=None):
"""Create datagram connection.""" """Create datagram connection."""
if sock is not None: if sock is not None:
if not _is_dgram_socket(sock):
raise ValueError(
'A UDP Socket was expected, got {!r}'.format(sock))
if (local_addr or remote_addr or if (local_addr or remote_addr or
family or proto or flags or family or proto or flags or
reuse_address or reuse_port or allow_broadcast): reuse_address or reuse_port or allow_broadcast):
...@@ -1027,6 +1054,9 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -1027,6 +1054,9 @@ class BaseEventLoop(events.AbstractEventLoop):
else: else:
if sock is None: if sock is None:
raise ValueError('Neither host/port nor sock were specified') raise ValueError('Neither host/port nor sock were specified')
if not _is_stream_socket(sock) or not _is_ip_socket(sock):
raise ValueError(
'A TCP Stream Socket was expected, got {!r}'.format(sock))
sockets = [sock] sockets = [sock]
server = Server(self, sockets) server = Server(self, sockets)
...@@ -1048,6 +1078,10 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -1048,6 +1078,10 @@ class BaseEventLoop(events.AbstractEventLoop):
This method is a coroutine. When completed, the coroutine This method is a coroutine. When completed, the coroutine
returns a (transport, protocol) pair. returns a (transport, protocol) pair.
""" """
if not _is_stream_socket(sock):
raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock))
transport, protocol = yield from self._create_connection_transport( transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, '', server_side=True) sock, protocol_factory, ssl, '', server_side=True)
if self._debug: if self._debug:
......
...@@ -235,7 +235,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -235,7 +235,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if sock is None: if sock is None:
raise ValueError('no path and sock were specified') raise ValueError('no path and sock were specified')
if (sock.family != socket.AF_UNIX or if (sock.family != socket.AF_UNIX or
sock.type != socket.SOCK_STREAM): not base_events._is_stream_socket(sock)):
raise ValueError( raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}' 'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock)) .format(sock))
...@@ -289,7 +289,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -289,7 +289,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
'path was not specified, and no sock specified') 'path was not specified, and no sock specified')
if (sock.family != socket.AF_UNIX or if (sock.family != socket.AF_UNIX or
sock.type != socket.SOCK_STREAM): not base_events._is_stream_socket(sock)):
raise ValueError( raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}' 'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock)) .format(sock))
......
...@@ -116,6 +116,13 @@ class BaseEventTests(test_utils.TestCase): ...@@ -116,6 +116,13 @@ class BaseEventTests(test_utils.TestCase):
self.assertIsNone( self.assertIsNone(
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP)) base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
if hasattr(socket, 'SOCK_NONBLOCK'):
self.assertEqual(
None,
base_events._ipaddr_info(
'1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP))
def test_port_parameter_types(self): def test_port_parameter_types(self):
# Test obscure kinds of arguments for "port". # Test obscure kinds of arguments for "port".
INET = socket.AF_INET INET = socket.AF_INET
...@@ -1040,6 +1047,43 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1040,6 +1047,43 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
MyProto, 'example.com', 80, sock=object()) MyProto, 'example.com', 80, sock=object())
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
def test_create_connection_wrong_sock(self):
sock = socket.socket(socket.AF_UNIX)
with sock:
coro = self.loop.create_connection(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A TCP Stream Socket was expected'):
self.loop.run_until_complete(coro)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
def test_create_server_wrong_sock(self):
sock = socket.socket(socket.AF_UNIX)
with sock:
coro = self.loop.create_server(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A TCP Stream Socket was expected'):
self.loop.run_until_complete(coro)
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)')
def test_create_server_stream_bittype(self):
sock = socket.socket(
socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with sock:
coro = self.loop.create_server(lambda: None, sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
def test_create_datagram_endpoint_wrong_sock(self):
sock = socket.socket(socket.AF_INET)
with sock:
coro = self.loop.create_datagram_endpoint(MyProto, sock=sock)
with self.assertRaisesRegex(ValueError,
'A UDP Socket was expected'):
self.loop.run_until_complete(coro)
def test_create_connection_no_host_port_sock(self): def test_create_connection_no_host_port_sock(self):
coro = self.loop.create_connection(MyProto) coro = self.loop.create_connection(MyProto)
self.assertRaises(ValueError, self.loop.run_until_complete, coro) self.assertRaises(ValueError, self.loop.run_until_complete, coro)
...@@ -1487,36 +1531,39 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): ...@@ -1487,36 +1531,39 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
self.assertEqual('CLOSED', protocol.state) self.assertEqual('CLOSED', protocol.state)
def test_create_datagram_endpoint_sock_sockopts(self): def test_create_datagram_endpoint_sock_sockopts(self):
class FakeSock:
type = socket.SOCK_DGRAM
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object()) MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object()) MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, family=1, sock=object()) MyDatagramProto, family=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, proto=1, sock=object()) MyDatagramProto, proto=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, flags=1, sock=object()) MyDatagramProto, flags=1, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_address=True, sock=object()) MyDatagramProto, reuse_address=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, reuse_port=True, sock=object()) MyDatagramProto, reuse_port=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
fut = self.loop.create_datagram_endpoint( fut = self.loop.create_datagram_endpoint(
MyDatagramProto, allow_broadcast=True, sock=object()) MyDatagramProto, allow_broadcast=True, sock=FakeSock())
self.assertRaises(ValueError, self.loop.run_until_complete, fut) self.assertRaises(ValueError, self.loop.run_until_complete, fut)
def test_create_datagram_endpoint_sockopts(self): def test_create_datagram_endpoint_sockopts(self):
......
...@@ -791,9 +791,9 @@ class EventLoopTestsMixin: ...@@ -791,9 +791,9 @@ class EventLoopTestsMixin:
conn, _ = lsock.accept() conn, _ = lsock.accept()
proto = MyProto(loop=loop) proto = MyProto(loop=loop)
proto.loop = loop proto.loop = loop
f = loop.create_task( loop.run_until_complete(
loop.connect_accepted_socket( loop.connect_accepted_socket(
(lambda : proto), conn, ssl=server_ssl)) (lambda: proto), conn, ssl=server_ssl))
loop.run_forever() loop.run_forever()
proto.transport.close() proto.transport.close()
lsock.close() lsock.close()
...@@ -1377,6 +1377,11 @@ class EventLoopTestsMixin: ...@@ -1377,6 +1377,11 @@ class EventLoopTestsMixin:
server.transport.close() server.transport.close()
def test_create_datagram_endpoint_sock(self): def test_create_datagram_endpoint_sock(self):
if (sys.platform == 'win32' and
isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
raise unittest.SkipTest(
'UDP is not supported with proactor event loops')
sock = None sock = None
local_address = ('127.0.0.1', 0) local_address = ('127.0.0.1', 0)
infos = self.loop.run_until_complete( infos = self.loop.run_until_complete(
...@@ -1394,7 +1399,7 @@ class EventLoopTestsMixin: ...@@ -1394,7 +1399,7 @@ class EventLoopTestsMixin:
else: else:
assert False, 'Can not create socket.' assert False, 'Can not create socket.'
f = self.loop.create_connection( f = self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop), sock=sock) lambda: MyDatagramProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f) tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport) self.assertIsInstance(tr, asyncio.Transport)
......
...@@ -280,6 +280,33 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase): ...@@ -280,6 +280,33 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
'A UNIX Domain Stream.*was expected'): 'A UNIX Domain Stream.*was expected'):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
def test_create_unix_server_path_dgram(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
with sock:
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Stream.*was expected'):
self.loop.run_until_complete(coro)
@unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
'no socket.SOCK_NONBLOCK (linux only)')
def test_create_unix_server_path_stream_bittype(self):
sock = socket.socket(
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with tempfile.NamedTemporaryFile() as file:
fn = file.name
try:
with sock:
sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
finally:
os.unlink(fn)
def test_create_unix_connection_path_inetsock(self): def test_create_unix_connection_path_inetsock(self):
sock = socket.socket() sock = socket.socket()
with sock: with sock:
......
...@@ -455,6 +455,8 @@ Library ...@@ -455,6 +455,8 @@ Library
- Issue #28639: Fix inspect.isawaitable to always return bool - Issue #28639: Fix inspect.isawaitable to always return bool
Patch by Justin Mayfield. Patch by Justin Mayfield.
- Issue #28652: Make loop methods reject socket kinds they do not support.
IDLE IDLE
---- ----
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment