Commit 87e9f031 authored by WANG Cong's avatar WANG Cong Committed by David S. Miller

ipv4: fix a potential deadlock in mcast getsockopt() path

Sasha reported the following lockdep warning:

  Possible unsafe locking scenario:

        CPU0                    CPU1
        ----                    ----
   lock(sk_lock-AF_INET);
                                lock(rtnl_mutex);
                                lock(sk_lock-AF_INET);
   lock(rtnl_mutex);

This is due to that for IP_MSFILTER and MCAST_MSFILTER, we take
rtnl lock before the socket lock in setsockopt() path, but take
the socket lock before rtnl lock in getsockopt() path. All the
rest optnames are setsockopt()-only.

Fix this by aligning the getsockopt() path with the setsockopt()
path, so that all mcast socket path would be locked in the same
order.

Note, IPv6 part is different where rtnl lock is not held.

Fixes: 54ff9ef3 ("ipv4, ipv6: kill ip_mc_{join, leave}_group and ipv6_sock_mc_{join, drop}")
Reported-by: default avatarSasha Levin <sasha.levin@oracle.com>
Cc: Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Signed-off-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Reviewed-by: default avatarMarcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 4ee3bd4a
...@@ -2392,11 +2392,11 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, ...@@ -2392,11 +2392,11 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
struct ip_sf_socklist *psl; struct ip_sf_socklist *psl;
struct net *net = sock_net(sk); struct net *net = sock_net(sk);
ASSERT_RTNL();
if (!ipv4_is_multicast(addr)) if (!ipv4_is_multicast(addr))
return -EINVAL; return -EINVAL;
rtnl_lock();
imr.imr_multiaddr.s_addr = msf->imsf_multiaddr; imr.imr_multiaddr.s_addr = msf->imsf_multiaddr;
imr.imr_address.s_addr = msf->imsf_interface; imr.imr_address.s_addr = msf->imsf_interface;
imr.imr_ifindex = 0; imr.imr_ifindex = 0;
...@@ -2417,7 +2417,6 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, ...@@ -2417,7 +2417,6 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
goto done; goto done;
msf->imsf_fmode = pmc->sfmode; msf->imsf_fmode = pmc->sfmode;
psl = rtnl_dereference(pmc->sflist); psl = rtnl_dereference(pmc->sflist);
rtnl_unlock();
if (!psl) { if (!psl) {
len = 0; len = 0;
count = 0; count = 0;
...@@ -2436,7 +2435,6 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, ...@@ -2436,7 +2435,6 @@ int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf,
return -EFAULT; return -EFAULT;
return 0; return 0;
done: done:
rtnl_unlock();
return err; return err;
} }
...@@ -2450,6 +2448,8 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, ...@@ -2450,6 +2448,8 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
struct ip_sf_socklist *psl; struct ip_sf_socklist *psl;
ASSERT_RTNL();
psin = (struct sockaddr_in *)&gsf->gf_group; psin = (struct sockaddr_in *)&gsf->gf_group;
if (psin->sin_family != AF_INET) if (psin->sin_family != AF_INET)
return -EINVAL; return -EINVAL;
...@@ -2457,8 +2457,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, ...@@ -2457,8 +2457,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
if (!ipv4_is_multicast(addr)) if (!ipv4_is_multicast(addr))
return -EINVAL; return -EINVAL;
rtnl_lock();
err = -EADDRNOTAVAIL; err = -EADDRNOTAVAIL;
for_each_pmc_rtnl(inet, pmc) { for_each_pmc_rtnl(inet, pmc) {
...@@ -2470,7 +2468,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, ...@@ -2470,7 +2468,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
goto done; goto done;
gsf->gf_fmode = pmc->sfmode; gsf->gf_fmode = pmc->sfmode;
psl = rtnl_dereference(pmc->sflist); psl = rtnl_dereference(pmc->sflist);
rtnl_unlock();
count = psl ? psl->sl_count : 0; count = psl ? psl->sl_count : 0;
copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
gsf->gf_numsrc = count; gsf->gf_numsrc = count;
...@@ -2490,7 +2487,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, ...@@ -2490,7 +2487,6 @@ int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf,
} }
return 0; return 0;
done: done:
rtnl_unlock();
return err; return err;
} }
......
...@@ -1251,11 +1251,22 @@ EXPORT_SYMBOL(compat_ip_setsockopt); ...@@ -1251,11 +1251,22 @@ EXPORT_SYMBOL(compat_ip_setsockopt);
* the _received_ ones. The set sets the _sent_ ones. * the _received_ ones. The set sets the _sent_ ones.
*/ */
static bool getsockopt_needs_rtnl(int optname)
{
switch (optname) {
case IP_MSFILTER:
case MCAST_MSFILTER:
return true;
}
return false;
}
static int do_ip_getsockopt(struct sock *sk, int level, int optname, static int do_ip_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *optlen, unsigned int flags) char __user *optval, int __user *optlen, unsigned int flags)
{ {
struct inet_sock *inet = inet_sk(sk); struct inet_sock *inet = inet_sk(sk);
int val; bool needs_rtnl = getsockopt_needs_rtnl(optname);
int val, err = 0;
int len; int len;
if (level != SOL_IP) if (level != SOL_IP)
...@@ -1269,6 +1280,8 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1269,6 +1280,8 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
if (len < 0) if (len < 0)
return -EINVAL; return -EINVAL;
if (needs_rtnl)
rtnl_lock();
lock_sock(sk); lock_sock(sk);
switch (optname) { switch (optname) {
...@@ -1386,39 +1399,35 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1386,39 +1399,35 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
case IP_MSFILTER: case IP_MSFILTER:
{ {
struct ip_msfilter msf; struct ip_msfilter msf;
int err;
if (len < IP_MSFILTER_SIZE(0)) { if (len < IP_MSFILTER_SIZE(0)) {
release_sock(sk); err = -EINVAL;
return -EINVAL; goto out;
} }
if (copy_from_user(&msf, optval, IP_MSFILTER_SIZE(0))) { if (copy_from_user(&msf, optval, IP_MSFILTER_SIZE(0))) {
release_sock(sk); err = -EFAULT;
return -EFAULT; goto out;
} }
err = ip_mc_msfget(sk, &msf, err = ip_mc_msfget(sk, &msf,
(struct ip_msfilter __user *)optval, optlen); (struct ip_msfilter __user *)optval, optlen);
release_sock(sk); goto out;
return err;
} }
case MCAST_MSFILTER: case MCAST_MSFILTER:
{ {
struct group_filter gsf; struct group_filter gsf;
int err;
if (len < GROUP_FILTER_SIZE(0)) { if (len < GROUP_FILTER_SIZE(0)) {
release_sock(sk); err = -EINVAL;
return -EINVAL; goto out;
} }
if (copy_from_user(&gsf, optval, GROUP_FILTER_SIZE(0))) { if (copy_from_user(&gsf, optval, GROUP_FILTER_SIZE(0))) {
release_sock(sk); err = -EFAULT;
return -EFAULT; goto out;
} }
err = ip_mc_gsfget(sk, &gsf, err = ip_mc_gsfget(sk, &gsf,
(struct group_filter __user *)optval, (struct group_filter __user *)optval,
optlen); optlen);
release_sock(sk); goto out;
return err;
} }
case IP_MULTICAST_ALL: case IP_MULTICAST_ALL:
val = inet->mc_all; val = inet->mc_all;
...@@ -1485,6 +1494,12 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname, ...@@ -1485,6 +1494,12 @@ static int do_ip_getsockopt(struct sock *sk, int level, int optname,
return -EFAULT; return -EFAULT;
} }
return 0; return 0;
out:
release_sock(sk);
if (needs_rtnl)
rtnl_unlock();
return err;
} }
int ip_getsockopt(struct sock *sk, int level, int ip_getsockopt(struct sock *sk, int level,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment