Commit cb2344c6 authored by Eric Dumazet's avatar Eric Dumazet Committed by Thadeu Lima de Souza Cascardo

packet: fix races in fanout_add()

BugLink: http://bugs.launchpad.net/bugs/1669016

[ Upstream commit d199fab6 ]

Multiple threads can call fanout_add() at the same time.

We need to grab fanout_mutex earlier to avoid races that could
lead to one thread freeing po->rollover that was set by another thread.

Do the same in fanout_release(), for peace of mind, and to help us
finding lockdep issues earlier.

Fixes: dc99f600 ("packet: Add fanout support.")
Fixes: 0648ab70 ("packet: rollover prepare: per-socket state")
Signed-off-by: default avatarEric Dumazet <edumazet@google.com>
Cc: Willem de Bruijn <willemb@google.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
Signed-off-by: default avatarGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Signed-off-by: default avatarTim Gardner <tim.gardner@canonical.com>
Signed-off-by: default avatarThadeu Lima de Souza Cascardo <cascardo@canonical.com>
parent ffc84846
...@@ -1623,6 +1623,7 @@ static void fanout_release_data(struct packet_fanout *f) ...@@ -1623,6 +1623,7 @@ static void fanout_release_data(struct packet_fanout *f)
static int fanout_add(struct sock *sk, u16 id, u16 type_flags) static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
{ {
struct packet_rollover *rollover = NULL;
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
struct packet_fanout *f, *match; struct packet_fanout *f, *match;
u8 type = type_flags & 0xff; u8 type = type_flags & 0xff;
...@@ -1645,23 +1646,28 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1645,23 +1646,28 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
return -EINVAL; return -EINVAL;
} }
mutex_lock(&fanout_mutex);
err = -EINVAL;
if (!po->running) if (!po->running)
return -EINVAL; goto out;
err = -EALREADY;
if (po->fanout) if (po->fanout)
return -EALREADY; goto out;
if (type == PACKET_FANOUT_ROLLOVER || if (type == PACKET_FANOUT_ROLLOVER ||
(type_flags & PACKET_FANOUT_FLAG_ROLLOVER)) { (type_flags & PACKET_FANOUT_FLAG_ROLLOVER)) {
po->rollover = kzalloc(sizeof(*po->rollover), GFP_KERNEL); err = -ENOMEM;
if (!po->rollover) rollover = kzalloc(sizeof(*rollover), GFP_KERNEL);
return -ENOMEM; if (!rollover)
atomic_long_set(&po->rollover->num, 0); goto out;
atomic_long_set(&po->rollover->num_huge, 0); atomic_long_set(&rollover->num, 0);
atomic_long_set(&po->rollover->num_failed, 0); atomic_long_set(&rollover->num_huge, 0);
atomic_long_set(&rollover->num_failed, 0);
po->rollover = rollover;
} }
mutex_lock(&fanout_mutex);
match = NULL; match = NULL;
list_for_each_entry(f, &fanout_list, list) { list_for_each_entry(f, &fanout_list, list) {
if (f->id == id && if (f->id == id &&
...@@ -1708,11 +1714,11 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1708,11 +1714,11 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
} }
} }
out: out:
mutex_unlock(&fanout_mutex); if (err && rollover) {
if (err) { kfree(rollover);
kfree(po->rollover);
po->rollover = NULL; po->rollover = NULL;
} }
mutex_unlock(&fanout_mutex);
return err; return err;
} }
...@@ -1721,11 +1727,9 @@ static void fanout_release(struct sock *sk) ...@@ -1721,11 +1727,9 @@ static void fanout_release(struct sock *sk)
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
struct packet_fanout *f; struct packet_fanout *f;
f = po->fanout;
if (!f)
return;
mutex_lock(&fanout_mutex); mutex_lock(&fanout_mutex);
f = po->fanout;
if (f) {
po->fanout = NULL; po->fanout = NULL;
if (atomic_dec_and_test(&f->sk_ref)) { if (atomic_dec_and_test(&f->sk_ref)) {
...@@ -1734,10 +1738,11 @@ static void fanout_release(struct sock *sk) ...@@ -1734,10 +1738,11 @@ static void fanout_release(struct sock *sk)
fanout_release_data(f); fanout_release_data(f);
kfree(f); kfree(f);
} }
mutex_unlock(&fanout_mutex);
if (po->rollover) if (po->rollover)
kfree_rcu(po->rollover, rcu); kfree_rcu(po->rollover, rcu);
}
mutex_unlock(&fanout_mutex);
} }
static bool packet_extra_vlan_len_allowed(const struct net_device *dev, static bool packet_extra_vlan_len_allowed(const struct net_device *dev,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment