Commit 870c7ad4 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

devlink: protect devlink->dev by the instance lock

devlink->dev is assumed to be always valid as long as any
outstanding reference to the devlink instance exists.

In prep for weakening of the references take the instance lock.
Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
Reviewed-by: default avatarJiri Pirko <jiri@nvidia.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 7a54a519
...@@ -131,7 +131,8 @@ struct devlink_gen_cmd { ...@@ -131,7 +131,8 @@ struct devlink_gen_cmd {
extern const struct genl_small_ops devlink_nl_ops[56]; extern const struct genl_small_ops devlink_nl_ops[56];
struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs); struct devlink *
devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs);
void devlink_notify_unregister(struct devlink *devlink); void devlink_notify_unregister(struct devlink *devlink);
void devlink_notify_register(struct devlink *devlink); void devlink_notify_register(struct devlink *devlink);
......
...@@ -6314,12 +6314,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, ...@@ -6314,12 +6314,10 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb,
start_offset = state->start_offset; start_offset = state->start_offset;
devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
if (IS_ERR(devlink)) if (IS_ERR(devlink))
return PTR_ERR(devlink); return PTR_ERR(devlink);
devl_lock(devlink);
if (!attrs[DEVLINK_ATTR_REGION_NAME]) { if (!attrs[DEVLINK_ATTR_REGION_NAME]) {
NL_SET_ERR_MSG(cb->extack, "No region name provided"); NL_SET_ERR_MSG(cb->extack, "No region name provided");
err = -EINVAL; err = -EINVAL;
...@@ -7735,9 +7733,10 @@ devlink_health_reporter_get_from_cb(struct netlink_callback *cb) ...@@ -7735,9 +7733,10 @@ devlink_health_reporter_get_from_cb(struct netlink_callback *cb)
struct nlattr **attrs = info->attrs; struct nlattr **attrs = info->attrs;
struct devlink *devlink; struct devlink *devlink;
devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
if (IS_ERR(devlink)) if (IS_ERR(devlink))
return NULL; return NULL;
devl_unlock(devlink);
reporter = devlink_health_reporter_get_from_attrs(devlink, attrs); reporter = devlink_health_reporter_get_from_attrs(devlink, attrs);
devlink_put(devlink); devlink_put(devlink);
......
...@@ -82,7 +82,8 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { ...@@ -82,7 +82,8 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = {
[DEVLINK_ATTR_REGION_DIRECT] = { .type = NLA_FLAG }, [DEVLINK_ATTR_REGION_DIRECT] = { .type = NLA_FLAG },
}; };
struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) struct devlink *
devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs)
{ {
struct devlink *devlink; struct devlink *devlink;
unsigned long index; unsigned long index;
...@@ -96,9 +97,11 @@ struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs) ...@@ -96,9 +97,11 @@ struct devlink *devlink_get_from_attrs(struct net *net, struct nlattr **attrs)
devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
devlinks_xa_for_each_registered_get(net, index, devlink) { devlinks_xa_for_each_registered_get(net, index, devlink) {
devl_lock(devlink);
if (strcmp(devlink->dev->bus->name, busname) == 0 && if (strcmp(devlink->dev->bus->name, busname) == 0 &&
strcmp(dev_name(devlink->dev), devname) == 0) strcmp(dev_name(devlink->dev), devname) == 0)
return devlink; return devlink;
devl_unlock(devlink);
devlink_put(devlink); devlink_put(devlink);
} }
...@@ -113,10 +116,10 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops, ...@@ -113,10 +116,10 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops,
struct devlink *devlink; struct devlink *devlink;
int err; int err;
devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs); devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs);
if (IS_ERR(devlink)) if (IS_ERR(devlink))
return PTR_ERR(devlink); return PTR_ERR(devlink);
devl_lock(devlink);
info->user_ptr[0] = devlink; info->user_ptr[0] = devlink;
if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) {
devlink_port = devlink_port_get_from_info(devlink, info); devlink_port = devlink_port_get_from_info(devlink, info);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment