Commit 97b94329 authored by David S. Miller's avatar David S. Miller

Merge branch 'vsock-fixes'

Filippo Storniolo says:

====================
vsock: fix server prevents clients from reconnecting

This patch series introduce fix and tests for the following vsock bug:
If the same remote peer, using the same port, tries to connect
to a server on a listening port more than once, the server will
reject the connection, causing a "connection reset by peer"
error on the remote peer. This is due to the presence of a
dangling socket from a previous connection in both the connected
and bound socket lists.
The inconsistency of the above lists only occurs when the remote
peer disconnects and the server remains active.
This bug does not occur when the server socket is closed.

More details on the first patch changelog.
The remaining patches are refactoring and test.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 7425627b d80f63f6
...@@ -1369,11 +1369,17 @@ virtio_transport_recv_connected(struct sock *sk, ...@@ -1369,11 +1369,17 @@ virtio_transport_recv_connected(struct sock *sk,
vsk->peer_shutdown |= RCV_SHUTDOWN; vsk->peer_shutdown |= RCV_SHUTDOWN;
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
vsk->peer_shutdown |= SEND_SHUTDOWN; vsk->peer_shutdown |= SEND_SHUTDOWN;
if (vsk->peer_shutdown == SHUTDOWN_MASK && if (vsk->peer_shutdown == SHUTDOWN_MASK) {
vsock_stream_has_data(vsk) <= 0 && if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) {
!sock_flag(sk, SOCK_DONE)) { (void)virtio_transport_reset(vsk, NULL);
(void)virtio_transport_reset(vsk, NULL); virtio_transport_do_close(vsk, true);
virtio_transport_do_close(vsk, true); }
/* Remove this socket anyway because the remote peer sent
* the shutdown. This way a new connection will succeed
* if the remote peer uses the same source port,
* even if the old socket is still unreleased, but now disconnected.
*/
vsock_remove_sock(vsk);
} }
if (le32_to_cpu(virtio_vsock_hdr(skb)->flags)) if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
sk->sk_state_change(sk); sk->sk_state_change(sk);
......
...@@ -85,6 +85,48 @@ void vsock_wait_remote_close(int fd) ...@@ -85,6 +85,48 @@ void vsock_wait_remote_close(int fd)
close(epollfd); close(epollfd);
} }
/* Bind to <bind_port>, connect to <cid, port> and return the file descriptor. */
int vsock_bind_connect(unsigned int cid, unsigned int port, unsigned int bind_port, int type)
{
struct sockaddr_vm sa_client = {
.svm_family = AF_VSOCK,
.svm_cid = VMADDR_CID_ANY,
.svm_port = bind_port,
};
struct sockaddr_vm sa_server = {
.svm_family = AF_VSOCK,
.svm_cid = cid,
.svm_port = port,
};
int client_fd, ret;
client_fd = socket(AF_VSOCK, type, 0);
if (client_fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}
if (bind(client_fd, (struct sockaddr *)&sa_client, sizeof(sa_client))) {
perror("bind");
exit(EXIT_FAILURE);
}
timeout_begin(TIMEOUT);
do {
ret = connect(client_fd, (struct sockaddr *)&sa_server, sizeof(sa_server));
timeout_check("connect");
} while (ret < 0 && errno == EINTR);
timeout_end();
if (ret < 0) {
perror("connect");
exit(EXIT_FAILURE);
}
return client_fd;
}
/* Connect to <cid, port> and return the file descriptor. */ /* Connect to <cid, port> and return the file descriptor. */
static int vsock_connect(unsigned int cid, unsigned int port, int type) static int vsock_connect(unsigned int cid, unsigned int port, int type)
{ {
...@@ -104,6 +146,10 @@ static int vsock_connect(unsigned int cid, unsigned int port, int type) ...@@ -104,6 +146,10 @@ static int vsock_connect(unsigned int cid, unsigned int port, int type)
control_expectln("LISTENING"); control_expectln("LISTENING");
fd = socket(AF_VSOCK, type, 0); fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}
timeout_begin(TIMEOUT); timeout_begin(TIMEOUT);
do { do {
...@@ -132,11 +178,8 @@ int vsock_seqpacket_connect(unsigned int cid, unsigned int port) ...@@ -132,11 +178,8 @@ int vsock_seqpacket_connect(unsigned int cid, unsigned int port)
return vsock_connect(cid, port, SOCK_SEQPACKET); return vsock_connect(cid, port, SOCK_SEQPACKET);
} }
/* Listen on <cid, port> and return the first incoming connection. The remote /* Listen on <cid, port> and return the file descriptor. */
* address is stored to clientaddrp. clientaddrp may be NULL. static int vsock_listen(unsigned int cid, unsigned int port, int type)
*/
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
{ {
union { union {
struct sockaddr sa; struct sockaddr sa;
...@@ -148,16 +191,13 @@ static int vsock_accept(unsigned int cid, unsigned int port, ...@@ -148,16 +191,13 @@ static int vsock_accept(unsigned int cid, unsigned int port,
.svm_cid = cid, .svm_cid = cid,
}, },
}; };
union {
struct sockaddr sa;
struct sockaddr_vm svm;
} clientaddr;
socklen_t clientaddr_len = sizeof(clientaddr.svm);
int fd; int fd;
int client_fd;
int old_errno;
fd = socket(AF_VSOCK, type, 0); fd = socket(AF_VSOCK, type, 0);
if (fd < 0) {
perror("socket");
exit(EXIT_FAILURE);
}
if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
perror("bind"); perror("bind");
...@@ -169,6 +209,24 @@ static int vsock_accept(unsigned int cid, unsigned int port, ...@@ -169,6 +209,24 @@ static int vsock_accept(unsigned int cid, unsigned int port,
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
} }
return fd;
}
/* Listen on <cid, port> and return the first incoming connection. The remote
* address is stored to clientaddrp. clientaddrp may be NULL.
*/
static int vsock_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp, int type)
{
union {
struct sockaddr sa;
struct sockaddr_vm svm;
} clientaddr;
socklen_t clientaddr_len = sizeof(clientaddr.svm);
int fd, client_fd, old_errno;
fd = vsock_listen(cid, port, type);
control_writeln("LISTENING"); control_writeln("LISTENING");
timeout_begin(TIMEOUT); timeout_begin(TIMEOUT);
...@@ -207,6 +265,11 @@ int vsock_stream_accept(unsigned int cid, unsigned int port, ...@@ -207,6 +265,11 @@ int vsock_stream_accept(unsigned int cid, unsigned int port,
return vsock_accept(cid, port, clientaddrp, SOCK_STREAM); return vsock_accept(cid, port, clientaddrp, SOCK_STREAM);
} }
int vsock_stream_listen(unsigned int cid, unsigned int port)
{
return vsock_listen(cid, port, SOCK_STREAM);
}
int vsock_seqpacket_accept(unsigned int cid, unsigned int port, int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp) struct sockaddr_vm *clientaddrp)
{ {
......
...@@ -36,9 +36,12 @@ struct test_case { ...@@ -36,9 +36,12 @@ struct test_case {
void init_signals(void); void init_signals(void);
unsigned int parse_cid(const char *str); unsigned int parse_cid(const char *str);
int vsock_stream_connect(unsigned int cid, unsigned int port); int vsock_stream_connect(unsigned int cid, unsigned int port);
int vsock_bind_connect(unsigned int cid, unsigned int port,
unsigned int bind_port, int type);
int vsock_seqpacket_connect(unsigned int cid, unsigned int port); int vsock_seqpacket_connect(unsigned int cid, unsigned int port);
int vsock_stream_accept(unsigned int cid, unsigned int port, int vsock_stream_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp); struct sockaddr_vm *clientaddrp);
int vsock_stream_listen(unsigned int cid, unsigned int port);
int vsock_seqpacket_accept(unsigned int cid, unsigned int port, int vsock_seqpacket_accept(unsigned int cid, unsigned int port,
struct sockaddr_vm *clientaddrp); struct sockaddr_vm *clientaddrp);
void vsock_wait_remote_close(int fd); void vsock_wait_remote_close(int fd);
......
...@@ -1180,6 +1180,51 @@ static void test_stream_shutrd_server(const struct test_opts *opts) ...@@ -1180,6 +1180,51 @@ static void test_stream_shutrd_server(const struct test_opts *opts)
close(fd); close(fd);
} }
static void test_double_bind_connect_server(const struct test_opts *opts)
{
int listen_fd, client_fd, i;
struct sockaddr_vm sa_client;
socklen_t socklen_client = sizeof(sa_client);
listen_fd = vsock_stream_listen(VMADDR_CID_ANY, 1234);
for (i = 0; i < 2; i++) {
control_writeln("LISTENING");
timeout_begin(TIMEOUT);
do {
client_fd = accept(listen_fd, (struct sockaddr *)&sa_client,
&socklen_client);
timeout_check("accept");
} while (client_fd < 0 && errno == EINTR);
timeout_end();
if (client_fd < 0) {
perror("accept");
exit(EXIT_FAILURE);
}
/* Waiting for remote peer to close connection */
vsock_wait_remote_close(client_fd);
}
close(listen_fd);
}
static void test_double_bind_connect_client(const struct test_opts *opts)
{
int i, client_fd;
for (i = 0; i < 2; i++) {
/* Wait until server is ready to accept a new connection */
control_expectln("LISTENING");
client_fd = vsock_bind_connect(opts->peer_cid, 1234, 4321, SOCK_STREAM);
close(client_fd);
}
}
static struct test_case test_cases[] = { static struct test_case test_cases[] = {
{ {
.name = "SOCK_STREAM connection reset", .name = "SOCK_STREAM connection reset",
...@@ -1285,6 +1330,11 @@ static struct test_case test_cases[] = { ...@@ -1285,6 +1330,11 @@ static struct test_case test_cases[] = {
.run_client = test_stream_msgzcopy_empty_errq_client, .run_client = test_stream_msgzcopy_empty_errq_client,
.run_server = test_stream_msgzcopy_empty_errq_server, .run_server = test_stream_msgzcopy_empty_errq_server,
}, },
{
.name = "SOCK_STREAM double bind connect",
.run_client = test_double_bind_connect_client,
.run_server = test_double_bind_connect_server,
},
{}, {},
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment