Commit 519883d2 authored by Yury Selivanov's avatar Yury Selivanov

asyncio: Add support for UNIX Domain Sockets.

parent 242e2659
...@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop): ...@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
sock.setblocking(False) sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def _create_connection_transport(self, sock, protocol_factory, ssl,
server_hostname):
protocol = protocol_factory() protocol = protocol_factory()
waiter = futures.Future(loop=self) waiter = futures.Future(loop=self)
if ssl: if ssl:
......
...@@ -220,6 +220,32 @@ class AbstractEventLoop: ...@@ -220,6 +220,32 @@ class AbstractEventLoop:
""" """
raise NotImplementedError raise NotImplementedError
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
raise NotImplementedError
def create_unix_server(self, protocol_factory, path, *,
sock=None, backlog=100, ssl=None):
"""A coroutine which creates a UNIX Domain Socket server.
The return valud is a Server object, which can be used to stop
the service.
path is a str, representing a file systsem path to bind the
server socket to.
sock can optionally be specified in order to use a preexisting
socket object.
backlog is the maximum number of queued connections passed to
listen() (defaults to 100).
ssl can be set to an SSLContext to enable SSL over the
accepted connections.
"""
raise NotImplementedError
def create_datagram_endpoint(self, protocol_factory, def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *, local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0): family=0, proto=0, flags=0):
......
"""Stream-related things.""" """Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
'open_connection', 'start_server', 'IncompleteReadError', 'open_connection', 'start_server',
'open_unix_connection', 'start_unix_server',
'IncompleteReadError',
] ]
import socket
from . import events from . import events
from . import futures from . import futures
from . import protocols from . import protocols
...@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *, ...@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds)) return (yield from loop.create_server(factory, host, port, **kwds))
if hasattr(socket, 'AF_UNIX'):
# UNIX Domain Sockets are supported on this platform
@tasks.coroutine
def open_unix_connection(path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = yield from loop.create_unix_connection(
lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
@tasks.coroutine
def start_unix_server(client_connected_cb, path=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Similar to `start_server` but works with UNIX Domain Sockets."""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_unix_server(factory, path, **kwds))
class FlowControlMixin(protocols.Protocol): class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain(). """Reusable flow control logic for StreamWriter.drain().
......
...@@ -4,12 +4,18 @@ import collections ...@@ -4,12 +4,18 @@ import collections
import contextlib import contextlib
import io import io
import os import os
import socket
import socketserver
import sys import sys
import tempfile
import threading import threading
import time import time
import unittest import unittest
import unittest.mock import unittest.mock
from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
try: try:
import ssl import ssl
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
...@@ -70,42 +76,51 @@ def run_once(loop): ...@@ -70,42 +76,51 @@ def run_once(loop):
loop.run_forever() loop.run_forever()
@contextlib.contextmanager class SilentWSGIRequestHandler(WSGIRequestHandler):
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
class SilentWSGIRequestHandler(WSGIRequestHandler): def get_stderr(self):
def get_stderr(self): return io.StringIO()
return io.StringIO()
def log_message(self, format, *args): def log_message(self, format, *args):
pass pass
class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address): class SilentWSGIServer(WSGIServer):
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass pass
class SSLWSGIServer(SilentWSGIServer):
def finish_request(self, request, client_address): class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
# The relative location of our test directory (which pass
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it. def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
here = os.path.join(os.path.dirname(__file__), '..', 'tests')
if not os.path.isdir(here):
here = os.path.join(os.path.dirname(os.__file__),
'test', 'test_asyncio')
keyfile = os.path.join(here, 'ssl_key.pem')
certfile = os.path.join(here, 'ssl_cert.pem')
ssock = ssl.wrap_socket(request,
keyfile=keyfile,
certfile=certfile,
server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
def app(environ, start_response): def app(environ, start_response):
status = '200 OK' status = '200 OK'
...@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): ...@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
# Run the test WSGI server in a separate thread in order not to # Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread # interfere with event handling in the main thread
server_class = SSLWSGIServer if use_ssl else SilentWSGIServer server_class = server_ssl_cls if use_ssl else server_cls
httpd = make_server(host, port, app, httpd = server_class(address, SilentWSGIRequestHandler)
server_class, SilentWSGIRequestHandler) httpd.set_app(app)
httpd.address = httpd.server_address httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever) server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start() server_thread.start()
...@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): ...@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_thread.join() server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file:
return file.name
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def make_test_protocol(base): def make_test_protocol(base):
dct = {} dct = {}
for name in dir(base): for name in dir(base):
...@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop): ...@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
def _write_to_self(self): def _write_to_self(self):
pass pass
def MockCallback(**kwargs): def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs) return unittest.mock.Mock(spec=['__call__'], **kwargs)
...@@ -11,6 +11,7 @@ import sys ...@@ -11,6 +11,7 @@ import sys
import threading import threading
from . import base_events
from . import base_subprocess from . import base_subprocess
from . import constants from . import constants
from . import events from . import events
...@@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover ...@@ -31,9 +32,9 @@ if sys.platform == 'win32': # pragma: no cover
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
"""Unix event loop """Unix event loop.
Adds signal handling to SelectorEventLoop Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
""" """
def __init__(self, selector=None): def __init__(self, selector=None):
...@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): ...@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def _child_watcher_callback(self, pid, returncode, transp): def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode) self.call_soon_threadsafe(transp._process_exited, returncode)
@tasks.coroutine
def create_unix_connection(self, protocol_factory, path, *,
ssl=None, sock=None,
server_hostname=None):
assert server_hostname is None or isinstance(server_hostname, str)
if ssl:
if server_hostname is None:
raise ValueError(
'you have to pass server_hostname when using ssl')
else:
if server_hostname is not None:
raise ValueError('server_hostname is only meaningful with ssl')
if path is not None:
if sock is not None:
raise ValueError(
'path and sock can not be specified at the same time')
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
sock.setblocking(False)
yield from self.sock_connect(sock, path)
except OSError:
if sock is not None:
sock.close()
raise
else:
if sock is None:
raise ValueError('no path and sock were specified')
sock.setblocking(False)
transport, protocol = yield from self._create_connection_transport(
sock, protocol_factory, ssl, server_hostname)
return transport, protocol
@tasks.coroutine
def create_unix_server(self, protocol_factory, path=None, *,
sock=None, backlog=100, ssl=None):
if isinstance(ssl, bool):
raise TypeError('ssl argument must be an SSLContext or None')
if path is not None:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.bind(path)
except OSError as exc:
if exc.errno == errno.EADDRINUSE:
# Let's improve the error message by adding
# with what exact address it occurs.
msg = 'Address {!r} is already in use'.format(path)
raise OSError(errno.EADDRINUSE, msg) from None
else:
raise
else:
if sock is None:
raise ValueError(
'path was not specified, and no sock specified')
if sock.family != socket.AF_UNIX:
raise ValueError(
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
server = base_events.Server(self, [sock])
sock.listen(backlog)
sock.setblocking(False)
self._start_serving(protocol_factory, sock, ssl, server)
return server
def _set_nonblocking(fd): def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL) flags = fcntl.fcntl(fd, fcntl.F_GETFL)
......
...@@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase): ...@@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
idx = -1 idx = -1
data = [10.0, 10.0, 10.3, 13.0] data = [10.0, 10.0, 10.3, 13.0]
self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
self.loop._run_once() self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
......
This diff is collapsed.
...@@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase): ...@@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock()
waiter = asyncio.Future(loop=self.loop) waiter = asyncio.Future(loop=self.loop)
transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) transport = self.loop._make_ssl_transport(
m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport) self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None) @unittest.mock.patch('asyncio.selector_events.ssl', None)
......
This diff is collapsed.
...@@ -7,8 +7,10 @@ import io ...@@ -7,8 +7,10 @@ import io
import os import os
import pprint import pprint
import signal import signal
import socket
import stat import stat
import sys import sys
import tempfile
import threading import threading
import unittest import unittest
import unittest.mock import unittest.mock
...@@ -24,7 +26,7 @@ from asyncio import unix_events ...@@ -24,7 +26,7 @@ from asyncio import unix_events
@unittest.skipUnless(signal, 'Signals are not supported') @unittest.skipUnless(signal, 'Signals are not supported')
class SelectorEventLoopTests(unittest.TestCase): class SelectorEventLoopSignalTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.SelectorEventLoop() self.loop = asyncio.SelectorEventLoop()
...@@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase): ...@@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
m_signal.set_wakeup_fd.assert_called_once_with(-1) m_signal.set_wakeup_fd.assert_called_once_with(-1)
@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
'UNIX Sockets are not supported')
class SelectorEventLoopUnixSocketTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
asyncio.set_event_loop(None)
def tearDown(self):
self.loop.close()
def test_create_unix_server_existing_path_sock(self):
with test_utils.unix_socket_path() as path:
sock = socket.socket(socket.AF_UNIX)
sock.bind(path)
coro = self.loop.create_unix_server(lambda: None, path)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_existing_path_nonsock(self):
with tempfile.NamedTemporaryFile() as file:
coro = self.loop.create_unix_server(lambda: None, file.name)
with self.assertRaisesRegexp(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_ssl_bool(self):
coro = self.loop.create_unix_server(lambda: None, path='spam',
ssl=True)
with self.assertRaisesRegex(TypeError,
'ssl argument must be an SSLContext'):
self.loop.run_until_complete(coro)
def test_create_unix_server_nopath_nosock(self):
coro = self.loop.create_unix_server(lambda: None, path=None)
with self.assertRaisesRegex(ValueError,
'path was not specified, and no sock'):
self.loop.run_until_complete(coro)
def test_create_unix_server_path_inetsock(self):
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=socket.socket())
with self.assertRaisesRegex(ValueError,
'A UNIX Domain Socket was expected'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_path_sock(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', sock=object())
with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nopath_nosock(self):
coro = self.loop.create_unix_connection(
lambda: None, None)
with self.assertRaisesRegex(ValueError,
'no path and sock were specified'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_nossl_serverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', server_hostname='spam')
with self.assertRaisesRegex(ValueError,
'server_hostname is only meaningful'):
self.loop.run_until_complete(coro)
def test_create_unix_connection_ssl_noserverhost(self):
coro = self.loop.create_unix_connection(
lambda: None, '/dev/null', ssl=True)
with self.assertRaisesRegexp(
ValueError, 'you have to pass server_hostname when using ssl'):
self.loop.run_until_complete(coro)
class UnixReadPipeTransportTests(unittest.TestCase): class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment