Commit f8978bd9 authored by Leon Romanovsky's avatar Leon Romanovsky Committed by Jason Gunthorpe

RDMA/netlink: Fix locking around __ib_get_device_by_index

Holding locks is mandatory when calling __ib_device_get_by_index,
otherwise there are races during the list iteration with device removal.

Since the locks are static to device.c, __ib_device_get_by_index can
never be called correctly by any user out side the file.

Make the function static and provide a safe function that gets the
correct locks and returns a kref'd pointer. Fix all callers.

Fixes: e5c9469e ("RDMA/netlink: Add nldev device doit implementation")
Fixes: c3f66f7b ("RDMA/netlink: Implement nldev port doit callback")
Fixes: 7d02f605 ("RDMA/netlink: Add nldev port dumpit implementation")
Reviewed-by: default avatarMark Bloch <markb@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 16ba3def
...@@ -314,7 +314,7 @@ static inline int ib_mad_enforce_security(struct ib_mad_agent_private *map, ...@@ -314,7 +314,7 @@ static inline int ib_mad_enforce_security(struct ib_mad_agent_private *map,
} }
#endif #endif
struct ib_device *__ib_device_get_by_index(u32 ifindex); struct ib_device *ib_device_get_by_index(u32 ifindex);
/* RDMA device netlink */ /* RDMA device netlink */
void nldev_init(void); void nldev_init(void);
void nldev_exit(void); void nldev_exit(void);
......
...@@ -134,7 +134,7 @@ static int ib_device_check_mandatory(struct ib_device *device) ...@@ -134,7 +134,7 @@ static int ib_device_check_mandatory(struct ib_device *device)
return 0; return 0;
} }
struct ib_device *__ib_device_get_by_index(u32 index) static struct ib_device *__ib_device_get_by_index(u32 index)
{ {
struct ib_device *device; struct ib_device *device;
...@@ -145,6 +145,22 @@ struct ib_device *__ib_device_get_by_index(u32 index) ...@@ -145,6 +145,22 @@ struct ib_device *__ib_device_get_by_index(u32 index)
return NULL; return NULL;
} }
/*
* Caller is responsible to return refrerence count by calling put_device()
*/
struct ib_device *ib_device_get_by_index(u32 index)
{
struct ib_device *device;
down_read(&lists_rwsem);
device = __ib_device_get_by_index(index);
if (device)
get_device(&device->dev);
up_read(&lists_rwsem);
return device;
}
static struct ib_device *__ib_device_get_by_name(const char *name) static struct ib_device *__ib_device_get_by_name(const char *name)
{ {
struct ib_device *device; struct ib_device *device;
......
...@@ -142,27 +142,34 @@ static int nldev_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -142,27 +142,34 @@ static int nldev_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh,
index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]);
device = __ib_device_get_by_index(index); device = ib_device_get_by_index(index);
if (!device) if (!device)
return -EINVAL; return -EINVAL;
msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
if (!msg) if (!msg) {
return -ENOMEM; err = -ENOMEM;
goto err;
}
nlh = nlmsg_put(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq, nlh = nlmsg_put(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq,
RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_GET), RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_GET),
0, 0); 0, 0);
err = fill_dev_info(msg, device); err = fill_dev_info(msg, device);
if (err) { if (err)
nlmsg_free(msg); goto err_free;
return err;
}
nlmsg_end(msg, nlh); nlmsg_end(msg, nlh);
put_device(&device->dev);
return rdma_nl_unicast(msg, NETLINK_CB(skb).portid); return rdma_nl_unicast(msg, NETLINK_CB(skb).portid);
err_free:
nlmsg_free(msg);
err:
put_device(&device->dev);
return err;
} }
static int _nldev_get_dumpit(struct ib_device *device, static int _nldev_get_dumpit(struct ib_device *device,
...@@ -220,31 +227,40 @@ static int nldev_port_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh, ...@@ -220,31 +227,40 @@ static int nldev_port_get_doit(struct sk_buff *skb, struct nlmsghdr *nlh,
return -EINVAL; return -EINVAL;
index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); index = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]);
device = __ib_device_get_by_index(index); device = ib_device_get_by_index(index);
if (!device) if (!device)
return -EINVAL; return -EINVAL;
port = nla_get_u32(tb[RDMA_NLDEV_ATTR_PORT_INDEX]); port = nla_get_u32(tb[RDMA_NLDEV_ATTR_PORT_INDEX]);
if (!rdma_is_port_valid(device, port)) if (!rdma_is_port_valid(device, port)) {
return -EINVAL; err = -EINVAL;
goto err;
}
msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
if (!msg) if (!msg) {
return -ENOMEM; err = -ENOMEM;
goto err;
}
nlh = nlmsg_put(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq, nlh = nlmsg_put(msg, NETLINK_CB(skb).portid, nlh->nlmsg_seq,
RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_GET), RDMA_NL_GET_TYPE(RDMA_NL_NLDEV, RDMA_NLDEV_CMD_GET),
0, 0); 0, 0);
err = fill_port_info(msg, device, port); err = fill_port_info(msg, device, port);
if (err) { if (err)
nlmsg_free(msg); goto err_free;
return err;
}
nlmsg_end(msg, nlh); nlmsg_end(msg, nlh);
put_device(&device->dev);
return rdma_nl_unicast(msg, NETLINK_CB(skb).portid); return rdma_nl_unicast(msg, NETLINK_CB(skb).portid);
err_free:
nlmsg_free(msg);
err:
put_device(&device->dev);
return err;
} }
static int nldev_port_get_dumpit(struct sk_buff *skb, static int nldev_port_get_dumpit(struct sk_buff *skb,
...@@ -265,7 +281,7 @@ static int nldev_port_get_dumpit(struct sk_buff *skb, ...@@ -265,7 +281,7 @@ static int nldev_port_get_dumpit(struct sk_buff *skb,
return -EINVAL; return -EINVAL;
ifindex = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]); ifindex = nla_get_u32(tb[RDMA_NLDEV_ATTR_DEV_INDEX]);
device = __ib_device_get_by_index(ifindex); device = ib_device_get_by_index(ifindex);
if (!device) if (!device)
return -EINVAL; return -EINVAL;
...@@ -299,7 +315,9 @@ static int nldev_port_get_dumpit(struct sk_buff *skb, ...@@ -299,7 +315,9 @@ static int nldev_port_get_dumpit(struct sk_buff *skb,
nlmsg_end(skb, nlh); nlmsg_end(skb, nlh);
} }
out: cb->args[0] = idx; out:
put_device(&device->dev);
cb->args[0] = idx;
return skb->len; return skb->len;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment