Commit 0d842785 authored by Guido van Rossum's avatar Guido van Rossum

asyncio: Add streams.start_server(), by Gustavo Carneiro.

parent 70c9775d
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection']
__all__ = ['StreamReader', 'StreamReaderProtocol',
'open_connection', 'start_server',
]
import collections
......@@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *,
return reader, writer
@tasks.coroutine
def start_server(client_connected_cb, host=None, port=None, *,
loop=None, limit=_DEFAULT_LIMIT, **kwds):
"""Start a socket server, call back for each client connected.
The first parameter, `client_connected_cb`, takes two parameters:
client_reader, client_writer. client_reader is a StreamReader
object, while client_writer is a StreamWriter object. This
parameter can either be a plain callback function or a coroutine;
if it is a coroutine, it will be automatically converted into a
Task.
The rest of the arguments are all the usual arguments to
loop.create_server() except protocol_factory; most common are
positional host and port, with various optional keyword arguments
following. The return value is the same as loop.create_server().
Additional optional keyword arguments are loop (to set the event loop
instance to use) and limit (to set the buffer limit passed to the
StreamReader).
The return value is the same as loop.create_server(), i.e. a
Server object which can be used to stop the service.
"""
if loop is None:
loop = events.get_event_loop()
def factory():
reader = StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop)
return protocol
return (yield from loop.create_server(factory, host, port, **kwds))
class StreamReaderProtocol(protocols.Protocol):
"""Trivial helper class to adapt between Protocol and StreamReader.
......@@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol):
call inappropriate methods of the protocol.)
"""
def __init__(self, stream_reader):
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
self._stream_reader = stream_reader
self._stream_writer = None
self._drain_waiter = None
self._paused = False
self._client_connected_cb = client_connected_cb
self._loop = loop # May be None; we may never need it.
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
self._stream_reader,
self._loop)
res = self._client_connected_cb(self._stream_reader,
self._stream_writer)
if tasks.iscoroutine(res):
tasks.Task(res, loop=self._loop)
def connection_lost(self, exc):
if exc is None:
......
......@@ -359,6 +359,72 @@ class StreamReaderTests(unittest.TestCase):
test_utils.run_briefly(self.loop)
self.assertIs(stream._waiter, None)
def test_start_server(self):
class MyServer:
def __init__(self, loop):
self.server = None
self.loop = loop
@tasks.coroutine
def handle_client(self, client_reader, client_writer):
data = yield from client_reader.readline()
client_writer.write(data)
def start(self):
self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client,
'127.0.0.1', 12345,
loop=self.loop))
def handle_client_callback(self, client_reader, client_writer):
task = tasks.Task(client_reader.readline(), loop=self.loop)
def done(task):
client_writer.write(task.result())
task.add_done_callback(done)
def start_callback(self):
self.server = self.loop.run_until_complete(
streams.start_server(self.handle_client_callback,
'127.0.0.1', 12345,
loop=self.loop))
def stop(self):
if self.server is not None:
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.server = None
@tasks.coroutine
def client():
reader, writer = yield from streams.open_connection(
'127.0.0.1', 12345, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
msgback = yield from reader.readline()
writer.close()
return msgback
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
server.start()
msg = self.loop.run_until_complete(tasks.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
server.start_callback()
msg = self.loop.run_until_complete(tasks.Task(client(),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment