Commit e40526cb authored by Daniel Borkmann's avatar Daniel Borkmann Committed by David S. Miller

packet: fix use after free race in send path when dev is released

Salam reported a use after free bug in PF_PACKET that occurs when
we're sending out frames on a socket bound device and suddenly the
net device is being unregistered. It appears that commit 827d9780
introduced a possible race condition between {t,}packet_snd() and
packet_notifier(). In the case of a bound socket, packet_notifier()
can drop the last reference to the net_device and {t,}packet_snd()
might end up suddenly sending a packet over a freed net_device.

To avoid reverting 827d9780 and thus introducing a performance
regression compared to the current state of things, we decided to
hold a cached RCU protected pointer to the net device and maintain
it on write side via bind spin_lock protected register_prot_hook()
and __unregister_prot_hook() calls.

In {t,}packet_snd() path, we access this pointer under rcu_read_lock
through packet_cached_dev_get() that holds reference to the device
to prevent it from being freed through packet_notifier() while
we're in send path. This is okay to do as dev_put()/dev_hold() are
per-cpu counters, so this should not be a performance issue. Also,
the code simplifies a bit as we don't need need_rls_dev anymore.

Fixes: 827d9780 ("af-packet: Use existing netdev reference for bound sockets.")
Reported-by: default avatarSalam Noureddine <noureddine@aristanetworks.com>
Signed-off-by: default avatarDaniel Borkmann <dborkman@redhat.com>
Signed-off-by: default avatarSalam Noureddine <noureddine@aristanetworks.com>
Cc: Ben Greear <greearb@candelatech.com>
Cc: Eric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent db739ef3
...@@ -244,11 +244,15 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po); ...@@ -244,11 +244,15 @@ static void __fanout_link(struct sock *sk, struct packet_sock *po);
static void register_prot_hook(struct sock *sk) static void register_prot_hook(struct sock *sk)
{ {
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
if (!po->running) { if (!po->running) {
if (po->fanout) if (po->fanout) {
__fanout_link(sk, po); __fanout_link(sk, po);
else } else {
dev_add_pack(&po->prot_hook); dev_add_pack(&po->prot_hook);
rcu_assign_pointer(po->cached_dev, po->prot_hook.dev);
}
sock_hold(sk); sock_hold(sk);
po->running = 1; po->running = 1;
} }
...@@ -266,10 +270,13 @@ static void __unregister_prot_hook(struct sock *sk, bool sync) ...@@ -266,10 +270,13 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
po->running = 0; po->running = 0;
if (po->fanout) if (po->fanout) {
__fanout_unlink(sk, po); __fanout_unlink(sk, po);
else } else {
__dev_remove_pack(&po->prot_hook); __dev_remove_pack(&po->prot_hook);
RCU_INIT_POINTER(po->cached_dev, NULL);
}
__sock_put(sk); __sock_put(sk);
if (sync) { if (sync) {
...@@ -2052,12 +2059,24 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb, ...@@ -2052,12 +2059,24 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
return tp_len; return tp_len;
} }
static struct net_device *packet_cached_dev_get(struct packet_sock *po)
{
struct net_device *dev;
rcu_read_lock();
dev = rcu_dereference(po->cached_dev);
if (dev)
dev_hold(dev);
rcu_read_unlock();
return dev;
}
static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
{ {
struct sk_buff *skb; struct sk_buff *skb;
struct net_device *dev; struct net_device *dev;
__be16 proto; __be16 proto;
bool need_rls_dev = false;
int err, reserve = 0; int err, reserve = 0;
void *ph; void *ph;
struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name; struct sockaddr_ll *saddr = (struct sockaddr_ll *)msg->msg_name;
...@@ -2070,7 +2089,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2070,7 +2089,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
mutex_lock(&po->pg_vec_lock); mutex_lock(&po->pg_vec_lock);
if (saddr == NULL) { if (saddr == NULL) {
dev = po->prot_hook.dev; dev = packet_cached_dev_get(po);
proto = po->num; proto = po->num;
addr = NULL; addr = NULL;
} else { } else {
...@@ -2084,19 +2103,17 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2084,19 +2103,17 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
proto = saddr->sll_protocol; proto = saddr->sll_protocol;
addr = saddr->sll_addr; addr = saddr->sll_addr;
dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex); dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex);
need_rls_dev = true;
} }
err = -ENXIO; err = -ENXIO;
if (unlikely(dev == NULL)) if (unlikely(dev == NULL))
goto out; goto out;
reserve = dev->hard_header_len;
err = -ENETDOWN; err = -ENETDOWN;
if (unlikely(!(dev->flags & IFF_UP))) if (unlikely(!(dev->flags & IFF_UP)))
goto out_put; goto out_put;
reserve = dev->hard_header_len;
size_max = po->tx_ring.frame_size size_max = po->tx_ring.frame_size
- (po->tp_hdrlen - sizeof(struct sockaddr_ll)); - (po->tp_hdrlen - sizeof(struct sockaddr_ll));
...@@ -2173,7 +2190,6 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) ...@@ -2173,7 +2190,6 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
__packet_set_status(po, ph, status); __packet_set_status(po, ph, status);
kfree_skb(skb); kfree_skb(skb);
out_put: out_put:
if (need_rls_dev)
dev_put(dev); dev_put(dev);
out: out:
mutex_unlock(&po->pg_vec_lock); mutex_unlock(&po->pg_vec_lock);
...@@ -2212,7 +2228,6 @@ static int packet_snd(struct socket *sock, ...@@ -2212,7 +2228,6 @@ static int packet_snd(struct socket *sock,
struct sk_buff *skb; struct sk_buff *skb;
struct net_device *dev; struct net_device *dev;
__be16 proto; __be16 proto;
bool need_rls_dev = false;
unsigned char *addr; unsigned char *addr;
int err, reserve = 0; int err, reserve = 0;
struct virtio_net_hdr vnet_hdr = { 0 }; struct virtio_net_hdr vnet_hdr = { 0 };
...@@ -2228,7 +2243,7 @@ static int packet_snd(struct socket *sock, ...@@ -2228,7 +2243,7 @@ static int packet_snd(struct socket *sock,
*/ */
if (saddr == NULL) { if (saddr == NULL) {
dev = po->prot_hook.dev; dev = packet_cached_dev_get(po);
proto = po->num; proto = po->num;
addr = NULL; addr = NULL;
} else { } else {
...@@ -2240,19 +2255,17 @@ static int packet_snd(struct socket *sock, ...@@ -2240,19 +2255,17 @@ static int packet_snd(struct socket *sock,
proto = saddr->sll_protocol; proto = saddr->sll_protocol;
addr = saddr->sll_addr; addr = saddr->sll_addr;
dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex); dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex);
need_rls_dev = true;
} }
err = -ENXIO; err = -ENXIO;
if (dev == NULL) if (unlikely(dev == NULL))
goto out_unlock; goto out_unlock;
if (sock->type == SOCK_RAW)
reserve = dev->hard_header_len;
err = -ENETDOWN; err = -ENETDOWN;
if (!(dev->flags & IFF_UP)) if (unlikely(!(dev->flags & IFF_UP)))
goto out_unlock; goto out_unlock;
if (sock->type == SOCK_RAW)
reserve = dev->hard_header_len;
if (po->has_vnet_hdr) { if (po->has_vnet_hdr) {
vnet_hdr_len = sizeof(vnet_hdr); vnet_hdr_len = sizeof(vnet_hdr);
...@@ -2386,7 +2399,6 @@ static int packet_snd(struct socket *sock, ...@@ -2386,7 +2399,6 @@ static int packet_snd(struct socket *sock,
if (err > 0 && (err = net_xmit_errno(err)) != 0) if (err > 0 && (err = net_xmit_errno(err)) != 0)
goto out_unlock; goto out_unlock;
if (need_rls_dev)
dev_put(dev); dev_put(dev);
return len; return len;
...@@ -2394,7 +2406,7 @@ static int packet_snd(struct socket *sock, ...@@ -2394,7 +2406,7 @@ static int packet_snd(struct socket *sock,
out_free: out_free:
kfree_skb(skb); kfree_skb(skb);
out_unlock: out_unlock:
if (dev && need_rls_dev) if (dev)
dev_put(dev); dev_put(dev);
out: out:
return err; return err;
...@@ -2614,6 +2626,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol, ...@@ -2614,6 +2626,7 @@ static int packet_create(struct net *net, struct socket *sock, int protocol,
po = pkt_sk(sk); po = pkt_sk(sk);
sk->sk_family = PF_PACKET; sk->sk_family = PF_PACKET;
po->num = proto; po->num = proto;
RCU_INIT_POINTER(po->cached_dev, NULL);
sk->sk_destruct = packet_sock_destruct; sk->sk_destruct = packet_sock_destruct;
sk_refcnt_debug_inc(sk); sk_refcnt_debug_inc(sk);
......
...@@ -113,6 +113,7 @@ struct packet_sock { ...@@ -113,6 +113,7 @@ struct packet_sock {
unsigned int tp_loss:1; unsigned int tp_loss:1;
unsigned int tp_tx_has_off:1; unsigned int tp_tx_has_off:1;
unsigned int tp_tstamp; unsigned int tp_tstamp;
struct net_device __rcu *cached_dev;
struct packet_type prot_hook ____cacheline_aligned_in_smp; struct packet_type prot_hook ____cacheline_aligned_in_smp;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment