Commit e8051688 authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

bridge: add RCU annotation to bridge multicast table

Add modern __rcu annotatations to bridge multicast table.
Use newer hlist macros to avoid direct access to hlist internals.
Signed-off-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarStephen Hemminger <shemminger@vyatta.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 8a22c99a
...@@ -223,7 +223,7 @@ static void br_multicast_flood(struct net_bridge_mdb_entry *mdst, ...@@ -223,7 +223,7 @@ static void br_multicast_flood(struct net_bridge_mdb_entry *mdst,
struct net_bridge_port_group *p; struct net_bridge_port_group *p;
struct hlist_node *rp; struct hlist_node *rp;
rp = rcu_dereference(br->router_list.first); rp = rcu_dereference(hlist_first_rcu(&br->router_list));
p = mdst ? rcu_dereference(mdst->ports) : NULL; p = mdst ? rcu_dereference(mdst->ports) : NULL;
while (p || rp) { while (p || rp) {
struct net_bridge_port *port, *lport, *rport; struct net_bridge_port *port, *lport, *rport;
...@@ -242,7 +242,7 @@ static void br_multicast_flood(struct net_bridge_mdb_entry *mdst, ...@@ -242,7 +242,7 @@ static void br_multicast_flood(struct net_bridge_mdb_entry *mdst,
if ((unsigned long)lport >= (unsigned long)port) if ((unsigned long)lport >= (unsigned long)port)
p = rcu_dereference(p->next); p = rcu_dereference(p->next);
if ((unsigned long)rport >= (unsigned long)port) if ((unsigned long)rport >= (unsigned long)port)
rp = rcu_dereference(rp->next); rp = rcu_dereference(hlist_next_rcu(rp));
} }
if (!prev) if (!prev)
......
...@@ -33,6 +33,9 @@ ...@@ -33,6 +33,9 @@
#include "br_private.h" #include "br_private.h"
#define mlock_dereference(X, br) \
rcu_dereference_protected(X, lockdep_is_held(&br->multicast_lock))
#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
static inline int ipv6_is_local_multicast(const struct in6_addr *addr) static inline int ipv6_is_local_multicast(const struct in6_addr *addr)
{ {
...@@ -135,7 +138,7 @@ static struct net_bridge_mdb_entry *br_mdb_ip6_get( ...@@ -135,7 +138,7 @@ static struct net_bridge_mdb_entry *br_mdb_ip6_get(
struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br, struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
struct sk_buff *skb) struct sk_buff *skb)
{ {
struct net_bridge_mdb_htable *mdb = br->mdb; struct net_bridge_mdb_htable *mdb = rcu_dereference(br->mdb);
struct br_ip ip; struct br_ip ip;
if (br->multicast_disabled) if (br->multicast_disabled)
...@@ -235,7 +238,8 @@ static void br_multicast_group_expired(unsigned long data) ...@@ -235,7 +238,8 @@ static void br_multicast_group_expired(unsigned long data)
if (mp->ports) if (mp->ports)
goto out; goto out;
mdb = br->mdb; mdb = mlock_dereference(br->mdb, br);
hlist_del_rcu(&mp->hlist[mdb->ver]); hlist_del_rcu(&mp->hlist[mdb->ver]);
mdb->size--; mdb->size--;
...@@ -249,16 +253,20 @@ static void br_multicast_group_expired(unsigned long data) ...@@ -249,16 +253,20 @@ static void br_multicast_group_expired(unsigned long data)
static void br_multicast_del_pg(struct net_bridge *br, static void br_multicast_del_pg(struct net_bridge *br,
struct net_bridge_port_group *pg) struct net_bridge_port_group *pg)
{ {
struct net_bridge_mdb_htable *mdb = br->mdb; struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
struct net_bridge_port_group *p; struct net_bridge_port_group *p;
struct net_bridge_port_group **pp; struct net_bridge_port_group __rcu **pp;
mdb = mlock_dereference(br->mdb, br);
mp = br_mdb_ip_get(mdb, &pg->addr); mp = br_mdb_ip_get(mdb, &pg->addr);
if (WARN_ON(!mp)) if (WARN_ON(!mp))
return; return;
for (pp = &mp->ports; (p = *pp); pp = &p->next) { for (pp = &mp->ports;
(p = mlock_dereference(*pp, br)) != NULL;
pp = &p->next) {
if (p != pg) if (p != pg)
continue; continue;
...@@ -294,10 +302,10 @@ static void br_multicast_port_group_expired(unsigned long data) ...@@ -294,10 +302,10 @@ static void br_multicast_port_group_expired(unsigned long data)
spin_unlock(&br->multicast_lock); spin_unlock(&br->multicast_lock);
} }
static int br_mdb_rehash(struct net_bridge_mdb_htable **mdbp, int max, static int br_mdb_rehash(struct net_bridge_mdb_htable __rcu **mdbp, int max,
int elasticity) int elasticity)
{ {
struct net_bridge_mdb_htable *old = *mdbp; struct net_bridge_mdb_htable *old = rcu_dereference_protected(*mdbp, 1);
struct net_bridge_mdb_htable *mdb; struct net_bridge_mdb_htable *mdb;
int err; int err;
...@@ -569,7 +577,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group( ...@@ -569,7 +577,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
struct net_bridge *br, struct net_bridge_port *port, struct net_bridge *br, struct net_bridge_port *port,
struct br_ip *group, int hash) struct br_ip *group, int hash)
{ {
struct net_bridge_mdb_htable *mdb = br->mdb; struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
struct hlist_node *p; struct hlist_node *p;
unsigned count = 0; unsigned count = 0;
...@@ -577,6 +585,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group( ...@@ -577,6 +585,7 @@ static struct net_bridge_mdb_entry *br_multicast_get_group(
int elasticity; int elasticity;
int err; int err;
mdb = rcu_dereference_protected(br->mdb, 1);
hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) { hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) {
count++; count++;
if (unlikely(br_ip_equal(group, &mp->addr))) if (unlikely(br_ip_equal(group, &mp->addr)))
...@@ -642,10 +651,11 @@ static struct net_bridge_mdb_entry *br_multicast_new_group( ...@@ -642,10 +651,11 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
struct net_bridge *br, struct net_bridge_port *port, struct net_bridge *br, struct net_bridge_port *port,
struct br_ip *group) struct br_ip *group)
{ {
struct net_bridge_mdb_htable *mdb = br->mdb; struct net_bridge_mdb_htable *mdb;
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
int hash; int hash;
mdb = rcu_dereference_protected(br->mdb, 1);
if (!mdb) { if (!mdb) {
if (br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0)) if (br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0))
return NULL; return NULL;
...@@ -660,7 +670,7 @@ static struct net_bridge_mdb_entry *br_multicast_new_group( ...@@ -660,7 +670,7 @@ static struct net_bridge_mdb_entry *br_multicast_new_group(
case -EAGAIN: case -EAGAIN:
rehash: rehash:
mdb = br->mdb; mdb = rcu_dereference_protected(br->mdb, 1);
hash = br_ip_hash(mdb, group); hash = br_ip_hash(mdb, group);
break; break;
...@@ -692,7 +702,7 @@ static int br_multicast_add_group(struct net_bridge *br, ...@@ -692,7 +702,7 @@ static int br_multicast_add_group(struct net_bridge *br,
{ {
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
struct net_bridge_port_group *p; struct net_bridge_port_group *p;
struct net_bridge_port_group **pp; struct net_bridge_port_group __rcu **pp;
unsigned long now = jiffies; unsigned long now = jiffies;
int err; int err;
...@@ -712,7 +722,9 @@ static int br_multicast_add_group(struct net_bridge *br, ...@@ -712,7 +722,9 @@ static int br_multicast_add_group(struct net_bridge *br,
goto out; goto out;
} }
for (pp = &mp->ports; (p = *pp); pp = &p->next) { for (pp = &mp->ports;
(p = mlock_dereference(*pp, br)) != NULL;
pp = &p->next) {
if (p->port == port) if (p->port == port)
goto found; goto found;
if ((unsigned long)p->port < (unsigned long)port) if ((unsigned long)p->port < (unsigned long)port)
...@@ -1106,7 +1118,7 @@ static int br_ip4_multicast_query(struct net_bridge *br, ...@@ -1106,7 +1118,7 @@ static int br_ip4_multicast_query(struct net_bridge *br,
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
struct igmpv3_query *ih3; struct igmpv3_query *ih3;
struct net_bridge_port_group *p; struct net_bridge_port_group *p;
struct net_bridge_port_group **pp; struct net_bridge_port_group __rcu **pp;
unsigned long max_delay; unsigned long max_delay;
unsigned long now = jiffies; unsigned long now = jiffies;
__be32 group; __be32 group;
...@@ -1145,7 +1157,7 @@ static int br_ip4_multicast_query(struct net_bridge *br, ...@@ -1145,7 +1157,7 @@ static int br_ip4_multicast_query(struct net_bridge *br,
if (!group) if (!group)
goto out; goto out;
mp = br_mdb_ip4_get(br->mdb, group); mp = br_mdb_ip4_get(mlock_dereference(br->mdb, br), group);
if (!mp) if (!mp)
goto out; goto out;
...@@ -1157,7 +1169,9 @@ static int br_ip4_multicast_query(struct net_bridge *br, ...@@ -1157,7 +1169,9 @@ static int br_ip4_multicast_query(struct net_bridge *br,
try_to_del_timer_sync(&mp->timer) >= 0)) try_to_del_timer_sync(&mp->timer) >= 0))
mod_timer(&mp->timer, now + max_delay); mod_timer(&mp->timer, now + max_delay);
for (pp = &mp->ports; (p = *pp); pp = &p->next) { for (pp = &mp->ports;
(p = mlock_dereference(*pp, br)) != NULL;
pp = &p->next) {
if (timer_pending(&p->timer) ? if (timer_pending(&p->timer) ?
time_after(p->timer.expires, now + max_delay) : time_after(p->timer.expires, now + max_delay) :
try_to_del_timer_sync(&p->timer) >= 0) try_to_del_timer_sync(&p->timer) >= 0)
...@@ -1178,7 +1192,8 @@ static int br_ip6_multicast_query(struct net_bridge *br, ...@@ -1178,7 +1192,8 @@ static int br_ip6_multicast_query(struct net_bridge *br,
struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb); struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb);
struct net_bridge_mdb_entry *mp; struct net_bridge_mdb_entry *mp;
struct mld2_query *mld2q; struct mld2_query *mld2q;
struct net_bridge_port_group *p, **pp; struct net_bridge_port_group *p;
struct net_bridge_port_group __rcu **pp;
unsigned long max_delay; unsigned long max_delay;
unsigned long now = jiffies; unsigned long now = jiffies;
struct in6_addr *group = NULL; struct in6_addr *group = NULL;
...@@ -1214,7 +1229,7 @@ static int br_ip6_multicast_query(struct net_bridge *br, ...@@ -1214,7 +1229,7 @@ static int br_ip6_multicast_query(struct net_bridge *br,
if (!group) if (!group)
goto out; goto out;
mp = br_mdb_ip6_get(br->mdb, group); mp = br_mdb_ip6_get(mlock_dereference(br->mdb, br), group);
if (!mp) if (!mp)
goto out; goto out;
...@@ -1225,7 +1240,9 @@ static int br_ip6_multicast_query(struct net_bridge *br, ...@@ -1225,7 +1240,9 @@ static int br_ip6_multicast_query(struct net_bridge *br,
try_to_del_timer_sync(&mp->timer) >= 0)) try_to_del_timer_sync(&mp->timer) >= 0))
mod_timer(&mp->timer, now + max_delay); mod_timer(&mp->timer, now + max_delay);
for (pp = &mp->ports; (p = *pp); pp = &p->next) { for (pp = &mp->ports;
(p = mlock_dereference(*pp, br)) != NULL;
pp = &p->next) {
if (timer_pending(&p->timer) ? if (timer_pending(&p->timer) ?
time_after(p->timer.expires, now + max_delay) : time_after(p->timer.expires, now + max_delay) :
try_to_del_timer_sync(&p->timer) >= 0) try_to_del_timer_sync(&p->timer) >= 0)
...@@ -1254,7 +1271,7 @@ static void br_multicast_leave_group(struct net_bridge *br, ...@@ -1254,7 +1271,7 @@ static void br_multicast_leave_group(struct net_bridge *br,
timer_pending(&br->multicast_querier_timer)) timer_pending(&br->multicast_querier_timer))
goto out; goto out;
mdb = br->mdb; mdb = mlock_dereference(br->mdb, br);
mp = br_mdb_ip_get(mdb, group); mp = br_mdb_ip_get(mdb, group);
if (!mp) if (!mp)
goto out; goto out;
...@@ -1277,7 +1294,9 @@ static void br_multicast_leave_group(struct net_bridge *br, ...@@ -1277,7 +1294,9 @@ static void br_multicast_leave_group(struct net_bridge *br,
goto out; goto out;
} }
for (p = mp->ports; p; p = p->next) { for (p = mlock_dereference(mp->ports, br);
p != NULL;
p = mlock_dereference(p->next, br)) {
if (p->port != port) if (p->port != port)
continue; continue;
...@@ -1625,7 +1644,7 @@ void br_multicast_stop(struct net_bridge *br) ...@@ -1625,7 +1644,7 @@ void br_multicast_stop(struct net_bridge *br)
del_timer_sync(&br->multicast_query_timer); del_timer_sync(&br->multicast_query_timer);
spin_lock_bh(&br->multicast_lock); spin_lock_bh(&br->multicast_lock);
mdb = br->mdb; mdb = mlock_dereference(br->mdb, br);
if (!mdb) if (!mdb)
goto out; goto out;
...@@ -1729,6 +1748,7 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val) ...@@ -1729,6 +1748,7 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
{ {
struct net_bridge_port *port; struct net_bridge_port *port;
int err = 0; int err = 0;
struct net_bridge_mdb_htable *mdb;
spin_lock(&br->multicast_lock); spin_lock(&br->multicast_lock);
if (br->multicast_disabled == !val) if (br->multicast_disabled == !val)
...@@ -1741,15 +1761,16 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val) ...@@ -1741,15 +1761,16 @@ int br_multicast_toggle(struct net_bridge *br, unsigned long val)
if (!netif_running(br->dev)) if (!netif_running(br->dev))
goto unlock; goto unlock;
if (br->mdb) { mdb = mlock_dereference(br->mdb, br);
if (br->mdb->old) { if (mdb) {
if (mdb->old) {
err = -EEXIST; err = -EEXIST;
rollback: rollback:
br->multicast_disabled = !!val; br->multicast_disabled = !!val;
goto unlock; goto unlock;
} }
err = br_mdb_rehash(&br->mdb, br->mdb->max, err = br_mdb_rehash(&br->mdb, mdb->max,
br->hash_elasticity); br->hash_elasticity);
if (err) if (err)
goto rollback; goto rollback;
...@@ -1774,6 +1795,7 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val) ...@@ -1774,6 +1795,7 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
{ {
int err = -ENOENT; int err = -ENOENT;
u32 old; u32 old;
struct net_bridge_mdb_htable *mdb;
spin_lock(&br->multicast_lock); spin_lock(&br->multicast_lock);
if (!netif_running(br->dev)) if (!netif_running(br->dev))
...@@ -1782,7 +1804,9 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val) ...@@ -1782,7 +1804,9 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
err = -EINVAL; err = -EINVAL;
if (!is_power_of_2(val)) if (!is_power_of_2(val))
goto unlock; goto unlock;
if (br->mdb && val < br->mdb->size)
mdb = mlock_dereference(br->mdb, br);
if (mdb && val < mdb->size)
goto unlock; goto unlock;
err = 0; err = 0;
...@@ -1790,8 +1814,8 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val) ...@@ -1790,8 +1814,8 @@ int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
old = br->hash_max; old = br->hash_max;
br->hash_max = val; br->hash_max = val;
if (br->mdb) { if (mdb) {
if (br->mdb->old) { if (mdb->old) {
err = -EEXIST; err = -EEXIST;
rollback: rollback:
br->hash_max = old; br->hash_max = old;
......
...@@ -72,7 +72,7 @@ struct net_bridge_fdb_entry ...@@ -72,7 +72,7 @@ struct net_bridge_fdb_entry
struct net_bridge_port_group { struct net_bridge_port_group {
struct net_bridge_port *port; struct net_bridge_port *port;
struct net_bridge_port_group *next; struct net_bridge_port_group __rcu *next;
struct hlist_node mglist; struct hlist_node mglist;
struct rcu_head rcu; struct rcu_head rcu;
struct timer_list timer; struct timer_list timer;
...@@ -86,7 +86,7 @@ struct net_bridge_mdb_entry ...@@ -86,7 +86,7 @@ struct net_bridge_mdb_entry
struct hlist_node hlist[2]; struct hlist_node hlist[2];
struct hlist_node mglist; struct hlist_node mglist;
struct net_bridge *br; struct net_bridge *br;
struct net_bridge_port_group *ports; struct net_bridge_port_group __rcu *ports;
struct rcu_head rcu; struct rcu_head rcu;
struct timer_list timer; struct timer_list timer;
struct timer_list query_timer; struct timer_list query_timer;
...@@ -227,7 +227,7 @@ struct net_bridge ...@@ -227,7 +227,7 @@ struct net_bridge
unsigned long multicast_startup_query_interval; unsigned long multicast_startup_query_interval;
spinlock_t multicast_lock; spinlock_t multicast_lock;
struct net_bridge_mdb_htable *mdb; struct net_bridge_mdb_htable __rcu *mdb;
struct hlist_head router_list; struct hlist_head router_list;
struct hlist_head mglist; struct hlist_head mglist;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment