Commit a19fb3c6 authored by Andrew Svetlov's avatar Andrew Svetlov Committed by GitHub

bpo-32622: Native sendfile on windows (#5565)

* Support sendfile on Windows Proactor event loop naively.
parent 5fb632e8
...@@ -6,11 +6,14 @@ proactor is only implemented on Windows with IOCP. ...@@ -6,11 +6,14 @@ proactor is only implemented on Windows with IOCP.
__all__ = 'BaseProactorEventLoop', __all__ = 'BaseProactorEventLoop',
import io
import os
import socket import socket
import warnings import warnings
from . import base_events from . import base_events
from . import constants from . import constants
from . import events
from . import futures from . import futures
from . import protocols from . import protocols
from . import sslproto from . import sslproto
...@@ -107,6 +110,11 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, ...@@ -107,6 +110,11 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
self._force_close(exc) self._force_close(exc)
def _force_close(self, exc): def _force_close(self, exc):
if self._empty_waiter is not None:
if exc is None:
self._empty_waiter.set_result(None)
else:
self._empty_waiter.set_exception(exc)
if self._closing: if self._closing:
return return
self._closing = True self._closing = True
...@@ -327,6 +335,10 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, ...@@ -327,6 +335,10 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
_start_tls_compatible = True _start_tls_compatible = True
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self._empty_waiter = None
def write(self, data): def write(self, data):
if not isinstance(data, (bytes, bytearray, memoryview)): if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError( raise TypeError(
...@@ -334,6 +346,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, ...@@ -334,6 +346,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
f"not {type(data).__name__}") f"not {type(data).__name__}")
if self._eof_written: if self._eof_written:
raise RuntimeError('write_eof() already called') raise RuntimeError('write_eof() already called')
if self._empty_waiter is not None:
raise RuntimeError('unable to write; sendfile is in progress')
if not data: if not data:
return return
...@@ -393,6 +407,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, ...@@ -393,6 +407,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
self._maybe_pause_protocol() self._maybe_pause_protocol()
else: else:
self._write_fut.add_done_callback(self._loop_writing) self._write_fut.add_done_callback(self._loop_writing)
if self._empty_waiter is not None and self._write_fut is None:
self._empty_waiter.set_result(None)
except ConnectionResetError as exc: except ConnectionResetError as exc:
self._force_close(exc) self._force_close(exc)
except OSError as exc: except OSError as exc:
...@@ -407,6 +423,17 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, ...@@ -407,6 +423,17 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
def abort(self): def abort(self):
self._force_close(None) self._force_close(None)
def _make_empty_waiter(self):
if self._empty_waiter is not None:
raise RuntimeError("Empty waiter is already set")
self._empty_waiter = self._loop.create_future()
if self._write_fut is None:
self._empty_waiter.set_result(None)
return self._empty_waiter
def _reset_empty_waiter(self):
self._empty_waiter = None
class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
...@@ -447,7 +474,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, ...@@ -447,7 +474,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
transports.Transport): transports.Transport):
"""Transport for connected sockets.""" """Transport for connected sockets."""
_sendfile_compatible = constants._SendfileMode.FALLBACK _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
def _set_extra(self, sock): def _set_extra(self, sock):
self._extra['socket'] = sock self._extra['socket'] = sock
...@@ -556,6 +583,47 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): ...@@ -556,6 +583,47 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
async def sock_accept(self, sock): async def sock_accept(self, sock):
return await self._proactor.accept(sock) return await self._proactor.accept(sock)
async def _sock_sendfile_native(self, sock, file, offset, count):
try:
fileno = file.fileno()
except (AttributeError, io.UnsupportedOperation) as err:
raise events.SendfileNotAvailableError("not a regular file")
try:
fsize = os.fstat(fileno).st_size
except OSError as err:
raise events.SendfileNotAvailableError("not a regular file")
blocksize = count if count else fsize
if not blocksize:
return 0 # empty file
blocksize = min(blocksize, 0xffff_ffff)
end_pos = min(offset + count, fsize) if count else fsize
offset = min(offset, fsize)
total_sent = 0
try:
while True:
blocksize = min(end_pos - offset, blocksize)
if blocksize <= 0:
return total_sent
await self._proactor.sendfile(sock, file, offset, blocksize)
offset += blocksize
total_sent += blocksize
finally:
if total_sent > 0:
file.seek(offset)
async def _sendfile_native(self, transp, file, offset, count):
resume_reading = transp.is_reading()
transp.pause_reading()
await transp._make_empty_waiter()
try:
return await self.sock_sendfile(transp._sock, file, offset, count,
fallback=False)
finally:
transp._reset_empty_waiter()
if resume_reading:
transp.resume_reading()
def _close_self_pipe(self): def _close_self_pipe(self):
if self._self_reading_future is not None: if self._self_reading_future is not None:
self._self_reading_future.cancel() self._self_reading_future.cancel()
......
...@@ -4,6 +4,7 @@ import _overlapped ...@@ -4,6 +4,7 @@ import _overlapped
import _winapi import _winapi
import errno import errno
import math import math
import msvcrt
import socket import socket
import struct import struct
import weakref import weakref
...@@ -527,6 +528,27 @@ class IocpProactor: ...@@ -527,6 +528,27 @@ class IocpProactor:
return self._register(ov, conn, finish_connect) return self._register(ov, conn, finish_connect)
def sendfile(self, sock, file, offset, count):
self._register_with_iocp(sock)
ov = _overlapped.Overlapped(NULL)
offset_low = offset & 0xffff_ffff
offset_high = (offset >> 32) & 0xffff_ffff
ov.TransmitFile(sock.fileno(),
msvcrt.get_osfhandle(file.fileno()),
offset_low, offset_high,
count, 0, 0)
def finish_sendfile(trans, key, ov):
try:
return ov.getresult()
except OSError as exc:
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise
return self._register(ov, sock, finish_sendfile)
def accept_pipe(self, pipe): def accept_pipe(self, pipe):
self._register_with_iocp(pipe) self._register_with_iocp(pipe)
ov = _overlapped.Overlapped(NULL) ov = _overlapped.Overlapped(NULL)
......
This diff is collapsed.
"""Tests for proactor_events.py""" """Tests for proactor_events.py"""
import io
import socket import socket
import unittest import unittest
import sys
from unittest import mock from unittest import mock
import asyncio import asyncio
from asyncio import events
from asyncio.proactor_events import BaseProactorEventLoop from asyncio.proactor_events import BaseProactorEventLoop
from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorSocketTransport
from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorWritePipeTransport
from asyncio.proactor_events import _ProactorDuplexPipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport
from test import support
from test.test_asyncio import utils as test_utils from test.test_asyncio import utils as test_utils
...@@ -775,5 +779,117 @@ class BaseProactorEventLoopTests(test_utils.TestCase): ...@@ -775,5 +779,117 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
self.assertFalse(future2.cancel.called) self.assertFalse(future2.cancel.called)
@unittest.skipIf(sys.platform != 'win32',
'Proactor is supported on Windows only')
class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase):
DATA = b"12345abcde" * 16 * 1024 # 160 KiB
class MyProto(asyncio.Protocol):
def __init__(self, loop):
self.started = False
self.closed = False
self.data = bytearray()
self.fut = loop.create_future()
self.transport = None
def connection_made(self, transport):
self.started = True
self.transport = transport
def data_received(self, data):
self.data.extend(data)
def connection_lost(self, exc):
self.closed = True
self.fut.set_result(None)
async def wait_closed(self):
await self.fut
@classmethod
def setUpClass(cls):
with open(support.TESTFN, 'wb') as fp:
fp.write(cls.DATA)
super().setUpClass()
@classmethod
def tearDownClass(cls):
support.unlink(support.TESTFN)
super().tearDownClass()
def setUp(self):
self.loop = asyncio.ProactorEventLoop()
self.set_event_loop(self.loop)
self.addCleanup(self.loop.close)
self.file = open(support.TESTFN, 'rb')
self.addCleanup(self.file.close)
super().setUp()
def make_socket(self, cleanup=True):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
if cleanup:
self.addCleanup(sock.close)
return sock
def run_loop(self, coro):
return self.loop.run_until_complete(coro)
def prepare(self):
sock = self.make_socket()
proto = self.MyProto(self.loop)
port = support.find_unused_port()
srv_sock = self.make_socket(cleanup=False)
srv_sock.bind(('127.0.0.1', port))
server = self.run_loop(self.loop.create_server(
lambda: proto, sock=srv_sock))
self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname()))
def cleanup():
if proto.transport is not None:
# can be None if the task was cancelled before
# connection_made callback
proto.transport.close()
self.run_loop(proto.wait_closed())
server.close()
self.run_loop(server.wait_closed())
self.addCleanup(cleanup)
return sock, proto
def test_sock_sendfile_not_a_file(self):
sock, proto = self.prepare()
f = object()
with self.assertRaisesRegex(events.SendfileNotAvailableError,
"not a regular file"):
self.run_loop(self.loop._sock_sendfile_native(sock, f,
0, None))
self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_iobuffer(self):
sock, proto = self.prepare()
f = io.BytesIO()
with self.assertRaisesRegex(events.SendfileNotAvailableError,
"not a regular file"):
self.run_loop(self.loop._sock_sendfile_native(sock, f,
0, None))
self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_not_regular_file(self):
sock, proto = self.prepare()
f = mock.Mock()
f.fileno.return_value = -1
with self.assertRaisesRegex(events.SendfileNotAvailableError,
"not a regular file"):
self.run_loop(self.loop._sock_sendfile_native(sock, f,
0, None))
self.assertEqual(self.file.tell(), 0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -466,10 +466,13 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): ...@@ -466,10 +466,13 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
self.addCleanup(self.file.close) self.addCleanup(self.file.close)
super().setUp() super().setUp()
def make_socket(self, blocking=False): def make_socket(self, cleanup=True):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(blocking) sock.setblocking(False)
self.addCleanup(sock.close) sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
if cleanup:
self.addCleanup(sock.close)
return sock return sock
def run_loop(self, coro): def run_loop(self, coro):
...@@ -479,8 +482,10 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): ...@@ -479,8 +482,10 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
sock = self.make_socket() sock = self.make_socket()
proto = self.MyProto(self.loop) proto = self.MyProto(self.loop)
port = support.find_unused_port() port = support.find_unused_port()
srv_sock = self.make_socket(cleanup=False)
srv_sock.bind((support.HOST, port))
server = self.run_loop(self.loop.create_server( server = self.run_loop(self.loop.create_server(
lambda: proto, support.HOST, port)) lambda: proto, sock=srv_sock))
self.run_loop(self.loop.sock_connect(sock, (support.HOST, port))) self.run_loop(self.loop.sock_connect(sock, (support.HOST, port)))
def cleanup(): def cleanup():
...@@ -497,27 +502,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): ...@@ -497,27 +502,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
return sock, proto return sock, proto
def test_sock_sendfile_success(self):
sock, proto = self.prepare()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
self.assertEqual(proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sock_sendfile_with_offset_and_count(self):
sock, proto = self.prepare()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
1000, 2000))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(proto.data, self.DATA[1000:3000])
self.assertEqual(self.file.tell(), 3000)
self.assertEqual(ret, 2000)
def test_sock_sendfile_not_available(self): def test_sock_sendfile_not_available(self):
sock, proto = self.prepare() sock, proto = self.prepare()
with mock.patch('asyncio.unix_events.os', spec=[]): with mock.patch('asyncio.unix_events.os', spec=[]):
...@@ -555,36 +539,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): ...@@ -555,36 +539,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase):
0, None)) 0, None))
self.assertEqual(self.file.tell(), 0) self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_zero_size(self):
sock, proto = self.prepare()
fname = support.TESTFN + '.suffix'
with open(fname, 'wb') as f:
pass # make zero sized file
f = open(fname, 'rb')
self.addCleanup(f.close)
self.addCleanup(support.unlink, fname)
ret = self.run_loop(self.loop._sock_sendfile_native(sock, f,
0, None))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, 0)
self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_mix_with_regular_send(self):
buf = b'1234567890' * 1024 * 1024 # 10 MB
sock, proto = self.prepare()
self.run_loop(self.loop.sock_sendall(sock, buf))
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
self.run_loop(self.loop.sock_sendall(sock, buf))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
expected = buf + self.DATA + buf
self.assertEqual(proto.data, expected)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sock_sendfile_cancel1(self): def test_sock_sendfile_cancel1(self):
sock, proto = self.prepare() sock, proto = self.prepare()
......
Implement native fast sendfile for Windows proactor event loop.
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE, enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE,
TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE,
TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE};
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
...@@ -89,6 +89,7 @@ SetFromWindowsErr(DWORD err) ...@@ -89,6 +89,7 @@ SetFromWindowsErr(DWORD err)
static LPFN_ACCEPTEX Py_AcceptEx = NULL; static LPFN_ACCEPTEX Py_AcceptEx = NULL;
static LPFN_CONNECTEX Py_ConnectEx = NULL; static LPFN_CONNECTEX Py_ConnectEx = NULL;
static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; static LPFN_DISCONNECTEX Py_DisconnectEx = NULL;
static LPFN_TRANSMITFILE Py_TransmitFile = NULL;
static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL;
#define GET_WSA_POINTER(s, x) \ #define GET_WSA_POINTER(s, x) \
...@@ -102,6 +103,7 @@ initialize_function_pointers(void) ...@@ -102,6 +103,7 @@ initialize_function_pointers(void)
GUID GuidAcceptEx = WSAID_ACCEPTEX; GUID GuidAcceptEx = WSAID_ACCEPTEX;
GUID GuidConnectEx = WSAID_CONNECTEX; GUID GuidConnectEx = WSAID_CONNECTEX;
GUID GuidDisconnectEx = WSAID_DISCONNECTEX; GUID GuidDisconnectEx = WSAID_DISCONNECTEX;
GUID GuidTransmitFile = WSAID_TRANSMITFILE;
HINSTANCE hKernel32; HINSTANCE hKernel32;
SOCKET s; SOCKET s;
DWORD dwBytes; DWORD dwBytes;
...@@ -114,7 +116,8 @@ initialize_function_pointers(void) ...@@ -114,7 +116,8 @@ initialize_function_pointers(void)
if (!GET_WSA_POINTER(s, AcceptEx) || if (!GET_WSA_POINTER(s, AcceptEx) ||
!GET_WSA_POINTER(s, ConnectEx) || !GET_WSA_POINTER(s, ConnectEx) ||
!GET_WSA_POINTER(s, DisconnectEx)) !GET_WSA_POINTER(s, DisconnectEx) ||
!GET_WSA_POINTER(s, TransmitFile))
{ {
closesocket(s); closesocket(s);
SetFromWindowsErr(WSAGetLastError()); SetFromWindowsErr(WSAGetLastError());
...@@ -1194,6 +1197,61 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) ...@@ -1194,6 +1197,61 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args)
} }
} }
PyDoc_STRVAR(
Overlapped_TransmitFile_doc,
"TransmitFile(socket, file, offset, offset_high, "
"count_to_write, count_per_send, flags) "
"-> Overlapped[None]\n\n"
"Transmit file data over a connected socket.");
static PyObject *
Overlapped_TransmitFile(OverlappedObject *self, PyObject *args)
{
SOCKET Socket;
HANDLE File;
DWORD offset;
DWORD offset_high;
DWORD count_to_write;
DWORD count_per_send;
DWORD flags;
BOOL ret;
DWORD err;
if (!PyArg_ParseTuple(args,
F_HANDLE F_HANDLE F_DWORD F_DWORD
F_DWORD F_DWORD F_DWORD,
&Socket, &File, &offset, &offset_high,
&count_to_write, &count_per_send,
&flags))
return NULL;
if (self->type != TYPE_NONE) {
PyErr_SetString(PyExc_ValueError, "operation already attempted");
return NULL;
}
self->type = TYPE_TRANSMIT_FILE;
self->handle = (HANDLE)Socket;
self->overlapped.Offset = offset;
self->overlapped.OffsetHigh = offset_high;
Py_BEGIN_ALLOW_THREADS
ret = Py_TransmitFile(Socket, File, count_to_write, count_per_send,
&self->overlapped,
NULL, flags);
Py_END_ALLOW_THREADS
self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError();
switch (err) {
case ERROR_SUCCESS:
case ERROR_IO_PENDING:
Py_RETURN_NONE;
default:
self->type = TYPE_NOT_STARTED;
return SetFromWindowsErr(err);
}
}
PyDoc_STRVAR( PyDoc_STRVAR(
Overlapped_ConnectNamedPipe_doc, Overlapped_ConnectNamedPipe_doc,
"ConnectNamedPipe(handle) -> Overlapped[None]\n\n" "ConnectNamedPipe(handle) -> Overlapped[None]\n\n"
...@@ -1303,6 +1361,8 @@ static PyMethodDef Overlapped_methods[] = { ...@@ -1303,6 +1361,8 @@ static PyMethodDef Overlapped_methods[] = {
METH_VARARGS, Overlapped_ConnectEx_doc}, METH_VARARGS, Overlapped_ConnectEx_doc},
{"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx,
METH_VARARGS, Overlapped_DisconnectEx_doc}, METH_VARARGS, Overlapped_DisconnectEx_doc},
{"TransmitFile", (PyCFunction) Overlapped_TransmitFile,
METH_VARARGS, Overlapped_TransmitFile_doc},
{"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe,
METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, METH_VARARGS, Overlapped_ConnectNamedPipe_doc},
{NULL} {NULL}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment