Commit 8217ca65 authored by Martin KaFai Lau's avatar Martin KaFai Lau Committed by Daniel Borkmann

bpf: Enable BPF_PROG_TYPE_SK_REUSEPORT bpf prog in reuseport selection

This patch allows a BPF_PROG_TYPE_SK_REUSEPORT bpf prog to select a
SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY introduced in
the earlier patch.  "bpf_run_sk_reuseport()" will return -ECONNREFUSED
when the BPF_PROG_TYPE_SK_REUSEPORT prog returns SK_DROP.
The callers, in inet[6]_hashtable.c and ipv[46]/udp.c, are modified to
handle this case and return NULL immediately instead of continuing the
sk search from its hashtable.

It re-uses the existing SO_ATTACH_REUSEPORT_EBPF setsockopt to attach
BPF_PROG_TYPE_SK_REUSEPORT.  The "sk_reuseport_attach_bpf()" will check
if the attaching bpf prog is in the new SK_REUSEPORT or the existing
SOCKET_FILTER type and then check different things accordingly.

One level of "__reuseport_attach_prog()" call is removed.  The
"sk_unhashed() && ..." and "sk->sk_reuseport_cb" tests are pushed
back to "reuseport_attach_prog()" in sock_reuseport.c.  sock_reuseport.c
seems to have more knowledge on those test requirements than filter.c.
In "reuseport_attach_prog()", after new_prog is attached to reuse->prog,
the old_prog (if any) is also directly freed instead of returning the
old_prog to the caller and asking the caller to free.

The sysctl_optmem_max check is moved back to the
"sk_reuseport_attach_filter()" and "sk_reuseport_attach_bpf()".
As of other bpf prog types, the new BPF_PROG_TYPE_SK_REUSEPORT is only
bounded by the usual "bpf_prog_charge_memlock()" during load time
instead of bounded by both bpf_prog_charge_memlock and sysctl_optmem_max.
Signed-off-by: default avatarMartin KaFai Lau <kafai@fb.com>
Acked-by: default avatarAlexei Starovoitov <ast@kernel.org>
Signed-off-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
parent 2dbb9b9e
...@@ -753,6 +753,7 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk); ...@@ -753,6 +753,7 @@ int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_attach_bpf(u32 ufd, struct sock *sk); int sk_attach_bpf(u32 ufd, struct sock *sk);
int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk); int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk);
int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk); int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk);
void sk_reuseport_prog_free(struct bpf_prog *prog);
int sk_detach_filter(struct sock *sk); int sk_detach_filter(struct sock *sk);
int sk_get_filter(struct sock *sk, struct sock_filter __user *filter, int sk_get_filter(struct sock *sk, struct sock_filter __user *filter,
unsigned int len); unsigned int len);
......
...@@ -34,8 +34,7 @@ extern struct sock *reuseport_select_sock(struct sock *sk, ...@@ -34,8 +34,7 @@ extern struct sock *reuseport_select_sock(struct sock *sk,
u32 hash, u32 hash,
struct sk_buff *skb, struct sk_buff *skb,
int hdr_len); int hdr_len);
extern struct bpf_prog *reuseport_attach_prog(struct sock *sk, extern int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog);
struct bpf_prog *prog);
int reuseport_get_id(struct sock_reuseport *reuse); int reuseport_get_id(struct sock_reuseport *reuse);
#endif /* _SOCK_REUSEPORT_H */ #endif /* _SOCK_REUSEPORT_H */
...@@ -1453,30 +1453,6 @@ static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk) ...@@ -1453,30 +1453,6 @@ static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk)
return 0; return 0;
} }
static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
{
struct bpf_prog *old_prog;
int err;
if (bpf_prog_size(prog->len) > sysctl_optmem_max)
return -ENOMEM;
if (sk_unhashed(sk) && sk->sk_reuseport) {
err = reuseport_alloc(sk, false);
if (err)
return err;
} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
/* The socket wasn't bound with SO_REUSEPORT */
return -EINVAL;
}
old_prog = reuseport_attach_prog(sk, prog);
if (old_prog)
bpf_prog_destroy(old_prog);
return 0;
}
static static
struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk) struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk)
{ {
...@@ -1550,13 +1526,15 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk) ...@@ -1550,13 +1526,15 @@ int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk)
if (IS_ERR(prog)) if (IS_ERR(prog))
return PTR_ERR(prog); return PTR_ERR(prog);
err = __reuseport_attach_prog(prog, sk); if (bpf_prog_size(prog->len) > sysctl_optmem_max)
if (err < 0) { err = -ENOMEM;
else
err = reuseport_attach_prog(sk, prog);
if (err)
__bpf_prog_release(prog); __bpf_prog_release(prog);
return err;
}
return 0; return err;
} }
static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk) static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk)
...@@ -1586,19 +1564,58 @@ int sk_attach_bpf(u32 ufd, struct sock *sk) ...@@ -1586,19 +1564,58 @@ int sk_attach_bpf(u32 ufd, struct sock *sk)
int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk) int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk)
{ {
struct bpf_prog *prog = __get_bpf(ufd, sk); struct bpf_prog *prog;
int err; int err;
if (sock_flag(sk, SOCK_FILTER_LOCKED))
return -EPERM;
prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER);
if (IS_ERR(prog) && PTR_ERR(prog) == -EINVAL)
prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT);
if (IS_ERR(prog)) if (IS_ERR(prog))
return PTR_ERR(prog); return PTR_ERR(prog);
err = __reuseport_attach_prog(prog, sk); if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) {
if (err < 0) { /* Like other non BPF_PROG_TYPE_SOCKET_FILTER
bpf_prog_put(prog); * bpf prog (e.g. sockmap). It depends on the
return err; * limitation imposed by bpf_prog_load().
* Hence, sysctl_optmem_max is not checked.
*/
if ((sk->sk_type != SOCK_STREAM &&
sk->sk_type != SOCK_DGRAM) ||
(sk->sk_protocol != IPPROTO_UDP &&
sk->sk_protocol != IPPROTO_TCP) ||
(sk->sk_family != AF_INET &&
sk->sk_family != AF_INET6)) {
err = -ENOTSUPP;
goto err_prog_put;
}
} else {
/* BPF_PROG_TYPE_SOCKET_FILTER */
if (bpf_prog_size(prog->len) > sysctl_optmem_max) {
err = -ENOMEM;
goto err_prog_put;
}
} }
return 0; err = reuseport_attach_prog(sk, prog);
err_prog_put:
if (err)
bpf_prog_put(prog);
return err;
}
void sk_reuseport_prog_free(struct bpf_prog *prog)
{
if (!prog)
return;
if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
bpf_prog_put(prog);
else
bpf_prog_destroy(prog);
} }
struct bpf_scratchpad { struct bpf_scratchpad {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <net/sock_reuseport.h> #include <net/sock_reuseport.h>
#include <linux/bpf.h> #include <linux/bpf.h>
#include <linux/idr.h> #include <linux/idr.h>
#include <linux/filter.h>
#include <linux/rcupdate.h> #include <linux/rcupdate.h>
#define INIT_SOCKS 128 #define INIT_SOCKS 128
...@@ -133,8 +134,7 @@ static void reuseport_free_rcu(struct rcu_head *head) ...@@ -133,8 +134,7 @@ static void reuseport_free_rcu(struct rcu_head *head)
struct sock_reuseport *reuse; struct sock_reuseport *reuse;
reuse = container_of(head, struct sock_reuseport, rcu); reuse = container_of(head, struct sock_reuseport, rcu);
if (reuse->prog) sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1));
bpf_prog_destroy(reuse->prog);
if (reuse->reuseport_id) if (reuse->reuseport_id)
ida_simple_remove(&reuseport_ida, reuse->reuseport_id); ida_simple_remove(&reuseport_ida, reuse->reuseport_id);
kfree(reuse); kfree(reuse);
...@@ -219,9 +219,9 @@ void reuseport_detach_sock(struct sock *sk) ...@@ -219,9 +219,9 @@ void reuseport_detach_sock(struct sock *sk)
} }
EXPORT_SYMBOL(reuseport_detach_sock); EXPORT_SYMBOL(reuseport_detach_sock);
static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks, static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks,
struct bpf_prog *prog, struct sk_buff *skb, struct bpf_prog *prog, struct sk_buff *skb,
int hdr_len) int hdr_len)
{ {
struct sk_buff *nskb = NULL; struct sk_buff *nskb = NULL;
u32 index; u32 index;
...@@ -282,9 +282,15 @@ struct sock *reuseport_select_sock(struct sock *sk, ...@@ -282,9 +282,15 @@ struct sock *reuseport_select_sock(struct sock *sk,
/* paired with smp_wmb() in reuseport_add_sock() */ /* paired with smp_wmb() in reuseport_add_sock() */
smp_rmb(); smp_rmb();
if (prog && skb) if (!prog || !skb)
sk2 = run_bpf(reuse, socks, prog, skb, hdr_len); goto select_by_hash;
if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT)
sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, hash);
else
sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len);
select_by_hash:
/* no bpf or invalid bpf result: fall back to hash usage */ /* no bpf or invalid bpf result: fall back to hash usage */
if (!sk2) if (!sk2)
sk2 = reuse->socks[reciprocal_scale(hash, socks)]; sk2 = reuse->socks[reciprocal_scale(hash, socks)];
...@@ -296,12 +302,21 @@ struct sock *reuseport_select_sock(struct sock *sk, ...@@ -296,12 +302,21 @@ struct sock *reuseport_select_sock(struct sock *sk,
} }
EXPORT_SYMBOL(reuseport_select_sock); EXPORT_SYMBOL(reuseport_select_sock);
struct bpf_prog * int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
{ {
struct sock_reuseport *reuse; struct sock_reuseport *reuse;
struct bpf_prog *old_prog; struct bpf_prog *old_prog;
if (sk_unhashed(sk) && sk->sk_reuseport) {
int err = reuseport_alloc(sk, false);
if (err)
return err;
} else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
/* The socket wasn't bound with SO_REUSEPORT */
return -EINVAL;
}
spin_lock_bh(&reuseport_lock); spin_lock_bh(&reuseport_lock);
reuse = rcu_dereference_protected(sk->sk_reuseport_cb, reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
lockdep_is_held(&reuseport_lock)); lockdep_is_held(&reuseport_lock));
...@@ -310,6 +325,7 @@ reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) ...@@ -310,6 +325,7 @@ reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
rcu_assign_pointer(reuse->prog, prog); rcu_assign_pointer(reuse->prog, prog);
spin_unlock_bh(&reuseport_lock); spin_unlock_bh(&reuseport_lock);
return old_prog; sk_reuseport_prog_free(old_prog);
return 0;
} }
EXPORT_SYMBOL(reuseport_attach_prog); EXPORT_SYMBOL(reuseport_attach_prog);
...@@ -328,7 +328,7 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -328,7 +328,7 @@ struct sock *__inet_lookup_listener(struct net *net,
saddr, sport, daddr, hnum, saddr, sport, daddr, hnum,
dif, sdif); dif, sdif);
if (result) if (result)
return result; goto done;
/* Lookup lhash2 with INADDR_ANY */ /* Lookup lhash2 with INADDR_ANY */
...@@ -337,9 +337,10 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -337,9 +337,10 @@ struct sock *__inet_lookup_listener(struct net *net,
if (ilb2->count > ilb->count) if (ilb2->count > ilb->count)
goto port_lookup; goto port_lookup;
return inet_lhash2_lookup(net, ilb2, skb, doff, result = inet_lhash2_lookup(net, ilb2, skb, doff,
saddr, sport, daddr, hnum, saddr, sport, daddr, hnum,
dif, sdif); dif, sdif);
goto done;
port_lookup: port_lookup:
sk_for_each_rcu(sk, &ilb->head) { sk_for_each_rcu(sk, &ilb->head) {
...@@ -352,12 +353,15 @@ struct sock *__inet_lookup_listener(struct net *net, ...@@ -352,12 +353,15 @@ struct sock *__inet_lookup_listener(struct net *net,
result = reuseport_select_sock(sk, phash, result = reuseport_select_sock(sk, phash,
skb, doff); skb, doff);
if (result) if (result)
return result; goto done;
} }
result = sk; result = sk;
hiscore = score; hiscore = score;
} }
} }
done:
if (unlikely(IS_ERR(result)))
return NULL;
return result; return result;
} }
EXPORT_SYMBOL_GPL(__inet_lookup_listener); EXPORT_SYMBOL_GPL(__inet_lookup_listener);
......
...@@ -499,6 +499,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -499,6 +499,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
daddr, hnum, dif, sdif, daddr, hnum, dif, sdif,
exact_dif, hslot2, skb); exact_dif, hslot2, skb);
} }
if (unlikely(IS_ERR(result)))
return NULL;
return result; return result;
} }
begin: begin:
...@@ -513,6 +515,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, ...@@ -513,6 +515,8 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
saddr, sport); saddr, sport);
result = reuseport_select_sock(sk, hash, skb, result = reuseport_select_sock(sk, hash, skb,
sizeof(struct udphdr)); sizeof(struct udphdr));
if (unlikely(IS_ERR(result)))
return NULL;
if (result) if (result)
return result; return result;
} }
......
...@@ -191,7 +191,7 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -191,7 +191,7 @@ struct sock *inet6_lookup_listener(struct net *net,
saddr, sport, daddr, hnum, saddr, sport, daddr, hnum,
dif, sdif); dif, sdif);
if (result) if (result)
return result; goto done;
/* Lookup lhash2 with in6addr_any */ /* Lookup lhash2 with in6addr_any */
...@@ -200,9 +200,10 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -200,9 +200,10 @@ struct sock *inet6_lookup_listener(struct net *net,
if (ilb2->count > ilb->count) if (ilb2->count > ilb->count)
goto port_lookup; goto port_lookup;
return inet6_lhash2_lookup(net, ilb2, skb, doff, result = inet6_lhash2_lookup(net, ilb2, skb, doff,
saddr, sport, daddr, hnum, saddr, sport, daddr, hnum,
dif, sdif); dif, sdif);
goto done;
port_lookup: port_lookup:
sk_for_each(sk, &ilb->head) { sk_for_each(sk, &ilb->head) {
...@@ -214,12 +215,15 @@ struct sock *inet6_lookup_listener(struct net *net, ...@@ -214,12 +215,15 @@ struct sock *inet6_lookup_listener(struct net *net,
result = reuseport_select_sock(sk, phash, result = reuseport_select_sock(sk, phash,
skb, doff); skb, doff);
if (result) if (result)
return result; goto done;
} }
result = sk; result = sk;
hiscore = score; hiscore = score;
} }
} }
done:
if (unlikely(IS_ERR(result)))
return NULL;
return result; return result;
} }
EXPORT_SYMBOL_GPL(inet6_lookup_listener); EXPORT_SYMBOL_GPL(inet6_lookup_listener);
......
...@@ -235,6 +235,8 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -235,6 +235,8 @@ struct sock *__udp6_lib_lookup(struct net *net,
exact_dif, hslot2, exact_dif, hslot2,
skb); skb);
} }
if (unlikely(IS_ERR(result)))
return NULL;
return result; return result;
} }
begin: begin:
...@@ -249,6 +251,8 @@ struct sock *__udp6_lib_lookup(struct net *net, ...@@ -249,6 +251,8 @@ struct sock *__udp6_lib_lookup(struct net *net,
saddr, sport); saddr, sport);
result = reuseport_select_sock(sk, hash, skb, result = reuseport_select_sock(sk, hash, skb,
sizeof(struct udphdr)); sizeof(struct udphdr));
if (unlikely(IS_ERR(result)))
return NULL;
if (result) if (result)
return result; return result;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment