Commit 39e9d6f3 authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'net-tcp-dynamically-disable-tcp-md5-static-key'

Dmitry Safonov says:

====================
net/tcp: Dynamically disable TCP-MD5 static key

The static key introduced by commit 6015c71e ("tcp: md5: add
tcp_md5_needed jump label") is a fast-path optimization aimed at
avoiding a cache line miss.
Once an MD5 key is introduced in the system the static key is enabled
and never disabled. Address this by disabling the static key when
the last tcp_md5sig_info in system is destroyed.

Previously it was submitted as a part of TCP-AO patches set [1].
Now in attempt to split 36 patches submission, I send this independently.

Version 5:
https://lore.kernel.org/all/20221122185534.308643-1-dima@arista.com/T/#u
Version 4:
https://lore.kernel.org/all/20221115211905.1685426-1-dima@arista.com/T/#u
Version 3:
https://lore.kernel.org/all/20221111212320.1386566-1-dima@arista.com/T/#u
Version 2:
https://lore.kernel.org/all/20221103212524.865762-1-dima@arista.com/T/#u
Version 1:
https://lore.kernel.org/all/20221102211350.625011-1-dima@arista.com/T/#u

[1]: https://lore.kernel.org/all/20221027204347.529913-1-dima@arista.com/T/#u
====================

Link: https://lore.kernel.org/r/20221123173859.473629-1-dima@arista.comSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 19833ae2 c5b8b515
......@@ -224,9 +224,10 @@ extern bool arch_jump_label_transform_queue(struct jump_entry *entry,
enum jump_label_type type);
extern void arch_jump_label_transform_apply(void);
extern int jump_label_text_reserved(void *start, void *end);
extern void static_key_slow_inc(struct static_key *key);
extern bool static_key_slow_inc(struct static_key *key);
extern bool static_key_fast_inc_not_disabled(struct static_key *key);
extern void static_key_slow_dec(struct static_key *key);
extern void static_key_slow_inc_cpuslocked(struct static_key *key);
extern bool static_key_slow_inc_cpuslocked(struct static_key *key);
extern void static_key_slow_dec_cpuslocked(struct static_key *key);
extern int static_key_count(struct static_key *key);
extern void static_key_enable(struct static_key *key);
......@@ -278,11 +279,23 @@ static __always_inline bool static_key_true(struct static_key *key)
return false;
}
static inline void static_key_slow_inc(struct static_key *key)
static inline bool static_key_fast_inc_not_disabled(struct static_key *key)
{
int v;
STATIC_KEY_CHECK_USE(key);
atomic_inc(&key->enabled);
/*
* Prevent key->enabled getting negative to follow the same semantics
* as for CONFIG_JUMP_LABEL=y, see kernel/jump_label.c comment.
*/
v = atomic_read(&key->enabled);
do {
if (v < 0 || (v + 1) < 0)
return false;
} while (!likely(atomic_try_cmpxchg(&key->enabled, &v, v + 1)));
return true;
}
#define static_key_slow_inc(key) static_key_fast_inc_not_disabled(key)
static inline void static_key_slow_dec(struct static_key *key)
{
......
......@@ -1675,7 +1675,11 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sk_buff *skb);
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp);
const u8 *newkey, u8 newkeylen);
int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index,
struct tcp_md5sig_key *key);
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags);
struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
......@@ -1683,7 +1687,7 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
#ifdef CONFIG_TCP_MD5SIG
#include <linux/jump_label.h>
extern struct static_key_false tcp_md5_needed;
extern struct static_key_false_deferred tcp_md5_needed;
struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr,
int family);
......@@ -1691,7 +1695,7 @@ static inline struct tcp_md5sig_key *
tcp_md5_do_lookup(const struct sock *sk, int l3index,
const union tcp_md5_addr *addr, int family)
{
if (!static_branch_unlikely(&tcp_md5_needed))
if (!static_branch_unlikely(&tcp_md5_needed.key))
return NULL;
return __tcp_md5_do_lookup(sk, l3index, addr, family);
}
......
......@@ -113,9 +113,40 @@ int static_key_count(struct static_key *key)
}
EXPORT_SYMBOL_GPL(static_key_count);
void static_key_slow_inc_cpuslocked(struct static_key *key)
/*
* static_key_fast_inc_not_disabled - adds a user for a static key
* @key: static key that must be already enabled
*
* The caller must make sure that the static key can't get disabled while
* in this function. It doesn't patch jump labels, only adds a user to
* an already enabled static key.
*
* Returns true if the increment was done. Unlike refcount_t the ref counter
* is not saturated, but will fail to increment on overflow.
*/
bool static_key_fast_inc_not_disabled(struct static_key *key)
{
int v;
STATIC_KEY_CHECK_USE(key);
/*
* Negative key->enabled has a special meaning: it sends
* static_key_slow_inc() down the slow path, and it is non-zero
* so it counts as "enabled" in jump_label_update(). Note that
* atomic_inc_unless_negative() checks >= 0, so roll our own.
*/
v = atomic_read(&key->enabled);
do {
if (v <= 0 || (v + 1) < 0)
return false;
} while (!likely(atomic_try_cmpxchg(&key->enabled, &v, v + 1)));
return true;
}
EXPORT_SYMBOL_GPL(static_key_fast_inc_not_disabled);
bool static_key_slow_inc_cpuslocked(struct static_key *key)
{
lockdep_assert_cpus_held();
/*
......@@ -124,15 +155,9 @@ void static_key_slow_inc_cpuslocked(struct static_key *key)
* jump_label_update() process. At the same time, however,
* the jump_label_update() call below wants to see
* static_key_enabled(&key) for jumps to be updated properly.
*
* So give a special meaning to negative key->enabled: it sends
* static_key_slow_inc() down the slow path, and it is non-zero
* so it counts as "enabled" in jump_label_update(). Note that
* atomic_inc_unless_negative() checks >= 0, so roll our own.
*/
for (int v = atomic_read(&key->enabled); v > 0; )
if (likely(atomic_try_cmpxchg(&key->enabled, &v, v + 1)))
return;
if (static_key_fast_inc_not_disabled(key))
return true;
jump_label_lock();
if (atomic_read(&key->enabled) == 0) {
......@@ -144,16 +169,23 @@ void static_key_slow_inc_cpuslocked(struct static_key *key)
*/
atomic_set_release(&key->enabled, 1);
} else {
atomic_inc(&key->enabled);
if (WARN_ON_ONCE(!static_key_fast_inc_not_disabled(key))) {
jump_label_unlock();
return false;
}
}
jump_label_unlock();
return true;
}
void static_key_slow_inc(struct static_key *key)
bool static_key_slow_inc(struct static_key *key)
{
bool ret;
cpus_read_lock();
static_key_slow_inc_cpuslocked(key);
ret = static_key_slow_inc_cpuslocked(key);
cpus_read_unlock();
return ret;
}
EXPORT_SYMBOL_GPL(static_key_slow_inc);
......
......@@ -4464,11 +4464,8 @@ bool tcp_alloc_md5sig_pool(void)
if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) {
mutex_lock(&tcp_md5sig_mutex);
if (!tcp_md5sig_pool_populated) {
if (!tcp_md5sig_pool_populated)
__tcp_alloc_md5sig_pool();
if (tcp_md5sig_pool_populated)
static_branch_inc(&tcp_md5_needed);
}
mutex_unlock(&tcp_md5sig_mutex);
}
......
......@@ -1053,7 +1053,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
* We need to maintain these in the sk structure.
*/
DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
EXPORT_SYMBOL(tcp_md5_needed);
static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
......@@ -1161,10 +1161,25 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
}
EXPORT_SYMBOL(tcp_v4_md5_lookup);
static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
{
struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_info *md5sig;
md5sig = kmalloc(sizeof(*md5sig), gfp);
if (!md5sig)
return -ENOMEM;
sk_gso_disable(sk);
INIT_HLIST_HEAD(&md5sig->head);
rcu_assign_pointer(tp->md5sig_info, md5sig);
return 0;
}
/* This can be called on a newly created socket, from other files */
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
{
/* Add Key to the list */
struct tcp_md5sig_key *key;
......@@ -1193,15 +1208,6 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
md5sig = rcu_dereference_protected(tp->md5sig_info,
lockdep_sock_is_held(sk));
if (!md5sig) {
md5sig = kmalloc(sizeof(*md5sig), gfp);
if (!md5sig)
return -ENOMEM;
sk_gso_disable(sk);
INIT_HLIST_HEAD(&md5sig->head);
rcu_assign_pointer(tp->md5sig_info, md5sig);
}
key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO);
if (!key)
......@@ -1223,8 +1229,59 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
hlist_add_head_rcu(&key->node, &md5sig->head);
return 0;
}
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen)
{
struct tcp_sock *tp = tcp_sk(sk);
if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
if (tcp_md5sig_info_add(sk, GFP_KERNEL))
return -ENOMEM;
if (!static_branch_inc(&tcp_md5_needed.key)) {
struct tcp_md5sig_info *md5sig;
md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
rcu_assign_pointer(tp->md5sig_info, NULL);
kfree_rcu(md5sig);
return -EUSERS;
}
}
return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
newkey, newkeylen, GFP_KERNEL);
}
EXPORT_SYMBOL(tcp_md5_do_add);
int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
int family, u8 prefixlen, int l3index,
struct tcp_md5sig_key *key)
{
struct tcp_sock *tp = tcp_sk(sk);
if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
return -ENOMEM;
if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
struct tcp_md5sig_info *md5sig;
md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
net_warn_ratelimited("Too many TCP-MD5 keys in the system\n");
rcu_assign_pointer(tp->md5sig_info, NULL);
kfree_rcu(md5sig);
return -EUSERS;
}
}
return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
key->flags, key->key, key->keylen,
sk_gfp_mask(sk, GFP_ATOMIC));
}
EXPORT_SYMBOL(tcp_md5_key_copy);
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
u8 prefixlen, int l3index, u8 flags)
{
......@@ -1311,7 +1368,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
return -EINVAL;
return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);
}
static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
......@@ -1562,14 +1619,8 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
addr = (union tcp_md5_addr *)&newinet->inet_daddr;
key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET);
if (key) {
/*
* We're using one, so create a matching key
* on the newsk structure. If we fail to get
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
key->key, key->keylen, GFP_ATOMIC);
if (tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key))
goto put_and_exit;
sk_gso_disable(newsk);
}
#endif
......@@ -2261,6 +2312,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
tcp_clear_md5_list(sk);
kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
tp->md5sig_info = NULL;
static_branch_slow_dec_deferred(&tcp_md5_needed);
}
#endif
......
......@@ -240,6 +240,40 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
}
EXPORT_SYMBOL(tcp_timewait_state_process);
static void tcp_time_wait_init(struct sock *sk, struct tcp_timewait_sock *tcptw)
{
#ifdef CONFIG_TCP_MD5SIG
const struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_key *key;
/*
* The timewait bucket does not have the key DB from the
* sock structure. We just make a quick copy of the
* md5 key being used (if indeed we are using one)
* so the timewait ack generating code has the key.
*/
tcptw->tw_md5_key = NULL;
if (!static_branch_unlikely(&tcp_md5_needed.key))
return;
key = tp->af_specific->md5_lookup(sk, sk);
if (key) {
tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
if (!tcptw->tw_md5_key)
return;
if (!tcp_alloc_md5sig_pool())
goto out_free;
if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key))
goto out_free;
}
return;
out_free:
WARN_ON_ONCE(1);
kfree(tcptw->tw_md5_key);
tcptw->tw_md5_key = NULL;
#endif
}
/*
* Move a socket to time-wait or dead fin-wait-2 state.
*/
......@@ -282,26 +316,7 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
}
#endif
#ifdef CONFIG_TCP_MD5SIG
/*
* The timewait bucket does not have the key DB from the
* sock structure. We just make a quick copy of the
* md5 key being used (if indeed we are using one)
* so the timewait ack generating code has the key.
*/
do {
tcptw->tw_md5_key = NULL;
if (static_branch_unlikely(&tcp_md5_needed)) {
struct tcp_md5sig_key *key;
key = tp->af_specific->md5_lookup(sk, sk);
if (key) {
tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
}
}
} while (0);
#endif
tcp_time_wait_init(sk, tcptw);
/* Get the TIME_WAIT timeout firing. */
if (timeo < rto)
......@@ -337,11 +352,13 @@ EXPORT_SYMBOL(tcp_time_wait);
void tcp_twsk_destructor(struct sock *sk)
{
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed)) {
if (static_branch_unlikely(&tcp_md5_needed.key)) {
struct tcp_timewait_sock *twsk = tcp_twsk(sk);
if (twsk->tw_md5_key)
if (twsk->tw_md5_key) {
kfree_rcu(twsk->tw_md5_key, rcu);
static_branch_slow_dec_deferred(&tcp_md5_needed);
}
}
#endif
}
......
......@@ -766,7 +766,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb,
*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed) &&
if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {
......@@ -922,7 +922,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb
*md5 = NULL;
#ifdef CONFIG_TCP_MD5SIG
if (static_branch_unlikely(&tcp_md5_needed) &&
if (static_branch_unlikely(&tcp_md5_needed.key) &&
rcu_access_pointer(tp->md5sig_info)) {
*md5 = tp->af_specific->md5_lookup(sk, sk);
if (*md5) {
......
......@@ -665,12 +665,11 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
if (ipv6_addr_v4mapped(&sin6->sin6_addr))
return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen,
GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);
return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
AF_INET6, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
cmd.tcpm_key, cmd.tcpm_keylen);
}
static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
......@@ -1365,14 +1364,14 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
/* Copy over the MD5 key from the original socket */
key = tcp_v6_md5_do_lookup(sk, &newsk->sk_v6_daddr, l3index);
if (key) {
/* We're using one, so create a matching key
* on the newsk structure. If we fail to get
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
sk_gfp_mask(sk, GFP_ATOMIC));
const union tcp_md5_addr *addr;
addr = (union tcp_md5_addr *)&newsk->sk_v6_daddr;
if (tcp_md5_key_copy(newsk, addr, AF_INET6, 128, l3index, key)) {
inet_csk_prepare_forced_close(newsk);
tcp_done(newsk);
goto out;
}
}
#endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment