Commit ad4ed872 authored by Andrew Svetlov's avatar Andrew Svetlov Committed by GitHub

Forbid creating of stream objects outside of asyncio (#13101)

parent 2cc0223f
...@@ -4,6 +4,7 @@ __all__ = ( ...@@ -4,6 +4,7 @@ __all__ = (
import socket import socket
import sys import sys
import warnings
import weakref import weakref
if hasattr(socket, 'AF_UNIX'): if hasattr(socket, 'AF_UNIX'):
...@@ -42,11 +43,14 @@ async def open_connection(host=None, port=None, *, ...@@ -42,11 +43,14 @@ async def open_connection(host=None, port=None, *,
""" """
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop) reader = StreamReader(limit=limit, loop=loop,
protocol = StreamReaderProtocol(reader, loop=loop) _asyncio_internal=True)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_connection( transport, _ = await loop.create_connection(
lambda: protocol, host, port, **kwds) lambda: protocol, host, port, **kwds)
writer = StreamWriter(transport, protocol, reader, loop) writer = StreamWriter(transport, protocol, reader, loop,
_asyncio_internal=True)
return reader, writer return reader, writer
...@@ -77,9 +81,11 @@ async def start_server(client_connected_cb, host=None, port=None, *, ...@@ -77,9 +81,11 @@ async def start_server(client_connected_cb, host=None, port=None, *,
loop = events.get_event_loop() loop = events.get_event_loop()
def factory(): def factory():
reader = StreamReader(limit=limit, loop=loop) reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, client_connected_cb, protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop) loop=loop,
_asyncio_internal=True)
return protocol return protocol
return await loop.create_server(factory, host, port, **kwds) return await loop.create_server(factory, host, port, **kwds)
...@@ -93,11 +99,14 @@ if hasattr(socket, 'AF_UNIX'): ...@@ -93,11 +99,14 @@ if hasattr(socket, 'AF_UNIX'):
"""Similar to `open_connection` but works with UNIX Domain Sockets.""" """Similar to `open_connection` but works with UNIX Domain Sockets."""
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
reader = StreamReader(limit=limit, loop=loop) reader = StreamReader(limit=limit, loop=loop,
protocol = StreamReaderProtocol(reader, loop=loop) _asyncio_internal=True)
protocol = StreamReaderProtocol(reader, loop=loop,
_asyncio_internal=True)
transport, _ = await loop.create_unix_connection( transport, _ = await loop.create_unix_connection(
lambda: protocol, path, **kwds) lambda: protocol, path, **kwds)
writer = StreamWriter(transport, protocol, reader, loop) writer = StreamWriter(transport, protocol, reader, loop,
_asyncio_internal=True)
return reader, writer return reader, writer
async def start_unix_server(client_connected_cb, path=None, *, async def start_unix_server(client_connected_cb, path=None, *,
...@@ -107,9 +116,11 @@ if hasattr(socket, 'AF_UNIX'): ...@@ -107,9 +116,11 @@ if hasattr(socket, 'AF_UNIX'):
loop = events.get_event_loop() loop = events.get_event_loop()
def factory(): def factory():
reader = StreamReader(limit=limit, loop=loop) reader = StreamReader(limit=limit, loop=loop,
_asyncio_internal=True)
protocol = StreamReaderProtocol(reader, client_connected_cb, protocol = StreamReaderProtocol(reader, client_connected_cb,
loop=loop) loop=loop,
_asyncio_internal=True)
return protocol return protocol
return await loop.create_unix_server(factory, path, **kwds) return await loop.create_unix_server(factory, path, **kwds)
...@@ -125,11 +136,20 @@ class FlowControlMixin(protocols.Protocol): ...@@ -125,11 +136,20 @@ class FlowControlMixin(protocols.Protocol):
StreamWriter.drain() must wait for _drain_helper() coroutine. StreamWriter.drain() must wait for _drain_helper() coroutine.
""" """
def __init__(self, loop=None): def __init__(self, loop=None, *, _asyncio_internal=False):
if loop is None: if loop is None:
self._loop = events.get_event_loop() self._loop = events.get_event_loop()
else: else:
self._loop = loop self._loop = loop
if not _asyncio_internal:
# NOTE:
# Avoid inheritance from FlowControlMixin
# Copy-paste the code to your project
# if you need flow control helpers
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
self._paused = False self._paused = False
self._drain_waiter = None self._drain_waiter = None
self._connection_lost = False self._connection_lost = False
...@@ -191,8 +211,9 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): ...@@ -191,8 +211,9 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
_source_traceback = None _source_traceback = None
def __init__(self, stream_reader, client_connected_cb=None, loop=None): def __init__(self, stream_reader, client_connected_cb=None, loop=None,
super().__init__(loop=loop) *, _asyncio_internal=False):
super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
if stream_reader is not None: if stream_reader is not None:
self._stream_reader_wr = weakref.ref(stream_reader, self._stream_reader_wr = weakref.ref(stream_reader,
self._on_reader_gc) self._on_reader_gc)
...@@ -253,7 +274,8 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): ...@@ -253,7 +274,8 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
if self._client_connected_cb is not None: if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self, self._stream_writer = StreamWriter(transport, self,
reader, reader,
self._loop) self._loop,
_asyncio_internal=True)
res = self._client_connected_cb(reader, res = self._client_connected_cb(reader,
self._stream_writer) self._stream_writer)
if coroutines.iscoroutine(res): if coroutines.iscoroutine(res):
...@@ -311,7 +333,13 @@ class StreamWriter: ...@@ -311,7 +333,13 @@ class StreamWriter:
directly. directly.
""" """
def __init__(self, transport, protocol, reader, loop): def __init__(self, transport, protocol, reader, loop,
*, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
self._transport = transport self._transport = transport
self._protocol = protocol self._protocol = protocol
# drain() expects that the reader has an exception() method # drain() expects that the reader has an exception() method
...@@ -388,7 +416,14 @@ class StreamReader: ...@@ -388,7 +416,14 @@ class StreamReader:
_source_traceback = None _source_traceback = None
def __init__(self, limit=_DEFAULT_LIMIT, loop=None): def __init__(self, limit=_DEFAULT_LIMIT, loop=None,
*, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
# The line length limit is a security feature; # The line length limit is a security feature;
# it also doubles as half the buffer limit. # it also doubles as half the buffer limit.
......
__all__ = 'create_subprocess_exec', 'create_subprocess_shell' __all__ = 'create_subprocess_exec', 'create_subprocess_shell'
import subprocess import subprocess
import warnings
from . import events from . import events
from . import protocols from . import protocols
...@@ -18,8 +19,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -18,8 +19,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
protocols.SubprocessProtocol): protocols.SubprocessProtocol):
"""Like StreamReaderProtocol, but for a subprocess.""" """Like StreamReaderProtocol, but for a subprocess."""
def __init__(self, limit, loop): def __init__(self, limit, loop, *, _asyncio_internal=False):
super().__init__(loop=loop) super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
self._limit = limit self._limit = limit
self.stdin = self.stdout = self.stderr = None self.stdin = self.stdout = self.stderr = None
self._transport = None self._transport = None
...@@ -42,14 +43,16 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -42,14 +43,16 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
stdout_transport = transport.get_pipe_transport(1) stdout_transport = transport.get_pipe_transport(1)
if stdout_transport is not None: if stdout_transport is not None:
self.stdout = streams.StreamReader(limit=self._limit, self.stdout = streams.StreamReader(limit=self._limit,
loop=self._loop) loop=self._loop,
_asyncio_internal=True)
self.stdout.set_transport(stdout_transport) self.stdout.set_transport(stdout_transport)
self._pipe_fds.append(1) self._pipe_fds.append(1)
stderr_transport = transport.get_pipe_transport(2) stderr_transport = transport.get_pipe_transport(2)
if stderr_transport is not None: if stderr_transport is not None:
self.stderr = streams.StreamReader(limit=self._limit, self.stderr = streams.StreamReader(limit=self._limit,
loop=self._loop) loop=self._loop,
_asyncio_internal=True)
self.stderr.set_transport(stderr_transport) self.stderr.set_transport(stderr_transport)
self._pipe_fds.append(2) self._pipe_fds.append(2)
...@@ -58,7 +61,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -58,7 +61,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
self.stdin = streams.StreamWriter(stdin_transport, self.stdin = streams.StreamWriter(stdin_transport,
protocol=self, protocol=self,
reader=None, reader=None,
loop=self._loop) loop=self._loop,
_asyncio_internal=True)
def pipe_data_received(self, fd, data): def pipe_data_received(self, fd, data):
if fd == 1: if fd == 1:
...@@ -104,7 +108,13 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, ...@@ -104,7 +108,13 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
class Process: class Process:
def __init__(self, transport, protocol, loop): def __init__(self, transport, protocol, loop, *, _asyncio_internal=False):
if not _asyncio_internal:
warnings.warn(f"{self.__class__} should be instaniated "
"by asyncio internals only, "
"please avoid its creation from user code",
DeprecationWarning)
self._transport = transport self._transport = transport
self._protocol = protocol self._protocol = protocol
self._loop = loop self._loop = loop
...@@ -195,12 +205,13 @@ async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, ...@@ -195,12 +205,13 @@ async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop) loop=loop,
_asyncio_internal=True)
transport, protocol = await loop.subprocess_shell( transport, protocol = await loop.subprocess_shell(
protocol_factory, protocol_factory,
cmd, stdin=stdin, stdout=stdout, cmd, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds) stderr=stderr, **kwds)
return Process(transport, protocol, loop) return Process(transport, protocol, loop, _asyncio_internal=True)
async def create_subprocess_exec(program, *args, stdin=None, stdout=None, async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
...@@ -209,10 +220,11 @@ async def create_subprocess_exec(program, *args, stdin=None, stdout=None, ...@@ -209,10 +220,11 @@ async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
if loop is None: if loop is None:
loop = events.get_event_loop() loop = events.get_event_loop()
protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
loop=loop) loop=loop,
_asyncio_internal=True)
transport, protocol = await loop.subprocess_exec( transport, protocol = await loop.subprocess_exec(
protocol_factory, protocol_factory,
program, *args, program, *args,
stdin=stdin, stdout=stdout, stdin=stdin, stdout=stdout,
stderr=stderr, **kwds) stderr=stderr, **kwds)
return Process(transport, protocol, loop) return Process(transport, protocol, loop, _asyncio_internal=True)
This diff is collapsed.
...@@ -510,6 +510,18 @@ class SubprocessMixin: ...@@ -510,6 +510,18 @@ class SubprocessMixin:
self.loop.run_until_complete(execute()) self.loop.run_until_complete(execute())
def test_subprocess_protocol_create_warning(self):
with self.assertWarns(DeprecationWarning):
subprocess.SubprocessStreamProtocol(limit=10, loop=self.loop)
def test_process_create_warning(self):
proto = subprocess.SubprocessStreamProtocol(limit=10, loop=self.loop,
_asyncio_internal=True)
transp = mock.Mock()
with self.assertWarns(DeprecationWarning):
subprocess.Process(transp, proto, loop=self.loop)
if sys.platform != 'win32': if sys.platform != 'win32':
# Unix # Unix
......
Forbid creation of asyncio stream objects like StreamReader, StreamWriter,
Process, and their protocols outside of asyncio package.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment