Commit bf2b866a authored by Daniel Borkmann's avatar Daniel Borkmann

Merge branch 'bpf-sockmap-fixes'

John Fastabend says:

====================
This addresses two syzbot issues that lead to identifying (by Eric and
Wei) a class of bugs where we don't correctly check for IPv4/v6
sockets and their associated state. The second issue was a locking
omission in sockhash.

The first patch addresses IPv6 socks and fixing an error where
sockhash would overwrite the prot pointer with IPv4 prot. To fix
this build similar solution to TLS ULP. Although we continue to
allow socks in all states not just ESTABLISH in this patch set
because as Martin points out there should be no issue with this
on the sockmap ULP because we don't use the ctx in this code. Once
multiple ULPs coexist we may need to revisit this. However we
can do this in *next trees.

The other issue syzbot found that the tcp_close() handler missed
locking the hash bucket lock which could result in corrupting the
sockhash bucket list if delete and close ran at the same time.
And also the smap_list_remove() routine was not working correctly
at all. This was not caught in my testing because in general my
tests (to date at least lets add some more robust selftest in
bpf-next) do things in the "expected" order, create map, add socks,
delete socks, then tear down maps. The tests we have that do the
ops out of this order where only working on single maps not multi-
maps so we never saw the issue. Thanks syzbot. The fix is to
restructure the tcp_close() lock handling. And fix the obvious
bug in smap_list_remove().

Finally, during review I noticed the release handler was omitted
from the upstream code (patch 4) due to an incorrect merge conflict
fix when I ported the code to latest bpf-next before submitting.
This would leave references to the map around if the user never
closes the map.

v3: rework patches, dropping ESTABLISH check and adding rcu
    annotation along with the smap_list_remove fix

v4: missed one more case where maps was being accessed without
    the sk_callback_lock, spoted by Martin as well.

v5: changed to use a specific lock for maps and reduced callback
    lock so that it is only used to gaurd sk callbacks. I think
    this makes the logic a bit cleaner and avoids confusion
    ovoer what each lock is doing.

Also big thanks to Martin for thorough review he caught at least
one case where I missed a rcu_call().
====================
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parents ca09cb04 caac76a5
...@@ -72,6 +72,7 @@ struct bpf_htab { ...@@ -72,6 +72,7 @@ struct bpf_htab {
u32 n_buckets; u32 n_buckets;
u32 elem_size; u32 elem_size;
struct bpf_sock_progs progs; struct bpf_sock_progs progs;
struct rcu_head rcu;
}; };
struct htab_elem { struct htab_elem {
...@@ -89,8 +90,8 @@ enum smap_psock_state { ...@@ -89,8 +90,8 @@ enum smap_psock_state {
struct smap_psock_map_entry { struct smap_psock_map_entry {
struct list_head list; struct list_head list;
struct sock **entry; struct sock **entry;
struct htab_elem *hash_link; struct htab_elem __rcu *hash_link;
struct bpf_htab *htab; struct bpf_htab __rcu *htab;
}; };
struct smap_psock { struct smap_psock {
...@@ -120,6 +121,7 @@ struct smap_psock { ...@@ -120,6 +121,7 @@ struct smap_psock {
struct bpf_prog *bpf_parse; struct bpf_prog *bpf_parse;
struct bpf_prog *bpf_verdict; struct bpf_prog *bpf_verdict;
struct list_head maps; struct list_head maps;
spinlock_t maps_lock;
/* Back reference used when sock callback trigger sockmap operations */ /* Back reference used when sock callback trigger sockmap operations */
struct sock *sock; struct sock *sock;
...@@ -140,6 +142,7 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -140,6 +142,7 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
static int bpf_tcp_sendpage(struct sock *sk, struct page *page, static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
static void bpf_tcp_close(struct sock *sk, long timeout);
static inline struct smap_psock *smap_psock_sk(const struct sock *sk) static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
{ {
...@@ -161,7 +164,42 @@ static bool bpf_tcp_stream_read(const struct sock *sk) ...@@ -161,7 +164,42 @@ static bool bpf_tcp_stream_read(const struct sock *sk)
return !empty; return !empty;
} }
static struct proto tcp_bpf_proto; enum {
SOCKMAP_IPV4,
SOCKMAP_IPV6,
SOCKMAP_NUM_PROTS,
};
enum {
SOCKMAP_BASE,
SOCKMAP_TX,
SOCKMAP_NUM_CONFIGS,
};
static struct proto *saved_tcpv6_prot __read_mostly;
static DEFINE_SPINLOCK(tcpv6_prot_lock);
static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
struct proto *base)
{
prot[SOCKMAP_BASE] = *base;
prot[SOCKMAP_BASE].close = bpf_tcp_close;
prot[SOCKMAP_BASE].recvmsg = bpf_tcp_recvmsg;
prot[SOCKMAP_BASE].stream_memory_read = bpf_tcp_stream_read;
prot[SOCKMAP_TX] = prot[SOCKMAP_BASE];
prot[SOCKMAP_TX].sendmsg = bpf_tcp_sendmsg;
prot[SOCKMAP_TX].sendpage = bpf_tcp_sendpage;
}
static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
{
int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
sk->sk_prot = &bpf_tcp_prots[family][conf];
}
static int bpf_tcp_init(struct sock *sk) static int bpf_tcp_init(struct sock *sk)
{ {
struct smap_psock *psock; struct smap_psock *psock;
...@@ -181,14 +219,17 @@ static int bpf_tcp_init(struct sock *sk) ...@@ -181,14 +219,17 @@ static int bpf_tcp_init(struct sock *sk)
psock->save_close = sk->sk_prot->close; psock->save_close = sk->sk_prot->close;
psock->sk_proto = sk->sk_prot; psock->sk_proto = sk->sk_prot;
if (psock->bpf_tx_msg) { /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg; if (sk->sk_family == AF_INET6 &&
tcp_bpf_proto.sendpage = bpf_tcp_sendpage; unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg; spin_lock_bh(&tcpv6_prot_lock);
tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read; if (likely(sk->sk_prot != saved_tcpv6_prot)) {
build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
}
spin_unlock_bh(&tcpv6_prot_lock);
} }
update_sk_prot(sk, psock);
sk->sk_prot = &tcp_bpf_proto;
rcu_read_unlock(); rcu_read_unlock();
return 0; return 0;
} }
...@@ -219,16 +260,54 @@ static void bpf_tcp_release(struct sock *sk) ...@@ -219,16 +260,54 @@ static void bpf_tcp_release(struct sock *sk)
rcu_read_unlock(); rcu_read_unlock();
} }
static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
u32 hash, void *key, u32 key_size)
{
struct htab_elem *l;
hlist_for_each_entry_rcu(l, head, hash_node) {
if (l->hash == hash && !memcmp(&l->key, key, key_size))
return l;
}
return NULL;
}
static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
{
return &htab->buckets[hash & (htab->n_buckets - 1)];
}
static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
{
return &__select_bucket(htab, hash)->head;
}
static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l) static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
{ {
atomic_dec(&htab->count); atomic_dec(&htab->count);
kfree_rcu(l, rcu); kfree_rcu(l, rcu);
} }
static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
struct smap_psock *psock)
{
struct smap_psock_map_entry *e;
spin_lock_bh(&psock->maps_lock);
e = list_first_entry_or_null(&psock->maps,
struct smap_psock_map_entry,
list);
if (e)
list_del(&e->list);
spin_unlock_bh(&psock->maps_lock);
return e;
}
static void bpf_tcp_close(struct sock *sk, long timeout) static void bpf_tcp_close(struct sock *sk, long timeout)
{ {
void (*close_fun)(struct sock *sk, long timeout); void (*close_fun)(struct sock *sk, long timeout);
struct smap_psock_map_entry *e, *tmp; struct smap_psock_map_entry *e;
struct sk_msg_buff *md, *mtmp; struct sk_msg_buff *md, *mtmp;
struct smap_psock *psock; struct smap_psock *psock;
struct sock *osk; struct sock *osk;
...@@ -247,7 +326,6 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -247,7 +326,6 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
*/ */
close_fun = psock->save_close; close_fun = psock->save_close;
write_lock_bh(&sk->sk_callback_lock);
if (psock->cork) { if (psock->cork) {
free_start_sg(psock->sock, psock->cork); free_start_sg(psock->sock, psock->cork);
kfree(psock->cork); kfree(psock->cork);
...@@ -260,20 +338,38 @@ static void bpf_tcp_close(struct sock *sk, long timeout) ...@@ -260,20 +338,38 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
kfree(md); kfree(md);
} }
list_for_each_entry_safe(e, tmp, &psock->maps, list) { e = psock_map_pop(sk, psock);
while (e) {
if (e->entry) { if (e->entry) {
osk = cmpxchg(e->entry, sk, NULL); osk = cmpxchg(e->entry, sk, NULL);
if (osk == sk) { if (osk == sk) {
list_del(&e->list);
smap_release_sock(psock, sk); smap_release_sock(psock, sk);
} }
} else { } else {
hlist_del_rcu(&e->hash_link->hash_node); struct htab_elem *link = rcu_dereference(e->hash_link);
smap_release_sock(psock, e->hash_link->sk); struct bpf_htab *htab = rcu_dereference(e->htab);
free_htab_elem(e->htab, e->hash_link); struct hlist_head *head;
struct htab_elem *l;
struct bucket *b;
b = __select_bucket(htab, link->hash);
head = &b->head;
raw_spin_lock_bh(&b->lock);
l = lookup_elem_raw(head,
link->hash, link->key,
htab->map.key_size);
/* If another thread deleted this object skip deletion.
* The refcnt on psock may or may not be zero.
*/
if (l) {
hlist_del_rcu(&link->hash_node);
smap_release_sock(psock, link->sk);
free_htab_elem(htab, link);
}
raw_spin_unlock_bh(&b->lock);
} }
e = psock_map_pop(sk, psock);
} }
write_unlock_bh(&sk->sk_callback_lock);
rcu_read_unlock(); rcu_read_unlock();
close_fun(sk, timeout); close_fun(sk, timeout);
} }
...@@ -1111,8 +1207,7 @@ static void bpf_tcp_msg_add(struct smap_psock *psock, ...@@ -1111,8 +1207,7 @@ static void bpf_tcp_msg_add(struct smap_psock *psock,
static int bpf_tcp_ulp_register(void) static int bpf_tcp_ulp_register(void)
{ {
tcp_bpf_proto = tcp_prot; build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
tcp_bpf_proto.close = bpf_tcp_close;
/* Once BPF TX ULP is registered it is never unregistered. It /* Once BPF TX ULP is registered it is never unregistered. It
* will be in the ULP list for the lifetime of the system. Doing * will be in the ULP list for the lifetime of the system. Doing
* duplicate registers is not a problem. * duplicate registers is not a problem.
...@@ -1357,7 +1452,9 @@ static void smap_release_sock(struct smap_psock *psock, struct sock *sock) ...@@ -1357,7 +1452,9 @@ static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
{ {
if (refcount_dec_and_test(&psock->refcnt)) { if (refcount_dec_and_test(&psock->refcnt)) {
tcp_cleanup_ulp(sock); tcp_cleanup_ulp(sock);
write_lock_bh(&sock->sk_callback_lock);
smap_stop_sock(psock, sock); smap_stop_sock(psock, sock);
write_unlock_bh(&sock->sk_callback_lock);
clear_bit(SMAP_TX_RUNNING, &psock->state); clear_bit(SMAP_TX_RUNNING, &psock->state);
rcu_assign_sk_user_data(sock, NULL); rcu_assign_sk_user_data(sock, NULL);
call_rcu_sched(&psock->rcu, smap_destroy_psock); call_rcu_sched(&psock->rcu, smap_destroy_psock);
...@@ -1508,6 +1605,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock, int node) ...@@ -1508,6 +1605,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock, int node)
INIT_LIST_HEAD(&psock->maps); INIT_LIST_HEAD(&psock->maps);
INIT_LIST_HEAD(&psock->ingress); INIT_LIST_HEAD(&psock->ingress);
refcount_set(&psock->refcnt, 1); refcount_set(&psock->refcnt, 1);
spin_lock_init(&psock->maps_lock);
rcu_assign_sk_user_data(sock, psock); rcu_assign_sk_user_data(sock, psock);
sock_hold(sock); sock_hold(sock);
...@@ -1564,18 +1662,32 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr) ...@@ -1564,18 +1662,32 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
return ERR_PTR(err); return ERR_PTR(err);
} }
static void smap_list_remove(struct smap_psock *psock, static void smap_list_map_remove(struct smap_psock *psock,
struct sock **entry, struct sock **entry)
struct htab_elem *hash_link)
{ {
struct smap_psock_map_entry *e, *tmp; struct smap_psock_map_entry *e, *tmp;
spin_lock_bh(&psock->maps_lock);
list_for_each_entry_safe(e, tmp, &psock->maps, list) { list_for_each_entry_safe(e, tmp, &psock->maps, list) {
if (e->entry == entry || e->hash_link == hash_link) { if (e->entry == entry)
list_del(&e->list); list_del(&e->list);
break;
}
} }
spin_unlock_bh(&psock->maps_lock);
}
static void smap_list_hash_remove(struct smap_psock *psock,
struct htab_elem *hash_link)
{
struct smap_psock_map_entry *e, *tmp;
spin_lock_bh(&psock->maps_lock);
list_for_each_entry_safe(e, tmp, &psock->maps, list) {
struct htab_elem *c = rcu_dereference(e->hash_link);
if (c == hash_link)
list_del(&e->list);
}
spin_unlock_bh(&psock->maps_lock);
} }
static void sock_map_free(struct bpf_map *map) static void sock_map_free(struct bpf_map *map)
...@@ -1601,7 +1713,6 @@ static void sock_map_free(struct bpf_map *map) ...@@ -1601,7 +1713,6 @@ static void sock_map_free(struct bpf_map *map)
if (!sock) if (!sock)
continue; continue;
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
/* This check handles a racing sock event that can get the /* This check handles a racing sock event that can get the
* sk_callback_lock before this case but after xchg happens * sk_callback_lock before this case but after xchg happens
...@@ -1609,10 +1720,9 @@ static void sock_map_free(struct bpf_map *map) ...@@ -1609,10 +1720,9 @@ static void sock_map_free(struct bpf_map *map)
* to be null and queued for garbage collection. * to be null and queued for garbage collection.
*/ */
if (likely(psock)) { if (likely(psock)) {
smap_list_remove(psock, &stab->sock_map[i], NULL); smap_list_map_remove(psock, &stab->sock_map[i]);
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
} }
write_unlock_bh(&sock->sk_callback_lock);
} }
rcu_read_unlock(); rcu_read_unlock();
...@@ -1661,17 +1771,15 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key) ...@@ -1661,17 +1771,15 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
if (!sock) if (!sock)
return -EINVAL; return -EINVAL;
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
if (!psock) if (!psock)
goto out; goto out;
if (psock->bpf_parse) if (psock->bpf_parse)
smap_stop_sock(psock, sock); smap_stop_sock(psock, sock);
smap_list_remove(psock, &stab->sock_map[k], NULL); smap_list_map_remove(psock, &stab->sock_map[k]);
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
out: out:
write_unlock_bh(&sock->sk_callback_lock);
return 0; return 0;
} }
...@@ -1752,7 +1860,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1752,7 +1860,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
} }
} }
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
/* 2. Do not allow inheriting programs if psock exists and has /* 2. Do not allow inheriting programs if psock exists and has
...@@ -1809,7 +1916,9 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1809,7 +1916,9 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
if (err) if (err)
goto out_free; goto out_free;
smap_init_progs(psock, verdict, parse); smap_init_progs(psock, verdict, parse);
write_lock_bh(&sock->sk_callback_lock);
smap_start_sock(psock, sock); smap_start_sock(psock, sock);
write_unlock_bh(&sock->sk_callback_lock);
} }
/* 4. Place psock in sockmap for use and stop any programs on /* 4. Place psock in sockmap for use and stop any programs on
...@@ -1819,9 +1928,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1819,9 +1928,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
*/ */
if (map_link) { if (map_link) {
e->entry = map_link; e->entry = map_link;
spin_lock_bh(&psock->maps_lock);
list_add_tail(&e->list, &psock->maps); list_add_tail(&e->list, &psock->maps);
spin_unlock_bh(&psock->maps_lock);
} }
write_unlock_bh(&sock->sk_callback_lock);
return err; return err;
out_free: out_free:
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
...@@ -1832,7 +1942,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map, ...@@ -1832,7 +1942,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
} }
if (tx_msg) if (tx_msg)
bpf_prog_put(tx_msg); bpf_prog_put(tx_msg);
write_unlock_bh(&sock->sk_callback_lock);
kfree(e); kfree(e);
return err; return err;
} }
...@@ -1869,10 +1978,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -1869,10 +1978,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
if (osock) { if (osock) {
struct smap_psock *opsock = smap_psock_sk(osock); struct smap_psock *opsock = smap_psock_sk(osock);
write_lock_bh(&osock->sk_callback_lock); smap_list_map_remove(opsock, &stab->sock_map[i]);
smap_list_remove(opsock, &stab->sock_map[i], NULL);
smap_release_sock(opsock, osock); smap_release_sock(opsock, osock);
write_unlock_bh(&osock->sk_callback_lock);
} }
out: out:
return err; return err;
...@@ -2061,14 +2168,13 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) ...@@ -2061,14 +2168,13 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
return ERR_PTR(err); return ERR_PTR(err);
} }
static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash) static void __bpf_htab_free(struct rcu_head *rcu)
{ {
return &htab->buckets[hash & (htab->n_buckets - 1)]; struct bpf_htab *htab;
}
static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash) htab = container_of(rcu, struct bpf_htab, rcu);
{ bpf_map_area_free(htab->buckets);
return &__select_bucket(htab, hash)->head; kfree(htab);
} }
static void sock_hash_free(struct bpf_map *map) static void sock_hash_free(struct bpf_map *map)
...@@ -2087,16 +2193,18 @@ static void sock_hash_free(struct bpf_map *map) ...@@ -2087,16 +2193,18 @@ static void sock_hash_free(struct bpf_map *map)
*/ */
rcu_read_lock(); rcu_read_lock();
for (i = 0; i < htab->n_buckets; i++) { for (i = 0; i < htab->n_buckets; i++) {
struct hlist_head *head = select_bucket(htab, i); struct bucket *b = __select_bucket(htab, i);
struct hlist_head *head;
struct hlist_node *n; struct hlist_node *n;
struct htab_elem *l; struct htab_elem *l;
raw_spin_lock_bh(&b->lock);
head = &b->head;
hlist_for_each_entry_safe(l, n, head, hash_node) { hlist_for_each_entry_safe(l, n, head, hash_node) {
struct sock *sock = l->sk; struct sock *sock = l->sk;
struct smap_psock *psock; struct smap_psock *psock;
hlist_del_rcu(&l->hash_node); hlist_del_rcu(&l->hash_node);
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
/* This check handles a racing sock event that can get /* This check handles a racing sock event that can get
* the sk_callback_lock before this case but after xchg * the sk_callback_lock before this case but after xchg
...@@ -2104,16 +2212,15 @@ static void sock_hash_free(struct bpf_map *map) ...@@ -2104,16 +2212,15 @@ static void sock_hash_free(struct bpf_map *map)
* (psock) to be null and queued for garbage collection. * (psock) to be null and queued for garbage collection.
*/ */
if (likely(psock)) { if (likely(psock)) {
smap_list_remove(psock, NULL, l); smap_list_hash_remove(psock, l);
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
} }
write_unlock_bh(&sock->sk_callback_lock); free_htab_elem(htab, l);
kfree(l);
} }
raw_spin_unlock_bh(&b->lock);
} }
rcu_read_unlock(); rcu_read_unlock();
bpf_map_area_free(htab->buckets); call_rcu(&htab->rcu, __bpf_htab_free);
kfree(htab);
} }
static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab, static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
...@@ -2140,19 +2247,6 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab, ...@@ -2140,19 +2247,6 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
return l_new; return l_new;
} }
static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
u32 hash, void *key, u32 key_size)
{
struct htab_elem *l;
hlist_for_each_entry_rcu(l, head, hash_node) {
if (l->hash == hash && !memcmp(&l->key, key, key_size))
return l;
}
return NULL;
}
static inline u32 htab_map_hash(const void *key, u32 key_len) static inline u32 htab_map_hash(const void *key, u32 key_len)
{ {
return jhash(key, key_len, 0); return jhash(key, key_len, 0);
...@@ -2272,9 +2366,12 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -2272,9 +2366,12 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
goto bucket_err; goto bucket_err;
} }
e->hash_link = l_new; rcu_assign_pointer(e->hash_link, l_new);
e->htab = container_of(map, struct bpf_htab, map); rcu_assign_pointer(e->htab,
container_of(map, struct bpf_htab, map));
spin_lock_bh(&psock->maps_lock);
list_add_tail(&e->list, &psock->maps); list_add_tail(&e->list, &psock->maps);
spin_unlock_bh(&psock->maps_lock);
/* add new element to the head of the list, so that /* add new element to the head of the list, so that
* concurrent search will find it before old elem * concurrent search will find it before old elem
...@@ -2284,7 +2381,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops, ...@@ -2284,7 +2381,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
psock = smap_psock_sk(l_old->sk); psock = smap_psock_sk(l_old->sk);
hlist_del_rcu(&l_old->hash_node); hlist_del_rcu(&l_old->hash_node);
smap_list_remove(psock, NULL, l_old); smap_list_hash_remove(psock, l_old);
smap_release_sock(psock, l_old->sk); smap_release_sock(psock, l_old->sk);
free_htab_elem(htab, l_old); free_htab_elem(htab, l_old);
} }
...@@ -2344,7 +2441,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key) ...@@ -2344,7 +2441,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key)
struct smap_psock *psock; struct smap_psock *psock;
hlist_del_rcu(&l->hash_node); hlist_del_rcu(&l->hash_node);
write_lock_bh(&sock->sk_callback_lock);
psock = smap_psock_sk(sock); psock = smap_psock_sk(sock);
/* This check handles a racing sock event that can get the /* This check handles a racing sock event that can get the
* sk_callback_lock before this case but after xchg happens * sk_callback_lock before this case but after xchg happens
...@@ -2352,10 +2448,9 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key) ...@@ -2352,10 +2448,9 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key)
* to be null and queued for garbage collection. * to be null and queued for garbage collection.
*/ */
if (likely(psock)) { if (likely(psock)) {
smap_list_remove(psock, NULL, l); smap_list_hash_remove(psock, l);
smap_release_sock(psock, sock); smap_release_sock(psock, sock);
} }
write_unlock_bh(&sock->sk_callback_lock);
free_htab_elem(htab, l); free_htab_elem(htab, l);
ret = 0; ret = 0;
} }
...@@ -2401,6 +2496,7 @@ const struct bpf_map_ops sock_hash_ops = { ...@@ -2401,6 +2496,7 @@ const struct bpf_map_ops sock_hash_ops = {
.map_get_next_key = sock_hash_get_next_key, .map_get_next_key = sock_hash_get_next_key,
.map_update_elem = sock_hash_update_elem, .map_update_elem = sock_hash_update_elem,
.map_delete_elem = sock_hash_delete_elem, .map_delete_elem = sock_hash_delete_elem,
.map_release_uref = sock_map_release,
}; };
BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock, BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment