Commit 9a349219 authored by David S. Miller's avatar David S. Miller

Merge branch 'AF_VSOCK-missed-wakeups'

Claudio Imbrenda says:

====================
AF_VSOCK: Shrink the area influenced by prepare_to_wait

This patchset applies on net-next.

I think I found a problem with the patch submitted by Laura Abbott
( https://lkml.org/lkml/2016/2/4/711 ): we might miss wakeups.
Since the condition is not checked between the prepare_to_wait and the
schedule(), if a wakeup happens after the condition is checked but before
the sleep happens, and we miss it. ( A description of the problem can be
found here: http://www.makelinux.net/ldd3/chp-6-sect-2 ).

The first patch reverts the previous broken patch, while the second patch
properly fixes the sleep-while-waiting issue.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 9a0384c0 f7f9b5e7
...@@ -1209,10 +1209,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1209,10 +1209,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
if (signal_pending(current)) { if (signal_pending(current)) {
err = sock_intr_errno(timeout); err = sock_intr_errno(timeout);
goto out_wait_error; sk->sk_state = SS_UNCONNECTED;
sock->state = SS_UNCONNECTED;
goto out_wait;
} else if (timeout == 0) { } else if (timeout == 0) {
err = -ETIMEDOUT; err = -ETIMEDOUT;
goto out_wait_error; sk->sk_state = SS_UNCONNECTED;
sock->state = SS_UNCONNECTED;
goto out_wait;
} }
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
...@@ -1220,20 +1224,17 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, ...@@ -1220,20 +1224,17 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
if (sk->sk_err) { if (sk->sk_err) {
err = -sk->sk_err; err = -sk->sk_err;
goto out_wait_error; sk->sk_state = SS_UNCONNECTED;
} else sock->state = SS_UNCONNECTED;
} else {
err = 0; err = 0;
}
out_wait: out_wait:
finish_wait(sk_sleep(sk), &wait); finish_wait(sk_sleep(sk), &wait);
out: out:
release_sock(sk); release_sock(sk);
return err; return err;
out_wait_error:
sk->sk_state = SS_UNCONNECTED;
sock->state = SS_UNCONNECTED;
goto out_wait;
} }
static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
...@@ -1270,18 +1271,20 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) ...@@ -1270,18 +1271,20 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
listener->sk_err == 0) { listener->sk_err == 0) {
release_sock(listener); release_sock(listener);
timeout = schedule_timeout(timeout); timeout = schedule_timeout(timeout);
finish_wait(sk_sleep(listener), &wait);
lock_sock(listener); lock_sock(listener);
if (signal_pending(current)) { if (signal_pending(current)) {
err = sock_intr_errno(timeout); err = sock_intr_errno(timeout);
goto out_wait; goto out;
} else if (timeout == 0) { } else if (timeout == 0) {
err = -EAGAIN; err = -EAGAIN;
goto out_wait; goto out;
} }
prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
} }
finish_wait(sk_sleep(listener), &wait);
if (listener->sk_err) if (listener->sk_err)
err = -listener->sk_err; err = -listener->sk_err;
...@@ -1301,19 +1304,15 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) ...@@ -1301,19 +1304,15 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
*/ */
if (err) { if (err) {
vconnected->rejected = true; vconnected->rejected = true;
release_sock(connected); } else {
sock_put(connected); newsock->state = SS_CONNECTED;
goto out_wait; sock_graft(connected, newsock);
} }
newsock->state = SS_CONNECTED;
sock_graft(connected, newsock);
release_sock(connected); release_sock(connected);
sock_put(connected); sock_put(connected);
} }
out_wait:
finish_wait(sk_sleep(listener), &wait);
out: out:
release_sock(listener); release_sock(listener);
return err; return err;
...@@ -1557,9 +1556,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1557,9 +1556,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
if (err < 0) if (err < 0)
goto out; goto out;
while (total_written < len) { while (total_written < len) {
ssize_t written; ssize_t written;
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
while (vsock_stream_has_space(vsk) == 0 && while (vsock_stream_has_space(vsk) == 0 &&
sk->sk_err == 0 && sk->sk_err == 0 &&
!(sk->sk_shutdown & SEND_SHUTDOWN) && !(sk->sk_shutdown & SEND_SHUTDOWN) &&
...@@ -1568,27 +1569,33 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1568,27 +1569,33 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
/* Don't wait for non-blocking sockets. */ /* Don't wait for non-blocking sockets. */
if (timeout == 0) { if (timeout == 0) {
err = -EAGAIN; err = -EAGAIN;
goto out_wait; finish_wait(sk_sleep(sk), &wait);
goto out_err;
} }
err = transport->notify_send_pre_block(vsk, &send_data); err = transport->notify_send_pre_block(vsk, &send_data);
if (err < 0) if (err < 0) {
goto out_wait; finish_wait(sk_sleep(sk), &wait);
goto out_err;
}
release_sock(sk); release_sock(sk);
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
timeout = schedule_timeout(timeout); timeout = schedule_timeout(timeout);
finish_wait(sk_sleep(sk), &wait);
lock_sock(sk); lock_sock(sk);
if (signal_pending(current)) { if (signal_pending(current)) {
err = sock_intr_errno(timeout); err = sock_intr_errno(timeout);
goto out_wait; finish_wait(sk_sleep(sk), &wait);
goto out_err;
} else if (timeout == 0) { } else if (timeout == 0) {
err = -EAGAIN; err = -EAGAIN;
goto out_wait; finish_wait(sk_sleep(sk), &wait);
goto out_err;
} }
prepare_to_wait(sk_sleep(sk), &wait,
TASK_INTERRUPTIBLE);
} }
finish_wait(sk_sleep(sk), &wait);
/* These checks occur both as part of and after the loop /* These checks occur both as part of and after the loop
* conditional since we need to check before and after * conditional since we need to check before and after
...@@ -1596,16 +1603,16 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1596,16 +1603,16 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
*/ */
if (sk->sk_err) { if (sk->sk_err) {
err = -sk->sk_err; err = -sk->sk_err;
goto out_wait; goto out_err;
} else if ((sk->sk_shutdown & SEND_SHUTDOWN) || } else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
(vsk->peer_shutdown & RCV_SHUTDOWN)) { (vsk->peer_shutdown & RCV_SHUTDOWN)) {
err = -EPIPE; err = -EPIPE;
goto out_wait; goto out_err;
} }
err = transport->notify_send_pre_enqueue(vsk, &send_data); err = transport->notify_send_pre_enqueue(vsk, &send_data);
if (err < 0) if (err < 0)
goto out_wait; goto out_err;
/* Note that enqueue will only write as many bytes as are free /* Note that enqueue will only write as many bytes as are free
* in the produce queue, so we don't need to ensure len is * in the produce queue, so we don't need to ensure len is
...@@ -1618,7 +1625,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1618,7 +1625,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
len - total_written); len - total_written);
if (written < 0) { if (written < 0) {
err = -ENOMEM; err = -ENOMEM;
goto out_wait; goto out_err;
} }
total_written += written; total_written += written;
...@@ -1626,11 +1633,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg, ...@@ -1626,11 +1633,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
err = transport->notify_send_post_enqueue( err = transport->notify_send_post_enqueue(
vsk, written, &send_data); vsk, written, &send_data);
if (err < 0) if (err < 0)
goto out_wait; goto out_err;
} }
out_wait: out_err:
if (total_written > 0) if (total_written > 0)
err = total_written; err = total_written;
out: out:
...@@ -1715,18 +1722,59 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -1715,18 +1722,59 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
while (1) { while (1) {
s64 ready = vsock_stream_has_data(vsk); s64 ready;
if (ready < 0) { prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
/* Invalid queue pair content. XXX This should be ready = vsock_stream_has_data(vsk);
* changed to a connection reset in a later change.
*/
err = -ENOMEM; if (ready == 0) {
goto out; if (sk->sk_err != 0 ||
} else if (ready > 0) { (sk->sk_shutdown & RCV_SHUTDOWN) ||
(vsk->peer_shutdown & SEND_SHUTDOWN)) {
finish_wait(sk_sleep(sk), &wait);
break;
}
/* Don't wait for non-blocking sockets. */
if (timeout == 0) {
err = -EAGAIN;
finish_wait(sk_sleep(sk), &wait);
break;
}
err = transport->notify_recv_pre_block(
vsk, target, &recv_data);
if (err < 0) {
finish_wait(sk_sleep(sk), &wait);
break;
}
release_sock(sk);
timeout = schedule_timeout(timeout);
lock_sock(sk);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
finish_wait(sk_sleep(sk), &wait);
break;
} else if (timeout == 0) {
err = -EAGAIN;
finish_wait(sk_sleep(sk), &wait);
break;
}
} else {
ssize_t read; ssize_t read;
finish_wait(sk_sleep(sk), &wait);
if (ready < 0) {
/* Invalid queue pair content. XXX This should
* be changed to a connection reset in a later
* change.
*/
err = -ENOMEM;
goto out;
}
err = transport->notify_recv_pre_dequeue( err = transport->notify_recv_pre_dequeue(
vsk, target, &recv_data); vsk, target, &recv_data);
if (err < 0) if (err < 0)
...@@ -1752,35 +1800,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, ...@@ -1752,35 +1800,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
break; break;
target -= read; target -= read;
} else {
if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN)
|| (vsk->peer_shutdown & SEND_SHUTDOWN)) {
break;
}
/* Don't wait for non-blocking sockets. */
if (timeout == 0) {
err = -EAGAIN;
break;
}
err = transport->notify_recv_pre_block(
vsk, target, &recv_data);
if (err < 0)
break;
release_sock(sk);
prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
timeout = schedule_timeout(timeout);
finish_wait(sk_sleep(sk), &wait);
lock_sock(sk);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
break;
} else if (timeout == 0) {
err = -EAGAIN;
break;
}
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment