Commit e946f8e3 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Joerg Roedel

iommu: Remove useless group refcounting

Several functions obtain the group reference and then release it before
returning. This gives the impression that the refcount is protecting
something for the duration of the function.

In truth all of these functions are called in places that know a device
driver is probed to the device and our locking rules already require
that dev->iommu_group cannot change while a driver is attached to the
struct device.

If this was not the case then this code is already at risk of triggering
UAF as it is racy if the dev->iommu_group is concurrently going to
NULL/free. refcount debugging will throw a WARN if kobject_get() is
called on a 0 refcount object to highlight the bug.

Remove the confusing refcounting and leave behind a comment about the
restriction.
Reviewed-by: default avatarLu Baolu <baolu.lu@linux.intel.com>
Reviewed-by: default avatarKevin Tian <kevin.tian@intel.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
Link: https://lore.kernel.org/r/1-v1-c869a95191f2+5e8-iommu_single_grp_jgg@nvidia.comSigned-off-by: default avatarJoerg Roedel <jroedel@suse.de>
parent 4efd98d4
...@@ -2152,10 +2152,10 @@ static int __iommu_attach_device(struct iommu_domain *domain, ...@@ -2152,10 +2152,10 @@ static int __iommu_attach_device(struct iommu_domain *domain,
*/ */
int iommu_attach_device(struct iommu_domain *domain, struct device *dev) int iommu_attach_device(struct iommu_domain *domain, struct device *dev)
{ {
struct iommu_group *group; /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
int ret; int ret;
group = iommu_group_get(dev);
if (!group) if (!group)
return -ENODEV; return -ENODEV;
...@@ -2172,8 +2172,6 @@ int iommu_attach_device(struct iommu_domain *domain, struct device *dev) ...@@ -2172,8 +2172,6 @@ int iommu_attach_device(struct iommu_domain *domain, struct device *dev)
out_unlock: out_unlock:
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(iommu_attach_device); EXPORT_SYMBOL_GPL(iommu_attach_device);
...@@ -2188,9 +2186,9 @@ int iommu_deferred_attach(struct device *dev, struct iommu_domain *domain) ...@@ -2188,9 +2186,9 @@ int iommu_deferred_attach(struct device *dev, struct iommu_domain *domain)
void iommu_detach_device(struct iommu_domain *domain, struct device *dev) void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
{ {
struct iommu_group *group; /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
group = iommu_group_get(dev);
if (!group) if (!group)
return; return;
...@@ -2202,24 +2200,18 @@ void iommu_detach_device(struct iommu_domain *domain, struct device *dev) ...@@ -2202,24 +2200,18 @@ void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
out_unlock: out_unlock:
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
} }
EXPORT_SYMBOL_GPL(iommu_detach_device); EXPORT_SYMBOL_GPL(iommu_detach_device);
struct iommu_domain *iommu_get_domain_for_dev(struct device *dev) struct iommu_domain *iommu_get_domain_for_dev(struct device *dev)
{ {
struct iommu_domain *domain; /* Caller must be a probed driver on dev */
struct iommu_group *group; struct iommu_group *group = dev->iommu_group;
group = iommu_group_get(dev);
if (!group) if (!group)
return NULL; return NULL;
domain = group->domain; return group->domain;
iommu_group_put(group);
return domain;
} }
EXPORT_SYMBOL_GPL(iommu_get_domain_for_dev); EXPORT_SYMBOL_GPL(iommu_get_domain_for_dev);
...@@ -3203,7 +3195,8 @@ static bool iommu_is_default_domain(struct iommu_group *group) ...@@ -3203,7 +3195,8 @@ static bool iommu_is_default_domain(struct iommu_group *group)
*/ */
int iommu_device_use_default_domain(struct device *dev) int iommu_device_use_default_domain(struct device *dev)
{ {
struct iommu_group *group = iommu_group_get(dev); /* Caller is the driver core during the pre-probe path */
struct iommu_group *group = dev->iommu_group;
int ret = 0; int ret = 0;
if (!group) if (!group)
...@@ -3222,8 +3215,6 @@ int iommu_device_use_default_domain(struct device *dev) ...@@ -3222,8 +3215,6 @@ int iommu_device_use_default_domain(struct device *dev)
unlock_out: unlock_out:
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
return ret; return ret;
} }
...@@ -3237,7 +3228,8 @@ int iommu_device_use_default_domain(struct device *dev) ...@@ -3237,7 +3228,8 @@ int iommu_device_use_default_domain(struct device *dev)
*/ */
void iommu_device_unuse_default_domain(struct device *dev) void iommu_device_unuse_default_domain(struct device *dev)
{ {
struct iommu_group *group = iommu_group_get(dev); /* Caller is the driver core during the post-probe path */
struct iommu_group *group = dev->iommu_group;
if (!group) if (!group)
return; return;
...@@ -3247,7 +3239,6 @@ void iommu_device_unuse_default_domain(struct device *dev) ...@@ -3247,7 +3239,6 @@ void iommu_device_unuse_default_domain(struct device *dev)
group->owner_cnt--; group->owner_cnt--;
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
} }
static int __iommu_group_alloc_blocking_domain(struct iommu_group *group) static int __iommu_group_alloc_blocking_domain(struct iommu_group *group)
...@@ -3331,13 +3322,13 @@ EXPORT_SYMBOL_GPL(iommu_group_claim_dma_owner); ...@@ -3331,13 +3322,13 @@ EXPORT_SYMBOL_GPL(iommu_group_claim_dma_owner);
*/ */
int iommu_device_claim_dma_owner(struct device *dev, void *owner) int iommu_device_claim_dma_owner(struct device *dev, void *owner)
{ {
struct iommu_group *group; /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
int ret = 0; int ret = 0;
if (WARN_ON(!owner)) if (WARN_ON(!owner))
return -EINVAL; return -EINVAL;
group = iommu_group_get(dev);
if (!group) if (!group)
return -ENODEV; return -ENODEV;
...@@ -3354,8 +3345,6 @@ int iommu_device_claim_dma_owner(struct device *dev, void *owner) ...@@ -3354,8 +3345,6 @@ int iommu_device_claim_dma_owner(struct device *dev, void *owner)
ret = __iommu_take_dma_ownership(group, owner); ret = __iommu_take_dma_ownership(group, owner);
unlock_out: unlock_out:
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(iommu_device_claim_dma_owner); EXPORT_SYMBOL_GPL(iommu_device_claim_dma_owner);
...@@ -3393,7 +3382,8 @@ EXPORT_SYMBOL_GPL(iommu_group_release_dma_owner); ...@@ -3393,7 +3382,8 @@ EXPORT_SYMBOL_GPL(iommu_group_release_dma_owner);
*/ */
void iommu_device_release_dma_owner(struct device *dev) void iommu_device_release_dma_owner(struct device *dev)
{ {
struct iommu_group *group = iommu_group_get(dev); /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
mutex_lock(&group->mutex); mutex_lock(&group->mutex);
if (group->owner_cnt > 1) if (group->owner_cnt > 1)
...@@ -3401,7 +3391,6 @@ void iommu_device_release_dma_owner(struct device *dev) ...@@ -3401,7 +3391,6 @@ void iommu_device_release_dma_owner(struct device *dev)
else else
__iommu_release_dma_ownership(group); __iommu_release_dma_ownership(group);
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
} }
EXPORT_SYMBOL_GPL(iommu_device_release_dma_owner); EXPORT_SYMBOL_GPL(iommu_device_release_dma_owner);
...@@ -3462,14 +3451,14 @@ static void __iommu_remove_group_pasid(struct iommu_group *group, ...@@ -3462,14 +3451,14 @@ static void __iommu_remove_group_pasid(struct iommu_group *group,
int iommu_attach_device_pasid(struct iommu_domain *domain, int iommu_attach_device_pasid(struct iommu_domain *domain,
struct device *dev, ioasid_t pasid) struct device *dev, ioasid_t pasid)
{ {
struct iommu_group *group; /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
void *curr; void *curr;
int ret; int ret;
if (!domain->ops->set_dev_pasid) if (!domain->ops->set_dev_pasid)
return -EOPNOTSUPP; return -EOPNOTSUPP;
group = iommu_group_get(dev);
if (!group) if (!group)
return -ENODEV; return -ENODEV;
...@@ -3487,8 +3476,6 @@ int iommu_attach_device_pasid(struct iommu_domain *domain, ...@@ -3487,8 +3476,6 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
} }
out_unlock: out_unlock:
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(iommu_attach_device_pasid); EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
...@@ -3505,14 +3492,13 @@ EXPORT_SYMBOL_GPL(iommu_attach_device_pasid); ...@@ -3505,14 +3492,13 @@ EXPORT_SYMBOL_GPL(iommu_attach_device_pasid);
void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev, void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
ioasid_t pasid) ioasid_t pasid)
{ {
struct iommu_group *group = iommu_group_get(dev); /* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
mutex_lock(&group->mutex); mutex_lock(&group->mutex);
__iommu_remove_group_pasid(group, pasid); __iommu_remove_group_pasid(group, pasid);
WARN_ON(xa_erase(&group->pasid_array, pasid) != domain); WARN_ON(xa_erase(&group->pasid_array, pasid) != domain);
mutex_unlock(&group->mutex); mutex_unlock(&group->mutex);
iommu_group_put(group);
} }
EXPORT_SYMBOL_GPL(iommu_detach_device_pasid); EXPORT_SYMBOL_GPL(iommu_detach_device_pasid);
...@@ -3534,10 +3520,10 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev, ...@@ -3534,10 +3520,10 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev,
ioasid_t pasid, ioasid_t pasid,
unsigned int type) unsigned int type)
{ {
/* Caller must be a probed driver on dev */
struct iommu_group *group = dev->iommu_group;
struct iommu_domain *domain; struct iommu_domain *domain;
struct iommu_group *group;
group = iommu_group_get(dev);
if (!group) if (!group)
return NULL; return NULL;
...@@ -3546,7 +3532,6 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev, ...@@ -3546,7 +3532,6 @@ struct iommu_domain *iommu_get_domain_for_dev_pasid(struct device *dev,
if (type && domain && domain->type != type) if (type && domain && domain->type != type)
domain = ERR_PTR(-EBUSY); domain = ERR_PTR(-EBUSY);
xa_unlock(&group->pasid_array); xa_unlock(&group->pasid_array);
iommu_group_put(group);
return domain; return domain;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment