Commit 71489e21 authored by Joe Stringer's avatar Joe Stringer Committed by Alexei Starovoitov

net: Track socket refcounts in skb_steal_sock()

Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.
Signed-off-by: default avatarJoe Stringer <joe@wand.net.nz>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarMartin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
parent cf7fbe66
...@@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo, ...@@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
int iif, int sdif, int iif, int sdif,
bool *refcounted) bool *refcounted)
{ {
struct sock *sk = skb_steal_sock(skb); struct sock *sk = skb_steal_sock(skb, refcounted);
*refcounted = true;
if (sk) if (sk)
return sk; return sk;
......
...@@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo, ...@@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
const int sdif, const int sdif,
bool *refcounted) bool *refcounted)
{ {
struct sock *sk = skb_steal_sock(skb); struct sock *sk = skb_steal_sock(skb, refcounted);
const struct iphdr *iph = ip_hdr(skb); const struct iphdr *iph = ip_hdr(skb);
*refcounted = true;
if (sk) if (sk)
return sk; return sk;
......
...@@ -2537,15 +2537,23 @@ skb_sk_is_prefetched(struct sk_buff *skb) ...@@ -2537,15 +2537,23 @@ skb_sk_is_prefetched(struct sk_buff *skb)
#endif /* CONFIG_INET */ #endif /* CONFIG_INET */
} }
static inline struct sock *skb_steal_sock(struct sk_buff *skb) /**
* skb_steal_sock
* @skb to steal the socket from
* @refcounted is set to true if the socket is reference-counted
*/
static inline struct sock *
skb_steal_sock(struct sk_buff *skb, bool *refcounted)
{ {
if (skb->sk) { if (skb->sk) {
struct sock *sk = skb->sk; struct sock *sk = skb->sk;
*refcounted = true;
skb->destructor = NULL; skb->destructor = NULL;
skb->sk = NULL; skb->sk = NULL;
return sk; return sk;
} }
*refcounted = false;
return NULL; return NULL;
} }
......
...@@ -2288,6 +2288,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -2288,6 +2288,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
struct rtable *rt = skb_rtable(skb); struct rtable *rt = skb_rtable(skb);
__be32 saddr, daddr; __be32 saddr, daddr;
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
bool refcounted;
/* /*
* Validate the packet. * Validate the packet.
...@@ -2313,7 +2314,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -2313,7 +2314,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
if (udp4_csum_init(skb, uh, proto)) if (udp4_csum_init(skb, uh, proto))
goto csum_error; goto csum_error;
sk = skb_steal_sock(skb); sk = skb_steal_sock(skb, &refcounted);
if (sk) { if (sk) {
struct dst_entry *dst = skb_dst(skb); struct dst_entry *dst = skb_dst(skb);
int ret; int ret;
...@@ -2322,7 +2323,8 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -2322,7 +2323,8 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
udp_sk_rx_dst_set(sk, dst); udp_sk_rx_dst_set(sk, dst);
ret = udp_unicast_rcv_skb(sk, skb, uh); ret = udp_unicast_rcv_skb(sk, skb, uh);
sock_put(sk); if (refcounted)
sock_put(sk);
return ret; return ret;
} }
......
...@@ -843,6 +843,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -843,6 +843,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
struct net *net = dev_net(skb->dev); struct net *net = dev_net(skb->dev);
struct udphdr *uh; struct udphdr *uh;
struct sock *sk; struct sock *sk;
bool refcounted;
u32 ulen = 0; u32 ulen = 0;
if (!pskb_may_pull(skb, sizeof(struct udphdr))) if (!pskb_may_pull(skb, sizeof(struct udphdr)))
...@@ -879,7 +880,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -879,7 +880,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto csum_error; goto csum_error;
/* Check if the socket is already available, e.g. due to early demux */ /* Check if the socket is already available, e.g. due to early demux */
sk = skb_steal_sock(skb); sk = skb_steal_sock(skb, &refcounted);
if (sk) { if (sk) {
struct dst_entry *dst = skb_dst(skb); struct dst_entry *dst = skb_dst(skb);
int ret; int ret;
...@@ -888,12 +889,14 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, ...@@ -888,12 +889,14 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
udp6_sk_rx_dst_set(sk, dst); udp6_sk_rx_dst_set(sk, dst);
if (!uh->check && !udp_sk(sk)->no_check6_rx) { if (!uh->check && !udp_sk(sk)->no_check6_rx) {
sock_put(sk); if (refcounted)
sock_put(sk);
goto report_csum_error; goto report_csum_error;
} }
ret = udp6_unicast_rcv_skb(sk, skb, uh); ret = udp6_unicast_rcv_skb(sk, skb, uh);
sock_put(sk); if (refcounted)
sock_put(sk);
return ret; return ret;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment