Commit d2f77c53 authored by Paolo Abeni's avatar Paolo Abeni Committed by David S. Miller

mptcp: check for plain TCP sock at accept time

This cleanup the code a bit and avoid corrupted states
on weird syscall sequence (accept(), connect()).
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Signed-off-by: default avatarDavide Caratti <dcaratti@redhat.com>
Reviewed-by: default avatarMat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 8fd73804
...@@ -52,13 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) ...@@ -52,13 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
return msk->subflow; return msk->subflow;
} }
static struct socket *mptcp_is_tcpsk(struct sock *sk) static bool mptcp_is_tcpsk(struct sock *sk)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
if (sock->sk != sk)
return NULL;
if (unlikely(sk->sk_prot == &tcp_prot)) { if (unlikely(sk->sk_prot == &tcp_prot)) {
/* we are being invoked after mptcp_accept() has /* we are being invoked after mptcp_accept() has
* accepted a non-mp-capable flow: sk is a tcp_sk, * accepted a non-mp-capable flow: sk is a tcp_sk,
...@@ -68,27 +65,21 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk) ...@@ -68,27 +65,21 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
* bypass mptcp. * bypass mptcp.
*/ */
sock->ops = &inet_stream_ops; sock->ops = &inet_stream_ops;
return sock; return true;
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
} else if (unlikely(sk->sk_prot == &tcpv6_prot)) { } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
sock->ops = &inet6_stream_ops; sock->ops = &inet6_stream_ops;
return sock; return true;
#endif #endif
} }
return NULL; return false;
} }
static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk) static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
{ {
struct socket *sock;
sock_owned_by_me((const struct sock *)msk); sock_owned_by_me((const struct sock *)msk);
sock = mptcp_is_tcpsk((struct sock *)msk);
if (unlikely(sock))
return sock;
if (likely(!__mptcp_check_fallback(msk))) if (likely(!__mptcp_check_fallback(msk)))
return NULL; return NULL;
...@@ -1466,7 +1457,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, ...@@ -1466,7 +1457,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
return NULL; return NULL;
pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
if (sk_is_mptcp(newsk)) { if (sk_is_mptcp(newsk)) {
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
struct sock *new_mptcp_sock; struct sock *new_mptcp_sock;
...@@ -1821,42 +1811,6 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1821,42 +1811,6 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
return err; return err;
} }
static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
int peer)
{
if (sock->sk->sk_prot == &tcp_prot) {
/* we are being invoked from __sys_accept4, after
* mptcp_accept() has just accepted a non-mp-capable
* flow: sk is a tcp_sk, not an mptcp one.
*
* Hand the socket over to tcp so all further socket ops
* bypass mptcp.
*/
sock->ops = &inet_stream_ops;
}
return inet_getname(sock, uaddr, peer);
}
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
int peer)
{
if (sock->sk->sk_prot == &tcpv6_prot) {
/* we are being invoked from __sys_accept4 after
* mptcp_accept() has accepted a non-mp-capable
* subflow: sk is a tcp_sk, not mptcp.
*
* Hand the socket over to tcp so all further
* socket ops bypass mptcp.
*/
sock->ops = &inet6_stream_ops;
}
return inet6_getname(sock, uaddr, peer);
}
#endif
static int mptcp_listen(struct socket *sock, int backlog) static int mptcp_listen(struct socket *sock, int backlog)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
...@@ -1885,15 +1839,6 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -1885,15 +1839,6 @@ static int mptcp_listen(struct socket *sock, int backlog)
return err; return err;
} }
static bool is_tcp_proto(const struct proto *p)
{
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
return p == &tcp_prot || p == &tcpv6_prot;
#else
return p == &tcp_prot;
#endif
}
static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
int flags, bool kern) int flags, bool kern)
{ {
...@@ -1915,7 +1860,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, ...@@ -1915,7 +1860,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
release_sock(sock->sk); release_sock(sock->sk);
err = ssock->ops->accept(sock, newsock, flags, kern); err = ssock->ops->accept(sock, newsock, flags, kern);
if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
struct mptcp_sock *msk = mptcp_sk(newsock->sk); struct mptcp_sock *msk = mptcp_sk(newsock->sk);
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
...@@ -2011,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = { ...@@ -2011,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = {
.connect = mptcp_stream_connect, .connect = mptcp_stream_connect,
.socketpair = sock_no_socketpair, .socketpair = sock_no_socketpair,
.accept = mptcp_stream_accept, .accept = mptcp_stream_accept,
.getname = mptcp_v4_getname, .getname = inet_getname,
.poll = mptcp_poll, .poll = mptcp_poll,
.ioctl = inet_ioctl, .ioctl = inet_ioctl,
.gettstamp = sock_gettstamp, .gettstamp = sock_gettstamp,
...@@ -2065,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = { ...@@ -2065,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
.connect = mptcp_stream_connect, .connect = mptcp_stream_connect,
.socketpair = sock_no_socketpair, .socketpair = sock_no_socketpair,
.accept = mptcp_stream_accept, .accept = mptcp_stream_accept,
.getname = mptcp_v6_getname, .getname = inet6_getname,
.poll = mptcp_poll, .poll = mptcp_poll,
.ioctl = inet6_ioctl, .ioctl = inet6_ioctl,
.gettstamp = sock_gettstamp, .gettstamp = sock_gettstamp,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment