Commit e341694e authored by Thomas Graf's avatar Thomas Graf Committed by David S. Miller

netlink: Convert netlink_lookup() to use RCU protected hash table

Heavy Netlink users such as Open vSwitch spend a considerable amount of
time in netlink_lookup() due to the read-lock on nl_table_lock. Use of
RCU relieves the lock contention.

Makes use of the new resizable hash table to avoid locking on the
lookup.

The hash table will grow if entries exceeds 75% of table size up to a
total table size of 64K. It will automatically shrink if usage falls
below 30%.

Also splits nl_table_lock into a separate mutex to protect hash table
mutations and allow synchronize_rcu() to sleep while waiting for readers
during expansion and shrinking.

Before:
   9.16%  kpktgend_0  [openvswitch]      [k] masked_flow_lookup
   6.42%  kpktgend_0  [pktgen]           [k] mod_cur_headers
   6.26%  kpktgend_0  [pktgen]           [k] pktgen_thread_worker
   6.23%  kpktgend_0  [kernel.kallsyms]  [k] memset
   4.79%  kpktgend_0  [kernel.kallsyms]  [k] netlink_lookup
   4.37%  kpktgend_0  [kernel.kallsyms]  [k] memcpy
   3.60%  kpktgend_0  [openvswitch]      [k] ovs_flow_extract
   2.69%  kpktgend_0  [kernel.kallsyms]  [k] jhash2

After:
  15.26%  kpktgend_0  [openvswitch]      [k] masked_flow_lookup
   8.12%  kpktgend_0  [pktgen]           [k] pktgen_thread_worker
   7.92%  kpktgend_0  [pktgen]           [k] mod_cur_headers
   5.11%  kpktgend_0  [kernel.kallsyms]  [k] memset
   4.11%  kpktgend_0  [openvswitch]      [k] ovs_flow_extract
   4.06%  kpktgend_0  [kernel.kallsyms]  [k] _raw_spin_lock
   3.90%  kpktgend_0  [kernel.kallsyms]  [k] jhash2
   [...]
   0.67%  kpktgend_0  [kernel.kallsyms]  [k] netlink_lookup
Signed-off-by: default avatarThomas Graf <tgraf@suug.ch>
Reviewed-by: default avatarNikolay Aleksandrov <nikolay@redhat.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 7e1e7763
...@@ -58,7 +58,9 @@ ...@@ -58,7 +58,9 @@
#include <linux/mutex.h> #include <linux/mutex.h>
#include <linux/vmalloc.h> #include <linux/vmalloc.h>
#include <linux/if_arp.h> #include <linux/if_arp.h>
#include <linux/rhashtable.h>
#include <asm/cacheflush.h> #include <asm/cacheflush.h>
#include <linux/hash.h>
#include <net/net_namespace.h> #include <net/net_namespace.h>
#include <net/sock.h> #include <net/sock.h>
...@@ -100,6 +102,18 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); ...@@ -100,6 +102,18 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); #define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
/* Protects netlink socket hash table mutations */
DEFINE_MUTEX(nl_sk_hash_lock);
static int lockdep_nl_sk_hash_is_held(void)
{
#ifdef CONFIG_LOCKDEP
return (debug_locks) ? lockdep_is_held(&nl_sk_hash_lock) : 1;
#else
return 1;
#endif
}
static ATOMIC_NOTIFIER_HEAD(netlink_chain); static ATOMIC_NOTIFIER_HEAD(netlink_chain);
static DEFINE_SPINLOCK(netlink_tap_lock); static DEFINE_SPINLOCK(netlink_tap_lock);
...@@ -110,11 +124,6 @@ static inline u32 netlink_group_mask(u32 group) ...@@ -110,11 +124,6 @@ static inline u32 netlink_group_mask(u32 group)
return group ? 1 << (group - 1) : 0; return group ? 1 << (group - 1) : 0;
} }
static inline struct hlist_head *nl_portid_hashfn(struct nl_portid_hash *hash, u32 portid)
{
return &hash->table[jhash_1word(portid, hash->rnd) & hash->mask];
}
int netlink_add_tap(struct netlink_tap *nt) int netlink_add_tap(struct netlink_tap *nt)
{ {
if (unlikely(nt->dev->type != ARPHRD_NETLINK)) if (unlikely(nt->dev->type != ARPHRD_NETLINK))
...@@ -983,105 +992,48 @@ netlink_unlock_table(void) ...@@ -983,105 +992,48 @@ netlink_unlock_table(void)
wake_up(&nl_table_wait); wake_up(&nl_table_wait);
} }
static bool netlink_compare(struct net *net, struct sock *sk) struct netlink_compare_arg
{
return net_eq(sock_net(sk), net);
}
static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
{ {
struct netlink_table *table = &nl_table[protocol]; struct net *net;
struct nl_portid_hash *hash = &table->hash; u32 portid;
struct hlist_head *head; };
struct sock *sk;
read_lock(&nl_table_lock);
head = nl_portid_hashfn(hash, portid);
sk_for_each(sk, head) {
if (table->compare(net, sk) &&
(nlk_sk(sk)->portid == portid)) {
sock_hold(sk);
goto found;
}
}
sk = NULL;
found:
read_unlock(&nl_table_lock);
return sk;
}
static struct hlist_head *nl_portid_hash_zalloc(size_t size) static bool netlink_compare(void *ptr, void *arg)
{ {
if (size <= PAGE_SIZE) struct netlink_compare_arg *x = arg;
return kzalloc(size, GFP_ATOMIC); struct sock *sk = ptr;
else
return (struct hlist_head *)
__get_free_pages(GFP_ATOMIC | __GFP_ZERO,
get_order(size));
}
static void nl_portid_hash_free(struct hlist_head *table, size_t size) return nlk_sk(sk)->portid == x->portid &&
{ net_eq(sock_net(sk), x->net);
if (size <= PAGE_SIZE)
kfree(table);
else
free_pages((unsigned long)table, get_order(size));
} }
static int nl_portid_hash_rehash(struct nl_portid_hash *hash, int grow) static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
struct net *net)
{ {
unsigned int omask, mask, shift; struct netlink_compare_arg arg = {
size_t osize, size; .net = net,
struct hlist_head *otable, *table; .portid = portid,
int i; };
u32 hash;
omask = mask = hash->mask;
osize = size = (mask + 1) * sizeof(*table);
shift = hash->shift;
if (grow) {
if (++shift > hash->max_shift)
return 0;
mask = mask * 2 + 1;
size *= 2;
}
table = nl_portid_hash_zalloc(size); hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
if (!table)
return 0;
otable = hash->table; return rhashtable_lookup_compare(&table->hash, hash,
hash->table = table; &netlink_compare, &arg);
hash->mask = mask;
hash->shift = shift;
get_random_bytes(&hash->rnd, sizeof(hash->rnd));
for (i = 0; i <= omask; i++) {
struct sock *sk;
struct hlist_node *tmp;
sk_for_each_safe(sk, tmp, &otable[i])
__sk_add_node(sk, nl_portid_hashfn(hash, nlk_sk(sk)->portid));
}
nl_portid_hash_free(otable, osize);
hash->rehash_time = jiffies + 10 * 60 * HZ;
return 1;
} }
static inline int nl_portid_hash_dilute(struct nl_portid_hash *hash, int len) static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
{ {
int avg = hash->entries >> hash->shift; struct netlink_table *table = &nl_table[protocol];
struct sock *sk;
if (unlikely(avg > 1) && nl_portid_hash_rehash(hash, 1))
return 1;
if (unlikely(len > avg) && time_after(jiffies, hash->rehash_time)) { rcu_read_lock();
nl_portid_hash_rehash(hash, 0); sk = __netlink_lookup(table, portid, net);
return 1; if (sk)
} sock_hold(sk);
rcu_read_unlock();
return 0; return sk;
} }
static const struct proto_ops netlink_ops; static const struct proto_ops netlink_ops;
...@@ -1113,22 +1065,10 @@ netlink_update_listeners(struct sock *sk) ...@@ -1113,22 +1065,10 @@ netlink_update_listeners(struct sock *sk)
static int netlink_insert(struct sock *sk, struct net *net, u32 portid) static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
{ {
struct netlink_table *table = &nl_table[sk->sk_protocol]; struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head;
int err = -EADDRINUSE; int err = -EADDRINUSE;
struct sock *osk;
int len;
netlink_table_grab(); mutex_lock(&nl_sk_hash_lock);
head = nl_portid_hashfn(hash, portid); if (__netlink_lookup(table, portid, net))
len = 0;
sk_for_each(osk, head) {
if (table->compare(net, osk) &&
(nlk_sk(osk)->portid == portid))
break;
len++;
}
if (osk)
goto err; goto err;
err = -EBUSY; err = -EBUSY;
...@@ -1136,26 +1076,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) ...@@ -1136,26 +1076,31 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
goto err; goto err;
err = -ENOMEM; err = -ENOMEM;
if (BITS_PER_LONG > 32 && unlikely(hash->entries >= UINT_MAX)) if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX))
goto err; goto err;
if (len && nl_portid_hash_dilute(hash, len))
head = nl_portid_hashfn(hash, portid);
hash->entries++;
nlk_sk(sk)->portid = portid; nlk_sk(sk)->portid = portid;
sk_add_node(sk, head); sock_hold(sk);
rhashtable_insert(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL);
err = 0; err = 0;
err: err:
netlink_table_ungrab(); mutex_unlock(&nl_sk_hash_lock);
return err; return err;
} }
static void netlink_remove(struct sock *sk) static void netlink_remove(struct sock *sk)
{ {
struct netlink_table *table;
mutex_lock(&nl_sk_hash_lock);
table = &nl_table[sk->sk_protocol];
if (rhashtable_remove(&table->hash, &nlk_sk(sk)->node, GFP_KERNEL)) {
WARN_ON(atomic_read(&sk->sk_refcnt) == 1);
__sock_put(sk);
}
mutex_unlock(&nl_sk_hash_lock);
netlink_table_grab(); netlink_table_grab();
if (sk_del_node_init(sk))
nl_table[sk->sk_protocol].hash.entries--;
if (nlk_sk(sk)->subscriptions) if (nlk_sk(sk)->subscriptions)
__sk_del_bind_node(sk); __sk_del_bind_node(sk);
netlink_table_ungrab(); netlink_table_ungrab();
...@@ -1311,6 +1256,9 @@ static int netlink_release(struct socket *sock) ...@@ -1311,6 +1256,9 @@ static int netlink_release(struct socket *sock)
} }
netlink_table_ungrab(); netlink_table_ungrab();
/* Wait for readers to complete */
synchronize_net();
kfree(nlk->groups); kfree(nlk->groups);
nlk->groups = NULL; nlk->groups = NULL;
...@@ -1326,30 +1274,22 @@ static int netlink_autobind(struct socket *sock) ...@@ -1326,30 +1274,22 @@ static int netlink_autobind(struct socket *sock)
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
struct netlink_table *table = &nl_table[sk->sk_protocol]; struct netlink_table *table = &nl_table[sk->sk_protocol];
struct nl_portid_hash *hash = &table->hash;
struct hlist_head *head;
struct sock *osk;
s32 portid = task_tgid_vnr(current); s32 portid = task_tgid_vnr(current);
int err; int err;
static s32 rover = -4097; static s32 rover = -4097;
retry: retry:
cond_resched(); cond_resched();
netlink_table_grab(); rcu_read_lock();
head = nl_portid_hashfn(hash, portid); if (__netlink_lookup(table, portid, net)) {
sk_for_each(osk, head) { /* Bind collision, search negative portid values. */
if (!table->compare(net, osk)) portid = rover--;
continue; if (rover > -4097)
if (nlk_sk(osk)->portid == portid) { rover = -4097;
/* Bind collision, search negative portid values. */ rcu_read_unlock();
portid = rover--; goto retry;
if (rover > -4097)
rover = -4097;
netlink_table_ungrab();
goto retry;
}
} }
netlink_table_ungrab(); rcu_read_unlock();
err = netlink_insert(sk, net, portid); err = netlink_insert(sk, net, portid);
if (err == -EADDRINUSE) if (err == -EADDRINUSE)
...@@ -2953,14 +2893,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) ...@@ -2953,14 +2893,18 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
{ {
struct nl_seq_iter *iter = seq->private; struct nl_seq_iter *iter = seq->private;
int i, j; int i, j;
struct netlink_sock *nlk;
struct sock *s; struct sock *s;
loff_t off = 0; loff_t off = 0;
for (i = 0; i < MAX_LINKS; i++) { for (i = 0; i < MAX_LINKS; i++) {
struct nl_portid_hash *hash = &nl_table[i].hash; struct rhashtable *ht = &nl_table[i].hash;
const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
for (j = 0; j < tbl->size; j++) {
rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
s = (struct sock *)nlk;
for (j = 0; j <= hash->mask; j++) {
sk_for_each(s, &hash->table[j]) {
if (sock_net(s) != seq_file_net(seq)) if (sock_net(s) != seq_file_net(seq))
continue; continue;
if (off == pos) { if (off == pos) {
...@@ -2976,15 +2920,14 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) ...@@ -2976,15 +2920,14 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
} }
static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
__acquires(nl_table_lock)
{ {
read_lock(&nl_table_lock); rcu_read_lock();
return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
} }
static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{ {
struct sock *s; struct netlink_sock *nlk;
struct nl_seq_iter *iter; struct nl_seq_iter *iter;
struct net *net; struct net *net;
int i, j; int i, j;
...@@ -2996,28 +2939,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2996,28 +2939,26 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
net = seq_file_net(seq); net = seq_file_net(seq);
iter = seq->private; iter = seq->private;
s = v; nlk = v;
do {
s = sk_next(s); rht_for_each_entry_rcu(nlk, nlk->node.next, node)
} while (s && !nl_table[s->sk_protocol].compare(net, s)); if (net_eq(sock_net((struct sock *)nlk), net))
if (s) return nlk;
return s;
i = iter->link; i = iter->link;
j = iter->hash_idx + 1; j = iter->hash_idx + 1;
do { do {
struct nl_portid_hash *hash = &nl_table[i].hash; struct rhashtable *ht = &nl_table[i].hash;
const struct bucket_table *tbl = rht_dereference(ht->tbl, ht);
for (; j <= hash->mask; j++) {
s = sk_head(&hash->table[j]);
while (s && !nl_table[s->sk_protocol].compare(net, s)) for (; j < tbl->size; j++) {
s = sk_next(s); rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) {
if (s) { if (net_eq(sock_net((struct sock *)nlk), net)) {
iter->link = i; iter->link = i;
iter->hash_idx = j; iter->hash_idx = j;
return s; return nlk;
}
} }
} }
...@@ -3028,9 +2969,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -3028,9 +2969,8 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
} }
static void netlink_seq_stop(struct seq_file *seq, void *v) static void netlink_seq_stop(struct seq_file *seq, void *v)
__releases(nl_table_lock)
{ {
read_unlock(&nl_table_lock); rcu_read_unlock();
} }
...@@ -3168,9 +3108,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = { ...@@ -3168,9 +3108,17 @@ static struct pernet_operations __net_initdata netlink_net_ops = {
static int __init netlink_proto_init(void) static int __init netlink_proto_init(void)
{ {
int i; int i;
unsigned long limit;
unsigned int order;
int err = proto_register(&netlink_proto, 0); int err = proto_register(&netlink_proto, 0);
struct rhashtable_params ht_params = {
.head_offset = offsetof(struct netlink_sock, node),
.key_offset = offsetof(struct netlink_sock, portid),
.key_len = sizeof(u32), /* portid */
.hashfn = arch_fast_hash,
.max_shift = 16, /* 64K */
.grow_decision = rht_grow_above_75,
.shrink_decision = rht_shrink_below_30,
.mutex_is_held = lockdep_nl_sk_hash_is_held,
};
if (err != 0) if (err != 0)
goto out; goto out;
...@@ -3181,32 +3129,13 @@ static int __init netlink_proto_init(void) ...@@ -3181,32 +3129,13 @@ static int __init netlink_proto_init(void)
if (!nl_table) if (!nl_table)
goto panic; goto panic;
if (totalram_pages >= (128 * 1024))
limit = totalram_pages >> (21 - PAGE_SHIFT);
else
limit = totalram_pages >> (23 - PAGE_SHIFT);
order = get_bitmask_order(limit) - 1 + PAGE_SHIFT;
limit = (1UL << order) / sizeof(struct hlist_head);
order = get_bitmask_order(min(limit, (unsigned long)UINT_MAX)) - 1;
for (i = 0; i < MAX_LINKS; i++) { for (i = 0; i < MAX_LINKS; i++) {
struct nl_portid_hash *hash = &nl_table[i].hash; if (rhashtable_init(&nl_table[i].hash, &ht_params) < 0) {
while (--i > 0)
hash->table = nl_portid_hash_zalloc(1 * sizeof(*hash->table)); rhashtable_destroy(&nl_table[i].hash);
if (!hash->table) {
while (i-- > 0)
nl_portid_hash_free(nl_table[i].hash.table,
1 * sizeof(*hash->table));
kfree(nl_table); kfree(nl_table);
goto panic; goto panic;
} }
hash->max_shift = order;
hash->shift = 0;
hash->mask = 0;
hash->rehash_time = jiffies;
nl_table[i].compare = netlink_compare;
} }
INIT_LIST_HEAD(&netlink_tap_all); INIT_LIST_HEAD(&netlink_tap_all);
......
#ifndef _AF_NETLINK_H #ifndef _AF_NETLINK_H
#define _AF_NETLINK_H #define _AF_NETLINK_H
#include <linux/rhashtable.h>
#include <net/sock.h> #include <net/sock.h>
#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
...@@ -47,6 +48,8 @@ struct netlink_sock { ...@@ -47,6 +48,8 @@ struct netlink_sock {
struct netlink_ring tx_ring; struct netlink_ring tx_ring;
atomic_t mapped; atomic_t mapped;
#endif /* CONFIG_NETLINK_MMAP */ #endif /* CONFIG_NETLINK_MMAP */
struct rhash_head node;
}; };
static inline struct netlink_sock *nlk_sk(struct sock *sk) static inline struct netlink_sock *nlk_sk(struct sock *sk)
...@@ -54,21 +57,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk) ...@@ -54,21 +57,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk)
return container_of(sk, struct netlink_sock, sk); return container_of(sk, struct netlink_sock, sk);
} }
struct nl_portid_hash {
struct hlist_head *table;
unsigned long rehash_time;
unsigned int mask;
unsigned int shift;
unsigned int entries;
unsigned int max_shift;
u32 rnd;
};
struct netlink_table { struct netlink_table {
struct nl_portid_hash hash; struct rhashtable hash;
struct hlist_head mc_list; struct hlist_head mc_list;
struct listeners __rcu *listeners; struct listeners __rcu *listeners;
unsigned int flags; unsigned int flags;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <linux/netlink.h> #include <linux/netlink.h>
#include <linux/sock_diag.h> #include <linux/sock_diag.h>
#include <linux/netlink_diag.h> #include <linux/netlink_diag.h>
#include <linux/rhashtable.h>
#include "af_netlink.h" #include "af_netlink.h"
...@@ -101,16 +102,20 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, ...@@ -101,16 +102,20 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
int protocol, int s_num) int protocol, int s_num)
{ {
struct netlink_table *tbl = &nl_table[protocol]; struct netlink_table *tbl = &nl_table[protocol];
struct nl_portid_hash *hash = &tbl->hash; struct rhashtable *ht = &tbl->hash;
const struct bucket_table *htbl = rht_dereference(ht->tbl, ht);
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct netlink_diag_req *req; struct netlink_diag_req *req;
struct netlink_sock *nlsk;
struct sock *sk; struct sock *sk;
int ret = 0, num = 0, i; int ret = 0, num = 0, i;
req = nlmsg_data(cb->nlh); req = nlmsg_data(cb->nlh);
for (i = 0; i <= hash->mask; i++) { for (i = 0; i < htbl->size; i++) {
sk_for_each(sk, &hash->table[i]) { rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) {
sk = (struct sock *)nlsk;
if (!net_eq(sock_net(sk), net)) if (!net_eq(sock_net(sk), net))
continue; continue;
if (num < s_num) { if (num < s_num) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment