Commit 456b61bc authored by Eric Dumazet's avatar Eric Dumazet Committed by David S. Miller

ipv6: mcast: RCU conversion

ipv6_sk_mc_lock rwlock becomes a spinlock.

readers (inet6_mc_check()) now takes rcu_read_lock() instead of read
lock. Writers dont need to disable BH anymore.

struct ipv6_mc_socklist objects are reclaimed after one RCU grace
period.
Signed-off-by: default avatarEric Dumazet <eric.dumazet@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 2757a15f
...@@ -364,7 +364,7 @@ struct ipv6_pinfo { ...@@ -364,7 +364,7 @@ struct ipv6_pinfo {
__u32 dst_cookie; __u32 dst_cookie;
struct ipv6_mc_socklist *ipv6_mc_list; struct ipv6_mc_socklist __rcu *ipv6_mc_list;
struct ipv6_ac_socklist *ipv6_ac_list; struct ipv6_ac_socklist *ipv6_ac_list;
struct ipv6_fl_socklist *ipv6_fl_list; struct ipv6_fl_socklist *ipv6_fl_list;
......
...@@ -89,10 +89,11 @@ struct ip6_sf_socklist { ...@@ -89,10 +89,11 @@ struct ip6_sf_socklist {
struct ipv6_mc_socklist { struct ipv6_mc_socklist {
struct in6_addr addr; struct in6_addr addr;
int ifindex; int ifindex;
struct ipv6_mc_socklist *next; struct ipv6_mc_socklist __rcu *next;
rwlock_t sflock; rwlock_t sflock;
unsigned int sfmode; /* MCAST_{INCLUDE,EXCLUDE} */ unsigned int sfmode; /* MCAST_{INCLUDE,EXCLUDE} */
struct ip6_sf_socklist *sflist; struct ip6_sf_socklist *sflist;
struct rcu_head rcu;
}; };
struct ip6_sf_list { struct ip6_sf_list {
......
...@@ -82,7 +82,7 @@ static void *__mld2_query_bugs[] __attribute__((__unused__)) = { ...@@ -82,7 +82,7 @@ static void *__mld2_query_bugs[] __attribute__((__unused__)) = {
static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT; static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT;
/* Big mc list lock for all the sockets */ /* Big mc list lock for all the sockets */
static DEFINE_RWLOCK(ipv6_sk_mc_lock); static DEFINE_SPINLOCK(ipv6_sk_mc_lock);
static void igmp6_join_group(struct ifmcaddr6 *ma); static void igmp6_join_group(struct ifmcaddr6 *ma);
static void igmp6_leave_group(struct ifmcaddr6 *ma); static void igmp6_leave_group(struct ifmcaddr6 *ma);
...@@ -123,6 +123,11 @@ int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF; ...@@ -123,6 +123,11 @@ int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF;
* socket join on multicast group * socket join on multicast group
*/ */
#define for_each_pmc_rcu(np, pmc) \
for (pmc = rcu_dereference(np->ipv6_mc_list); \
pmc != NULL; \
pmc = rcu_dereference(pmc->next))
int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr) int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
{ {
struct net_device *dev = NULL; struct net_device *dev = NULL;
...@@ -134,15 +139,15 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr) ...@@ -134,15 +139,15 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
if (!ipv6_addr_is_multicast(addr)) if (!ipv6_addr_is_multicast(addr))
return -EINVAL; return -EINVAL;
read_lock_bh(&ipv6_sk_mc_lock); rcu_read_lock();
for (mc_lst=np->ipv6_mc_list; mc_lst; mc_lst=mc_lst->next) { for_each_pmc_rcu(np, mc_lst) {
if ((ifindex == 0 || mc_lst->ifindex == ifindex) && if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
ipv6_addr_equal(&mc_lst->addr, addr)) { ipv6_addr_equal(&mc_lst->addr, addr)) {
read_unlock_bh(&ipv6_sk_mc_lock); rcu_read_unlock();
return -EADDRINUSE; return -EADDRINUSE;
} }
} }
read_unlock_bh(&ipv6_sk_mc_lock); rcu_read_unlock();
mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL); mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL);
...@@ -186,33 +191,41 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr) ...@@ -186,33 +191,41 @@ int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr)
return err; return err;
} }
write_lock_bh(&ipv6_sk_mc_lock); spin_lock(&ipv6_sk_mc_lock);
mc_lst->next = np->ipv6_mc_list; mc_lst->next = np->ipv6_mc_list;
np->ipv6_mc_list = mc_lst; rcu_assign_pointer(np->ipv6_mc_list, mc_lst);
write_unlock_bh(&ipv6_sk_mc_lock); spin_unlock(&ipv6_sk_mc_lock);
rcu_read_unlock(); rcu_read_unlock();
return 0; return 0;
} }
static void ipv6_mc_socklist_reclaim(struct rcu_head *head)
{
kfree(container_of(head, struct ipv6_mc_socklist, rcu));
}
/* /*
* socket leave on multicast group * socket leave on multicast group
*/ */
int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
{ {
struct ipv6_pinfo *np = inet6_sk(sk); struct ipv6_pinfo *np = inet6_sk(sk);
struct ipv6_mc_socklist *mc_lst, **lnk; struct ipv6_mc_socklist *mc_lst;
struct ipv6_mc_socklist __rcu **lnk;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
write_lock_bh(&ipv6_sk_mc_lock); spin_lock(&ipv6_sk_mc_lock);
for (lnk = &np->ipv6_mc_list; (mc_lst = *lnk) !=NULL ; lnk = &mc_lst->next) { for (lnk = &np->ipv6_mc_list;
(mc_lst = rcu_dereference_protected(*lnk,
lockdep_is_held(&ipv6_sk_mc_lock))) !=NULL ;
lnk = &mc_lst->next) {
if ((ifindex == 0 || mc_lst->ifindex == ifindex) && if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
ipv6_addr_equal(&mc_lst->addr, addr)) { ipv6_addr_equal(&mc_lst->addr, addr)) {
struct net_device *dev; struct net_device *dev;
*lnk = mc_lst->next; *lnk = mc_lst->next;
write_unlock_bh(&ipv6_sk_mc_lock); spin_unlock(&ipv6_sk_mc_lock);
rcu_read_lock(); rcu_read_lock();
dev = dev_get_by_index_rcu(net, mc_lst->ifindex); dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
...@@ -225,11 +238,12 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) ...@@ -225,11 +238,12 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
} else } else
(void) ip6_mc_leave_src(sk, mc_lst, NULL); (void) ip6_mc_leave_src(sk, mc_lst, NULL);
rcu_read_unlock(); rcu_read_unlock();
sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);
return 0; return 0;
} }
} }
write_unlock_bh(&ipv6_sk_mc_lock); spin_unlock(&ipv6_sk_mc_lock);
return -EADDRNOTAVAIL; return -EADDRNOTAVAIL;
} }
...@@ -272,12 +286,13 @@ void ipv6_sock_mc_close(struct sock *sk) ...@@ -272,12 +286,13 @@ void ipv6_sock_mc_close(struct sock *sk)
struct ipv6_mc_socklist *mc_lst; struct ipv6_mc_socklist *mc_lst;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
write_lock_bh(&ipv6_sk_mc_lock); spin_lock(&ipv6_sk_mc_lock);
while ((mc_lst = np->ipv6_mc_list) != NULL) { while ((mc_lst = rcu_dereference_protected(np->ipv6_mc_list,
lockdep_is_held(&ipv6_sk_mc_lock))) != NULL) {
struct net_device *dev; struct net_device *dev;
np->ipv6_mc_list = mc_lst->next; np->ipv6_mc_list = mc_lst->next;
write_unlock_bh(&ipv6_sk_mc_lock); spin_unlock(&ipv6_sk_mc_lock);
rcu_read_lock(); rcu_read_lock();
dev = dev_get_by_index_rcu(net, mc_lst->ifindex); dev = dev_get_by_index_rcu(net, mc_lst->ifindex);
...@@ -290,11 +305,13 @@ void ipv6_sock_mc_close(struct sock *sk) ...@@ -290,11 +305,13 @@ void ipv6_sock_mc_close(struct sock *sk)
} else } else
(void) ip6_mc_leave_src(sk, mc_lst, NULL); (void) ip6_mc_leave_src(sk, mc_lst, NULL);
rcu_read_unlock(); rcu_read_unlock();
sock_kfree_s(sk, mc_lst, sizeof(*mc_lst));
write_lock_bh(&ipv6_sk_mc_lock); atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc);
call_rcu(&mc_lst->rcu, ipv6_mc_socklist_reclaim);
spin_lock(&ipv6_sk_mc_lock);
} }
write_unlock_bh(&ipv6_sk_mc_lock); spin_unlock(&ipv6_sk_mc_lock);
} }
int ip6_mc_source(int add, int omode, struct sock *sk, int ip6_mc_source(int add, int omode, struct sock *sk,
...@@ -328,8 +345,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk, ...@@ -328,8 +345,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
err = -EADDRNOTAVAIL; err = -EADDRNOTAVAIL;
read_lock(&ipv6_sk_mc_lock); for_each_pmc_rcu(inet6, pmc) {
for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) {
if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface) if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
continue; continue;
if (ipv6_addr_equal(&pmc->addr, group)) if (ipv6_addr_equal(&pmc->addr, group))
...@@ -428,7 +444,6 @@ int ip6_mc_source(int add, int omode, struct sock *sk, ...@@ -428,7 +444,6 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
done: done:
if (pmclocked) if (pmclocked)
write_unlock(&pmc->sflock); write_unlock(&pmc->sflock);
read_unlock(&ipv6_sk_mc_lock);
read_unlock_bh(&idev->lock); read_unlock_bh(&idev->lock);
rcu_read_unlock(); rcu_read_unlock();
if (leavegroup) if (leavegroup)
...@@ -466,14 +481,13 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf) ...@@ -466,14 +481,13 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
dev = idev->dev; dev = idev->dev;
err = 0; err = 0;
read_lock(&ipv6_sk_mc_lock);
if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) { if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) {
leavegroup = 1; leavegroup = 1;
goto done; goto done;
} }
for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) { for_each_pmc_rcu(inet6, pmc) {
if (pmc->ifindex != gsf->gf_interface) if (pmc->ifindex != gsf->gf_interface)
continue; continue;
if (ipv6_addr_equal(&pmc->addr, group)) if (ipv6_addr_equal(&pmc->addr, group))
...@@ -521,7 +535,6 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf) ...@@ -521,7 +535,6 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf)
write_unlock(&pmc->sflock); write_unlock(&pmc->sflock);
err = 0; err = 0;
done: done:
read_unlock(&ipv6_sk_mc_lock);
read_unlock_bh(&idev->lock); read_unlock_bh(&idev->lock);
rcu_read_unlock(); rcu_read_unlock();
if (leavegroup) if (leavegroup)
...@@ -562,7 +575,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, ...@@ -562,7 +575,7 @@ int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
* so reading the list is safe. * so reading the list is safe.
*/ */
for (pmc=inet6->ipv6_mc_list; pmc; pmc=pmc->next) { for_each_pmc_rcu(inet6, pmc) {
if (pmc->ifindex != gsf->gf_interface) if (pmc->ifindex != gsf->gf_interface)
continue; continue;
if (ipv6_addr_equal(group, &pmc->addr)) if (ipv6_addr_equal(group, &pmc->addr))
...@@ -612,13 +625,13 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr, ...@@ -612,13 +625,13 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
struct ip6_sf_socklist *psl; struct ip6_sf_socklist *psl;
int rv = 1; int rv = 1;
read_lock(&ipv6_sk_mc_lock); rcu_read_lock();
for (mc = np->ipv6_mc_list; mc; mc = mc->next) { for_each_pmc_rcu(np, mc) {
if (ipv6_addr_equal(&mc->addr, mc_addr)) if (ipv6_addr_equal(&mc->addr, mc_addr))
break; break;
} }
if (!mc) { if (!mc) {
read_unlock(&ipv6_sk_mc_lock); rcu_read_unlock();
return 1; return 1;
} }
read_lock(&mc->sflock); read_lock(&mc->sflock);
...@@ -638,7 +651,7 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr, ...@@ -638,7 +651,7 @@ int inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr,
rv = 0; rv = 0;
} }
read_unlock(&mc->sflock); read_unlock(&mc->sflock);
read_unlock(&ipv6_sk_mc_lock); rcu_read_unlock();
return rv; return rv;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment