Commit e562d086 authored by David S. Miller's avatar David S. Miller

Merge branch 'mptcp-refactor-token-container'

Paolo Abeni says:

====================
mptcp: refactor token container

Currently the msk sockets are stored in a single radix tree, protected by a
global spin_lock. This series moves to an hash table, allocated at boot time,
with per bucker spin_lock - alike inet_hashtables, but using a different key:
the token itself.

The above improves scalability, as write operations will have a far later chance
to compete for lock acquisition, allows lockless lookup, and will allow
easier msk traversing - e.g. for diag interface implementation's sake.

This also introduces trivial, related, kunit tests and move the existing in
kernel's one to kunit.

v1 -> v2:
 - fixed a few extra and sparse warns
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents be7aa9fa a8ee9c9b
...@@ -18,12 +18,20 @@ config MPTCP_IPV6 ...@@ -18,12 +18,20 @@ config MPTCP_IPV6
select IPV6 select IPV6
default y default y
config MPTCP_HMAC_TEST endif
bool "Tests for MPTCP HMAC implementation"
config MPTCP_KUNIT_TESTS
tristate "This builds the MPTCP KUnit tests" if !KUNIT_ALL_TESTS
select MPTCP
depends on KUNIT
default KUNIT_ALL_TESTS
help help
This option enable boot time self-test for the HMAC implementation Currently covers the MPTCP crypto and token helpers.
used by the MPTCP code Only useful for kernel devs running KUnit test harness and are not
for inclusion into a production build.
Say N if you are unsure. For more information on KUnit and unit tests in general please refer
to the KUnit documentation in Documentation/dev-tools/kunit/.
If unsure, say N.
endif
...@@ -3,3 +3,7 @@ obj-$(CONFIG_MPTCP) += mptcp.o ...@@ -3,3 +3,7 @@ obj-$(CONFIG_MPTCP) += mptcp.o
mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \ mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \
mib.o pm_netlink.o mib.o pm_netlink.o
mptcp_crypto_test-objs := crypto_test.o
mptcp_token_test-objs := token_test.o
obj-$(CONFIG_MPTCP_KUNIT_TESTS) += mptcp_crypto_test.o mptcp_token_test.o
...@@ -87,65 +87,6 @@ void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac) ...@@ -87,65 +87,6 @@ void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac)
sha256_final(&state, (u8 *)hmac); sha256_final(&state, (u8 *)hmac);
} }
#ifdef CONFIG_MPTCP_HMAC_TEST #if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS)
struct test_cast { EXPORT_SYMBOL_GPL(mptcp_crypto_hmac_sha);
char *key;
char *msg;
char *result;
};
/* we can't reuse RFC 4231 test vectors, as we have constraint on the
* input and key size.
*/
static struct test_cast tests[] = {
{
.key = "0b0b0b0b0b0b0b0b",
.msg = "48692054",
.result = "8385e24fb4235ac37556b6b886db106284a1da671699f46db1f235ec622dcafa",
},
{
.key = "aaaaaaaaaaaaaaaa",
.msg = "dddddddd",
.result = "2c5e219164ff1dca1c4a92318d847bb6b9d44492984e1eb71aff9022f71046e9",
},
{
.key = "0102030405060708",
.msg = "cdcdcdcd",
.result = "e73b9ba9969969cefb04aa0d6df18ec2fcc075b6f23b4d8c4da736a5dbbc6e7d",
},
};
static int __init test_mptcp_crypto(void)
{
char hmac[32], hmac_hex[65];
u32 nonce1, nonce2;
u64 key1, key2;
u8 msg[8];
int i, j;
for (i = 0; i < ARRAY_SIZE(tests); ++i) {
/* mptcp hmap will convert to be before computing the hmac */
key1 = be64_to_cpu(*((__be64 *)&tests[i].key[0]));
key2 = be64_to_cpu(*((__be64 *)&tests[i].key[8]));
nonce1 = be32_to_cpu(*((__be32 *)&tests[i].msg[0]));
nonce2 = be32_to_cpu(*((__be32 *)&tests[i].msg[4]));
put_unaligned_be32(nonce1, &msg[0]);
put_unaligned_be32(nonce2, &msg[4]);
mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac);
for (j = 0; j < 32; ++j)
sprintf(&hmac_hex[j << 1], "%02x", hmac[j] & 0xff);
hmac_hex[64] = 0;
if (memcmp(hmac_hex, tests[i].result, 64))
pr_err("test %d failed, got %s expected %s", i,
hmac_hex, tests[i].result);
else
pr_info("test %d [ ok ]", i);
}
return 0;
}
late_initcall(test_mptcp_crypto);
#endif #endif
// SPDX-License-Identifier: GPL-2.0
#include <kunit/test.h>
#include "protocol.h"
struct test_case {
char *key;
char *msg;
char *result;
};
/* we can't reuse RFC 4231 test vectors, as we have constraint on the
* input and key size.
*/
static struct test_case tests[] = {
{
.key = "0b0b0b0b0b0b0b0b",
.msg = "48692054",
.result = "8385e24fb4235ac37556b6b886db106284a1da671699f46db1f235ec622dcafa",
},
{
.key = "aaaaaaaaaaaaaaaa",
.msg = "dddddddd",
.result = "2c5e219164ff1dca1c4a92318d847bb6b9d44492984e1eb71aff9022f71046e9",
},
{
.key = "0102030405060708",
.msg = "cdcdcdcd",
.result = "e73b9ba9969969cefb04aa0d6df18ec2fcc075b6f23b4d8c4da736a5dbbc6e7d",
},
};
static void mptcp_crypto_test_basic(struct kunit *test)
{
char hmac[32], hmac_hex[65];
u32 nonce1, nonce2;
u64 key1, key2;
u8 msg[8];
int i, j;
for (i = 0; i < ARRAY_SIZE(tests); ++i) {
/* mptcp hmap will convert to be before computing the hmac */
key1 = be64_to_cpu(*((__be64 *)&tests[i].key[0]));
key2 = be64_to_cpu(*((__be64 *)&tests[i].key[8]));
nonce1 = be32_to_cpu(*((__be32 *)&tests[i].msg[0]));
nonce2 = be32_to_cpu(*((__be32 *)&tests[i].msg[4]));
put_unaligned_be32(nonce1, &msg[0]);
put_unaligned_be32(nonce2, &msg[4]);
mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac);
for (j = 0; j < 32; ++j)
sprintf(&hmac_hex[j << 1], "%02x", hmac[j] & 0xff);
hmac_hex[64] = 0;
KUNIT_EXPECT_STREQ(test, &hmac_hex[0], tests[i].result);
}
}
static struct kunit_case mptcp_crypto_test_cases[] = {
KUNIT_CASE(mptcp_crypto_test_basic),
{}
};
static struct kunit_suite mptcp_crypto_suite = {
.name = "mptcp-crypto",
.test_cases = mptcp_crypto_test_cases,
};
kunit_test_suite(mptcp_crypto_suite);
MODULE_LICENSE("GPL");
...@@ -234,7 +234,7 @@ void mptcp_pm_close(struct mptcp_sock *msk) ...@@ -234,7 +234,7 @@ void mptcp_pm_close(struct mptcp_sock *msk)
sock_put((struct sock *)msk); sock_put((struct sock *)msk);
} }
void mptcp_pm_init(void) void __init mptcp_pm_init(void)
{ {
pm_wq = alloc_workqueue("pm_wq", WQ_UNBOUND | WQ_MEM_RECLAIM, 8); pm_wq = alloc_workqueue("pm_wq", WQ_UNBOUND | WQ_MEM_RECLAIM, 8);
if (!pm_wq) if (!pm_wq)
......
...@@ -851,7 +851,7 @@ static struct pernet_operations mptcp_pm_pernet_ops = { ...@@ -851,7 +851,7 @@ static struct pernet_operations mptcp_pm_pernet_ops = {
.size = sizeof(struct pm_nl_pernet), .size = sizeof(struct pm_nl_pernet),
}; };
void mptcp_pm_nl_init(void) void __init mptcp_pm_nl_init(void)
{ {
if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
panic("Failed to register MPTCP PM pernet subsystem.\n"); panic("Failed to register MPTCP PM pernet subsystem.\n");
......
...@@ -1448,20 +1448,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk, ...@@ -1448,20 +1448,6 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
msk->token = subflow_req->token; msk->token = subflow_req->token;
msk->subflow = NULL; msk->subflow = NULL;
if (unlikely(mptcp_token_new_accept(subflow_req->token, nsk))) {
nsk->sk_state = TCP_CLOSE;
bh_unlock_sock(nsk);
/* we can't call into mptcp_close() here - possible BH context
* free the sock directly.
* sk_clone_lock() sets nsk refcnt to two, hence call sk_free()
* too.
*/
sk_common_release(nsk);
sk_free(nsk);
return NULL;
}
msk->write_seq = subflow_req->idsn + 1; msk->write_seq = subflow_req->idsn + 1;
atomic64_set(&msk->snd_una, msk->write_seq); atomic64_set(&msk->snd_una, msk->write_seq);
if (mp_opt->mp_capable) { if (mp_opt->mp_capable) {
...@@ -1547,7 +1533,7 @@ static void mptcp_destroy(struct sock *sk) ...@@ -1547,7 +1533,7 @@ static void mptcp_destroy(struct sock *sk)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
mptcp_token_destroy(msk->token); mptcp_token_destroy(msk);
if (msk->cached_ext) if (msk->cached_ext)
__skb_ext_put(msk->cached_ext); __skb_ext_put(msk->cached_ext);
...@@ -1636,6 +1622,20 @@ static void mptcp_release_cb(struct sock *sk) ...@@ -1636,6 +1622,20 @@ static void mptcp_release_cb(struct sock *sk)
} }
} }
static int mptcp_hash(struct sock *sk)
{
/* should never be called,
* we hash the TCP subflows not the master socket
*/
WARN_ON_ONCE(1);
return 0;
}
static void mptcp_unhash(struct sock *sk)
{
/* called from sk_common_release(), but nothing to do here */
}
static int mptcp_get_port(struct sock *sk, unsigned short snum) static int mptcp_get_port(struct sock *sk, unsigned short snum)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
...@@ -1679,7 +1679,6 @@ void mptcp_finish_connect(struct sock *ssk) ...@@ -1679,7 +1679,6 @@ void mptcp_finish_connect(struct sock *ssk)
*/ */
WRITE_ONCE(msk->remote_key, subflow->remote_key); WRITE_ONCE(msk->remote_key, subflow->remote_key);
WRITE_ONCE(msk->local_key, subflow->local_key); WRITE_ONCE(msk->local_key, subflow->local_key);
WRITE_ONCE(msk->token, subflow->token);
WRITE_ONCE(msk->write_seq, subflow->idsn + 1); WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
WRITE_ONCE(msk->ack_seq, ack_seq); WRITE_ONCE(msk->ack_seq, ack_seq);
WRITE_ONCE(msk->can_ack, 1); WRITE_ONCE(msk->can_ack, 1);
...@@ -1761,8 +1760,8 @@ static struct proto mptcp_prot = { ...@@ -1761,8 +1760,8 @@ static struct proto mptcp_prot = {
.sendmsg = mptcp_sendmsg, .sendmsg = mptcp_sendmsg,
.recvmsg = mptcp_recvmsg, .recvmsg = mptcp_recvmsg,
.release_cb = mptcp_release_cb, .release_cb = mptcp_release_cb,
.hash = inet_hash, .hash = mptcp_hash,
.unhash = inet_unhash, .unhash = mptcp_unhash,
.get_port = mptcp_get_port, .get_port = mptcp_get_port,
.sockets_allocated = &mptcp_sockets_allocated, .sockets_allocated = &mptcp_sockets_allocated,
.memory_allocated = &tcp_memory_allocated, .memory_allocated = &tcp_memory_allocated,
...@@ -1771,6 +1770,7 @@ static struct proto mptcp_prot = { ...@@ -1771,6 +1770,7 @@ static struct proto mptcp_prot = {
.sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem),
.sysctl_mem = sysctl_tcp_mem, .sysctl_mem = sysctl_tcp_mem,
.obj_size = sizeof(struct mptcp_sock), .obj_size = sizeof(struct mptcp_sock),
.slab_flags = SLAB_TYPESAFE_BY_RCU,
.no_autobind = true, .no_autobind = true,
}; };
...@@ -1800,6 +1800,7 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1800,6 +1800,7 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags) int addr_len, int flags)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct mptcp_subflow_context *subflow;
struct socket *ssock; struct socket *ssock;
int err; int err;
...@@ -1812,19 +1813,23 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1812,19 +1813,23 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto do_connect; goto do_connect;
} }
mptcp_token_destroy(msk);
ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
if (IS_ERR(ssock)) { if (IS_ERR(ssock)) {
err = PTR_ERR(ssock); err = PTR_ERR(ssock);
goto unlock; goto unlock;
} }
subflow = mptcp_subflow_ctx(ssock->sk);
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
* TCP option space. * TCP option space.
*/ */
if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0; subflow->request_mptcp = 0;
#endif #endif
if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk))
subflow->request_mptcp = 0;
do_connect: do_connect:
err = ssock->ops->connect(ssock, uaddr, addr_len, flags); err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
...@@ -1888,6 +1893,7 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -1888,6 +1893,7 @@ static int mptcp_listen(struct socket *sock, int backlog)
pr_debug("msk=%p", msk); pr_debug("msk=%p", msk);
lock_sock(sock->sk); lock_sock(sock->sk);
mptcp_token_destroy(msk);
ssock = __mptcp_socket_create(msk, TCP_LISTEN); ssock = __mptcp_socket_create(msk, TCP_LISTEN);
if (IS_ERR(ssock)) { if (IS_ERR(ssock)) {
err = PTR_ERR(ssock); err = PTR_ERR(ssock);
...@@ -2077,7 +2083,7 @@ static struct inet_protosw mptcp_protosw = { ...@@ -2077,7 +2083,7 @@ static struct inet_protosw mptcp_protosw = {
.flags = INET_PROTOSW_ICSK, .flags = INET_PROTOSW_ICSK,
}; };
void mptcp_proto_init(void) void __init mptcp_proto_init(void)
{ {
mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
...@@ -2086,6 +2092,7 @@ void mptcp_proto_init(void) ...@@ -2086,6 +2092,7 @@ void mptcp_proto_init(void)
mptcp_subflow_init(); mptcp_subflow_init();
mptcp_pm_init(); mptcp_pm_init();
mptcp_token_init();
if (proto_register(&mptcp_prot, 1) != 0) if (proto_register(&mptcp_prot, 1) != 0)
panic("Failed to register MPTCP proto.\n"); panic("Failed to register MPTCP proto.\n");
...@@ -2139,7 +2146,7 @@ static struct inet_protosw mptcp_v6_protosw = { ...@@ -2139,7 +2146,7 @@ static struct inet_protosw mptcp_v6_protosw = {
.flags = INET_PROTOSW_ICSK, .flags = INET_PROTOSW_ICSK,
}; };
int mptcp_proto_v6_init(void) int __init mptcp_proto_v6_init(void)
{ {
int err; int err;
......
...@@ -250,6 +250,7 @@ struct mptcp_subflow_request_sock { ...@@ -250,6 +250,7 @@ struct mptcp_subflow_request_sock {
u32 local_nonce; u32 local_nonce;
u32 remote_nonce; u32 remote_nonce;
struct mptcp_sock *msk; struct mptcp_sock *msk;
struct hlist_nulls_node token_node;
}; };
static inline struct mptcp_subflow_request_sock * static inline struct mptcp_subflow_request_sock *
...@@ -337,7 +338,7 @@ mptcp_subflow_get_mapped_dsn(const struct mptcp_subflow_context *subflow) ...@@ -337,7 +338,7 @@ mptcp_subflow_get_mapped_dsn(const struct mptcp_subflow_context *subflow)
int mptcp_is_enabled(struct net *net); int mptcp_is_enabled(struct net *net);
bool mptcp_subflow_data_available(struct sock *sk); bool mptcp_subflow_data_available(struct sock *sk);
void mptcp_subflow_init(void); void __init mptcp_subflow_init(void);
/* called with sk socket lock held */ /* called with sk socket lock held */
int __mptcp_subflow_connect(struct sock *sk, int ifindex, int __mptcp_subflow_connect(struct sock *sk, int ifindex,
...@@ -355,9 +356,9 @@ static inline void mptcp_subflow_tcp_fallback(struct sock *sk, ...@@ -355,9 +356,9 @@ static inline void mptcp_subflow_tcp_fallback(struct sock *sk,
inet_csk(sk)->icsk_af_ops = ctx->icsk_af_ops; inet_csk(sk)->icsk_af_ops = ctx->icsk_af_ops;
} }
void mptcp_proto_init(void); void __init mptcp_proto_init(void);
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
int mptcp_proto_v6_init(void); int __init mptcp_proto_v6_init(void);
#endif #endif
struct sock *mptcp_sk_clone(const struct sock *sk, struct sock *mptcp_sk_clone(const struct sock *sk,
...@@ -372,12 +373,19 @@ bool mptcp_finish_join(struct sock *sk); ...@@ -372,12 +373,19 @@ bool mptcp_finish_join(struct sock *sk);
void mptcp_data_acked(struct sock *sk); void mptcp_data_acked(struct sock *sk);
void mptcp_subflow_eof(struct sock *sk); void mptcp_subflow_eof(struct sock *sk);
void __init mptcp_token_init(void);
static inline void mptcp_token_init_request(struct request_sock *req)
{
mptcp_subflow_rsk(req)->token_node.pprev = NULL;
}
int mptcp_token_new_request(struct request_sock *req); int mptcp_token_new_request(struct request_sock *req);
void mptcp_token_destroy_request(u32 token); void mptcp_token_destroy_request(struct request_sock *req);
int mptcp_token_new_connect(struct sock *sk); int mptcp_token_new_connect(struct sock *sk);
int mptcp_token_new_accept(u32 token, struct sock *conn); void mptcp_token_accept(struct mptcp_subflow_request_sock *r,
struct mptcp_sock *msk);
struct mptcp_sock *mptcp_token_get_sock(u32 token); struct mptcp_sock *mptcp_token_get_sock(u32 token);
void mptcp_token_destroy(u32 token); void mptcp_token_destroy(struct mptcp_sock *msk);
void mptcp_crypto_key_sha(u64 key, u32 *token, u64 *idsn); void mptcp_crypto_key_sha(u64 key, u32 *token, u64 *idsn);
static inline void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) static inline void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
...@@ -394,7 +402,7 @@ static inline void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) ...@@ -394,7 +402,7 @@ static inline void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac); void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac);
void mptcp_pm_init(void); void __init mptcp_pm_init(void);
void mptcp_pm_data_init(struct mptcp_sock *msk); void mptcp_pm_data_init(struct mptcp_sock *msk);
void mptcp_pm_close(struct mptcp_sock *msk); void mptcp_pm_close(struct mptcp_sock *msk);
void mptcp_pm_new_connection(struct mptcp_sock *msk, int server_side); void mptcp_pm_new_connection(struct mptcp_sock *msk, int server_side);
...@@ -428,7 +436,7 @@ bool mptcp_pm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, ...@@ -428,7 +436,7 @@ bool mptcp_pm_addr_signal(struct mptcp_sock *msk, unsigned int remaining,
struct mptcp_addr_info *saddr); struct mptcp_addr_info *saddr);
int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc);
void mptcp_pm_nl_init(void); void __init mptcp_pm_nl_init(void);
void mptcp_pm_nl_data_init(struct mptcp_sock *msk); void mptcp_pm_nl_data_init(struct mptcp_sock *msk);
void mptcp_pm_nl_fully_established(struct mptcp_sock *msk); void mptcp_pm_nl_fully_established(struct mptcp_sock *msk);
void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk); void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk);
......
...@@ -32,12 +32,9 @@ static void SUBFLOW_REQ_INC_STATS(struct request_sock *req, ...@@ -32,12 +32,9 @@ static void SUBFLOW_REQ_INC_STATS(struct request_sock *req,
static int subflow_rebuild_header(struct sock *sk) static int subflow_rebuild_header(struct sock *sk)
{ {
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
int local_id, err = 0; int local_id;
if (subflow->request_mptcp && !subflow->token) { if (subflow->request_join && !subflow->local_nonce) {
pr_debug("subflow=%p", sk);
err = mptcp_token_new_connect(sk);
} else if (subflow->request_join && !subflow->local_nonce) {
struct mptcp_sock *msk = (struct mptcp_sock *)subflow->conn; struct mptcp_sock *msk = (struct mptcp_sock *)subflow->conn;
pr_debug("subflow=%p", sk); pr_debug("subflow=%p", sk);
...@@ -57,9 +54,6 @@ static int subflow_rebuild_header(struct sock *sk) ...@@ -57,9 +54,6 @@ static int subflow_rebuild_header(struct sock *sk)
} }
out: out:
if (err)
return err;
return subflow->icsk_af_ops->rebuild_header(sk); return subflow->icsk_af_ops->rebuild_header(sk);
} }
...@@ -72,8 +66,7 @@ static void subflow_req_destructor(struct request_sock *req) ...@@ -72,8 +66,7 @@ static void subflow_req_destructor(struct request_sock *req)
if (subflow_req->msk) if (subflow_req->msk)
sock_put((struct sock *)subflow_req->msk); sock_put((struct sock *)subflow_req->msk);
if (subflow_req->mp_capable) mptcp_token_destroy_request(req);
mptcp_token_destroy_request(subflow_req->token);
tcp_request_sock_ops.destructor(req); tcp_request_sock_ops.destructor(req);
} }
...@@ -135,6 +128,7 @@ static void subflow_init_req(struct request_sock *req, ...@@ -135,6 +128,7 @@ static void subflow_init_req(struct request_sock *req,
subflow_req->mp_capable = 0; subflow_req->mp_capable = 0;
subflow_req->mp_join = 0; subflow_req->mp_join = 0;
subflow_req->msk = NULL; subflow_req->msk = NULL;
mptcp_token_init_request(req);
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
...@@ -250,7 +244,7 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) ...@@ -250,7 +244,7 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb)
subflow->remote_nonce = mp_opt.nonce; subflow->remote_nonce = mp_opt.nonce;
pr_debug("subflow=%p, thmac=%llu, remote_nonce=%u", subflow, pr_debug("subflow=%p, thmac=%llu, remote_nonce=%u", subflow,
subflow->thmac, subflow->remote_nonce); subflow->thmac, subflow->remote_nonce);
} else if (subflow->request_mptcp) { } else {
tp->is_mptcp = 0; tp->is_mptcp = 0;
} }
...@@ -386,7 +380,7 @@ static void mptcp_sock_destruct(struct sock *sk) ...@@ -386,7 +380,7 @@ static void mptcp_sock_destruct(struct sock *sk)
sock_orphan(sk); sock_orphan(sk);
} }
mptcp_token_destroy(mptcp_sk(sk)->token); mptcp_token_destroy(mptcp_sk(sk));
inet_sock_destruct(sk); inet_sock_destruct(sk);
} }
...@@ -505,6 +499,7 @@ static struct sock *subflow_syn_recv_sock(const struct sock *sk, ...@@ -505,6 +499,7 @@ static struct sock *subflow_syn_recv_sock(const struct sock *sk,
*/ */
new_msk->sk_destruct = mptcp_sock_destruct; new_msk->sk_destruct = mptcp_sock_destruct;
mptcp_pm_new_connection(mptcp_sk(new_msk), 1); mptcp_pm_new_connection(mptcp_sk(new_msk), 1);
mptcp_token_accept(subflow_req, mptcp_sk(new_msk));
ctx->conn = new_msk; ctx->conn = new_msk;
new_msk = NULL; new_msk = NULL;
...@@ -1255,7 +1250,7 @@ static int subflow_ops_init(struct request_sock_ops *subflow_ops) ...@@ -1255,7 +1250,7 @@ static int subflow_ops_init(struct request_sock_ops *subflow_ops)
return 0; return 0;
} }
void mptcp_subflow_init(void) void __init mptcp_subflow_init(void)
{ {
subflow_request_sock_ops = tcp_request_sock_ops; subflow_request_sock_ops = tcp_request_sock_ops;
if (subflow_ops_init(&subflow_request_sock_ops) != 0) if (subflow_ops_init(&subflow_request_sock_ops) != 0)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <linux/kernel.h> #include <linux/kernel.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/radix-tree.h> #include <linux/memblock.h>
#include <linux/ip.h> #include <linux/ip.h>
#include <linux/tcp.h> #include <linux/tcp.h>
#include <net/sock.h> #include <net/sock.h>
...@@ -33,10 +33,55 @@ ...@@ -33,10 +33,55 @@
#include <net/mptcp.h> #include <net/mptcp.h>
#include "protocol.h" #include "protocol.h"
static RADIX_TREE(token_tree, GFP_ATOMIC); #define TOKEN_MAX_RETRIES 4
static RADIX_TREE(token_req_tree, GFP_ATOMIC); #define TOKEN_MAX_CHAIN_LEN 4
static DEFINE_SPINLOCK(token_tree_lock);
static int token_used __read_mostly; struct token_bucket {
spinlock_t lock;
int chain_len;
struct hlist_nulls_head req_chain;
struct hlist_nulls_head msk_chain;
};
static struct token_bucket *token_hash __read_mostly;
static unsigned int token_mask __read_mostly;
static struct token_bucket *token_bucket(u32 token)
{
return &token_hash[token & token_mask];
}
/* called with bucket lock held */
static struct mptcp_subflow_request_sock *
__token_lookup_req(struct token_bucket *t, u32 token)
{
struct mptcp_subflow_request_sock *req;
struct hlist_nulls_node *pos;
hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node)
if (req->token == token)
return req;
return NULL;
}
/* called with bucket lock held */
static struct mptcp_sock *
__token_lookup_msk(struct token_bucket *t, u32 token)
{
struct hlist_nulls_node *pos;
struct sock *sk;
sk_nulls_for_each_rcu(sk, pos, &t->msk_chain)
if (mptcp_sk(sk)->token == token)
return mptcp_sk(sk);
return NULL;
}
static bool __token_bucket_busy(struct token_bucket *t, u32 token)
{
return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN ||
__token_lookup_req(t, token) || __token_lookup_msk(t, token);
}
/** /**
* mptcp_token_new_request - create new key/idsn/token for subflow_request * mptcp_token_new_request - create new key/idsn/token for subflow_request
...@@ -52,30 +97,32 @@ static int token_used __read_mostly; ...@@ -52,30 +97,32 @@ static int token_used __read_mostly;
int mptcp_token_new_request(struct request_sock *req) int mptcp_token_new_request(struct request_sock *req)
{ {
struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
int err; int retries = TOKEN_MAX_RETRIES;
struct token_bucket *bucket;
while (1) { u32 token;
u32 token;
again:
mptcp_crypto_key_gen_sha(&subflow_req->local_key, mptcp_crypto_key_gen_sha(&subflow_req->local_key,
&subflow_req->token, &subflow_req->token,
&subflow_req->idsn); &subflow_req->idsn);
pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n", pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n",
req, subflow_req->local_key, subflow_req->token, req, subflow_req->local_key, subflow_req->token,
subflow_req->idsn); subflow_req->idsn);
token = subflow_req->token; token = subflow_req->token;
spin_lock_bh(&token_tree_lock); bucket = token_bucket(token);
if (!radix_tree_lookup(&token_req_tree, token) && spin_lock_bh(&bucket->lock);
!radix_tree_lookup(&token_tree, token)) if (__token_bucket_busy(bucket, token)) {
break; spin_unlock_bh(&bucket->lock);
spin_unlock_bh(&token_tree_lock); if (!--retries)
return -EBUSY;
goto again;
} }
err = radix_tree_insert(&token_req_tree, hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain);
subflow_req->token, &token_used); bucket->chain_len++;
spin_unlock_bh(&token_tree_lock); spin_unlock_bh(&bucket->lock);
return err; return 0;
} }
/** /**
...@@ -97,48 +144,56 @@ int mptcp_token_new_request(struct request_sock *req) ...@@ -97,48 +144,56 @@ int mptcp_token_new_request(struct request_sock *req)
int mptcp_token_new_connect(struct sock *sk) int mptcp_token_new_connect(struct sock *sk)
{ {
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct sock *mptcp_sock = subflow->conn; struct mptcp_sock *msk = mptcp_sk(subflow->conn);
int err; int retries = TOKEN_MAX_RETRIES;
struct token_bucket *bucket;
while (1) { pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
u32 token; sk, subflow->local_key, subflow->token, subflow->idsn);
mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token, again:
&subflow->idsn); mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token,
&subflow->idsn);
pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", bucket = token_bucket(subflow->token);
sk, subflow->local_key, subflow->token, subflow->idsn); spin_lock_bh(&bucket->lock);
if (__token_bucket_busy(bucket, subflow->token)) {
token = subflow->token; spin_unlock_bh(&bucket->lock);
spin_lock_bh(&token_tree_lock); if (!--retries)
if (!radix_tree_lookup(&token_req_tree, token) && return -EBUSY;
!radix_tree_lookup(&token_tree, token)) goto again;
break;
spin_unlock_bh(&token_tree_lock);
} }
err = radix_tree_insert(&token_tree, subflow->token, mptcp_sock);
spin_unlock_bh(&token_tree_lock);
return err; WRITE_ONCE(msk->token, subflow->token);
__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
bucket->chain_len++;
spin_unlock_bh(&bucket->lock);
return 0;
} }
/** /**
* mptcp_token_new_accept - insert token for later processing * mptcp_token_accept - replace a req sk with full sock in token hash
* @token: the token to insert to the tree * @req: the request socket to be removed
* @conn: the just cloned socket linked to the new connection * @msk: the just cloned socket linked to the new connection
* *
* Called when a SYN packet creates a new logical connection, i.e. * Called when a SYN packet creates a new logical connection, i.e.
* is not a join request. * is not a join request.
*/ */
int mptcp_token_new_accept(u32 token, struct sock *conn) void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
struct mptcp_sock *msk)
{ {
int err; struct mptcp_subflow_request_sock *pos;
struct token_bucket *bucket;
spin_lock_bh(&token_tree_lock); bucket = token_bucket(req->token);
err = radix_tree_insert(&token_tree, token, conn); spin_lock_bh(&bucket->lock);
spin_unlock_bh(&token_tree_lock);
return err; /* pedantic lookup check for the moved token */
pos = __token_lookup_req(bucket, req->token);
if (!WARN_ON_ONCE(pos != req))
hlist_nulls_del_init_rcu(&req->token_node);
__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
spin_unlock_bh(&bucket->lock);
} }
/** /**
...@@ -152,45 +207,112 @@ int mptcp_token_new_accept(u32 token, struct sock *conn) ...@@ -152,45 +207,112 @@ int mptcp_token_new_accept(u32 token, struct sock *conn)
*/ */
struct mptcp_sock *mptcp_token_get_sock(u32 token) struct mptcp_sock *mptcp_token_get_sock(u32 token)
{ {
struct sock *conn; struct hlist_nulls_node *pos;
struct token_bucket *bucket;
spin_lock_bh(&token_tree_lock); struct mptcp_sock *msk;
conn = radix_tree_lookup(&token_tree, token); struct sock *sk;
if (conn) {
/* token still reserved? */ rcu_read_lock();
if (conn == (struct sock *)&token_used) bucket = token_bucket(token);
conn = NULL;
else again:
sock_hold(conn); sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
msk = mptcp_sk(sk);
if (READ_ONCE(msk->token) != token)
continue;
if (!refcount_inc_not_zero(&sk->sk_refcnt))
goto not_found;
if (READ_ONCE(msk->token) != token) {
sock_put(sk);
goto again;
}
goto found;
} }
spin_unlock_bh(&token_tree_lock); if (get_nulls_value(pos) != (token & token_mask))
goto again;
return mptcp_sk(conn); not_found:
msk = NULL;
found:
rcu_read_unlock();
return msk;
} }
/** /**
* mptcp_token_destroy_request - remove mptcp connection/token * mptcp_token_destroy_request - remove mptcp connection/token
* @token: token of mptcp connection to remove * @req: mptcp request socket dropping the token
* *
* Remove not-yet-fully-established incoming connection identified * Remove the token associated to @req.
* by @token.
*/ */
void mptcp_token_destroy_request(u32 token) void mptcp_token_destroy_request(struct request_sock *req)
{ {
spin_lock_bh(&token_tree_lock); struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
radix_tree_delete(&token_req_tree, token); struct mptcp_subflow_request_sock *pos;
spin_unlock_bh(&token_tree_lock); struct token_bucket *bucket;
if (hlist_nulls_unhashed(&subflow_req->token_node))
return;
bucket = token_bucket(subflow_req->token);
spin_lock_bh(&bucket->lock);
pos = __token_lookup_req(bucket, subflow_req->token);
if (!WARN_ON_ONCE(pos != subflow_req)) {
hlist_nulls_del_init_rcu(&pos->token_node);
bucket->chain_len--;
}
spin_unlock_bh(&bucket->lock);
} }
/** /**
* mptcp_token_destroy - remove mptcp connection/token * mptcp_token_destroy - remove mptcp connection/token
* @token: token of mptcp connection to remove * @msk: mptcp connection dropping the token
* *
* Remove the connection identified by @token. * Remove the token associated to @msk
*/ */
void mptcp_token_destroy(u32 token) void mptcp_token_destroy(struct mptcp_sock *msk)
{ {
spin_lock_bh(&token_tree_lock); struct token_bucket *bucket;
radix_tree_delete(&token_tree, token); struct mptcp_sock *pos;
spin_unlock_bh(&token_tree_lock);
if (sk_unhashed((struct sock *)msk))
return;
bucket = token_bucket(msk->token);
spin_lock_bh(&bucket->lock);
pos = __token_lookup_msk(bucket, msk->token);
if (!WARN_ON_ONCE(pos != msk)) {
__sk_nulls_del_node_init_rcu((struct sock *)pos);
bucket->chain_len--;
}
spin_unlock_bh(&bucket->lock);
}
void __init mptcp_token_init(void)
{
int i;
token_hash = alloc_large_system_hash("MPTCP token",
sizeof(struct token_bucket),
0,
20,/* one slot per 1MB of memory */
0,
NULL,
&token_mask,
0,
64 * 1024);
for (i = 0; i < token_mask + 1; ++i) {
INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i);
INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i);
spin_lock_init(&token_hash[i].lock);
}
} }
#if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS)
EXPORT_SYMBOL_GPL(mptcp_token_new_request);
EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
EXPORT_SYMBOL_GPL(mptcp_token_accept);
EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
EXPORT_SYMBOL_GPL(mptcp_token_destroy);
#endif
// SPDX-License-Identifier: GPL-2.0
#include <kunit/test.h>
#include "protocol.h"
static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test)
{
struct mptcp_subflow_request_sock *req;
req = kunit_kzalloc(test, sizeof(struct mptcp_subflow_request_sock),
GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req);
mptcp_token_init_request((struct request_sock *)req);
return req;
}
static void mptcp_token_test_req_basic(struct kunit *test)
{
struct mptcp_subflow_request_sock *req = build_req_sock(test);
struct mptcp_sock *null_msk = NULL;
KUNIT_ASSERT_EQ(test, 0,
mptcp_token_new_request((struct request_sock *)req));
KUNIT_EXPECT_NE(test, 0, (int)req->token);
KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(req->token));
/* cleanup */
mptcp_token_destroy_request((struct request_sock *)req);
}
static struct inet_connection_sock *build_icsk(struct kunit *test)
{
struct inet_connection_sock *icsk;
icsk = kunit_kzalloc(test, sizeof(struct inet_connection_sock),
GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk);
return icsk;
}
static struct mptcp_subflow_context *build_ctx(struct kunit *test)
{
struct mptcp_subflow_context *ctx;
ctx = kunit_kzalloc(test, sizeof(struct mptcp_subflow_context),
GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx);
return ctx;
}
static struct mptcp_sock *build_msk(struct kunit *test)
{
struct mptcp_sock *msk;
msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
return msk;
}
static void mptcp_token_test_msk_basic(struct kunit *test)
{
struct inet_connection_sock *icsk = build_icsk(test);
struct mptcp_subflow_context *ctx = build_ctx(test);
struct mptcp_sock *msk = build_msk(test);
struct mptcp_sock *null_msk = NULL;
struct sock *sk;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->conn = (struct sock *)msk;
sk = (struct sock *)msk;
KUNIT_ASSERT_EQ(test, 0,
mptcp_token_new_connect((struct sock *)icsk));
KUNIT_EXPECT_NE(test, 0, (int)ctx->token);
KUNIT_EXPECT_EQ(test, ctx->token, msk->token);
KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(ctx->token));
KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt));
mptcp_token_destroy(msk);
KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(ctx->token));
}
static void mptcp_token_test_accept(struct kunit *test)
{
struct mptcp_subflow_request_sock *req = build_req_sock(test);
struct mptcp_sock *msk = build_msk(test);
KUNIT_ASSERT_EQ(test, 0,
mptcp_token_new_request((struct request_sock *)req));
msk->token = req->token;
mptcp_token_accept(req, msk);
KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
/* this is now a no-op */
mptcp_token_destroy_request((struct request_sock *)req);
KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
/* cleanup */
mptcp_token_destroy(msk);
}
static void mptcp_token_test_destroyed(struct kunit *test)
{
struct mptcp_subflow_request_sock *req = build_req_sock(test);
struct mptcp_sock *msk = build_msk(test);
struct mptcp_sock *null_msk = NULL;
struct sock *sk;
sk = (struct sock *)msk;
KUNIT_ASSERT_EQ(test, 0,
mptcp_token_new_request((struct request_sock *)req));
msk->token = req->token;
mptcp_token_accept(req, msk);
/* simulate race on removal */
refcount_set(&sk->sk_refcnt, 0);
KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(msk->token));
/* cleanup */
mptcp_token_destroy(msk);
}
static struct kunit_case mptcp_token_test_cases[] = {
KUNIT_CASE(mptcp_token_test_req_basic),
KUNIT_CASE(mptcp_token_test_msk_basic),
KUNIT_CASE(mptcp_token_test_accept),
KUNIT_CASE(mptcp_token_test_destroyed),
{}
};
static struct kunit_suite mptcp_token_suite = {
.name = "mptcp-token",
.test_cases = mptcp_token_test_cases,
};
kunit_test_suite(mptcp_token_suite);
MODULE_LICENSE("GPL");
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment