Commit c7d8b782 authored by Jason Gunthorpe's avatar Jason Gunthorpe

hmm: use mmu_notifier_get/put for 'struct hmm'

This is a significant simplification, it eliminates all the remaining
'hmm' stuff in mm_struct, eliminates krefing along the critical notifier
paths, and takes away all the ugly locking and abuse of page_table_lock.

mmu_notifier_get() provides the single struct hmm per struct mm which
eliminates mm->hmm.

It also directly guarantees that no mmu_notifier op callback is callable
while concurrent free is possible, this eliminates all the krefs inside
the mmu_notifier callbacks.

The remaining krefs in the range code were overly cautious, drivers are
already not permitted to free the mirror while a range exists.

Link: https://lore.kernel.org/r/20190806231548.25242-6-jgg@ziepe.caReviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Reviewed-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Tested-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent e4c057d0
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <linux/pm_runtime.h> #include <linux/pm_runtime.h>
#include <linux/vga_switcheroo.h> #include <linux/vga_switcheroo.h>
#include <drm/drm_probe_helper.h> #include <drm/drm_probe_helper.h>
#include <linux/mmu_notifier.h>
#include "amdgpu.h" #include "amdgpu.h"
#include "amdgpu_irq.h" #include "amdgpu_irq.h"
...@@ -1464,6 +1465,7 @@ static void __exit amdgpu_exit(void) ...@@ -1464,6 +1465,7 @@ static void __exit amdgpu_exit(void)
amdgpu_unregister_atpx_handler(); amdgpu_unregister_atpx_handler();
amdgpu_sync_fini(); amdgpu_sync_fini();
amdgpu_fence_slab_fini(); amdgpu_fence_slab_fini();
mmu_notifier_synchronize();
} }
module_init(amdgpu_init); module_init(amdgpu_init);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <linux/pci.h> #include <linux/pci.h>
#include <linux/pm_runtime.h> #include <linux/pm_runtime.h>
#include <linux/vga_switcheroo.h> #include <linux/vga_switcheroo.h>
#include <linux/mmu_notifier.h>
#include <drm/drmP.h> #include <drm/drmP.h>
#include <drm/drm_crtc_helper.h> #include <drm/drm_crtc_helper.h>
...@@ -1292,6 +1293,8 @@ nouveau_drm_exit(void) ...@@ -1292,6 +1293,8 @@ nouveau_drm_exit(void)
#ifdef CONFIG_NOUVEAU_PLATFORM_DRIVER #ifdef CONFIG_NOUVEAU_PLATFORM_DRIVER
platform_driver_unregister(&nouveau_platform_driver); platform_driver_unregister(&nouveau_platform_driver);
#endif #endif
if (IS_ENABLED(CONFIG_DRM_NOUVEAU_SVM))
mmu_notifier_synchronize();
} }
module_init(nouveau_drm_init); module_init(nouveau_drm_init);
......
...@@ -84,15 +84,12 @@ ...@@ -84,15 +84,12 @@
* @notifiers: count of active mmu notifiers * @notifiers: count of active mmu notifiers
*/ */
struct hmm { struct hmm {
struct mm_struct *mm; struct mmu_notifier mmu_notifier;
struct kref kref;
spinlock_t ranges_lock; spinlock_t ranges_lock;
struct list_head ranges; struct list_head ranges;
struct list_head mirrors; struct list_head mirrors;
struct mmu_notifier mmu_notifier;
struct rw_semaphore mirrors_sem; struct rw_semaphore mirrors_sem;
wait_queue_head_t wq; wait_queue_head_t wq;
struct rcu_head rcu;
long notifiers; long notifiers;
}; };
...@@ -409,13 +406,6 @@ long hmm_range_dma_unmap(struct hmm_range *range, ...@@ -409,13 +406,6 @@ long hmm_range_dma_unmap(struct hmm_range *range,
*/ */
#define HMM_RANGE_DEFAULT_TIMEOUT 1000 #define HMM_RANGE_DEFAULT_TIMEOUT 1000
/* Below are for HMM internal use only! Not to be used by device driver! */
static inline void hmm_mm_init(struct mm_struct *mm)
{
mm->hmm = NULL;
}
#else /* IS_ENABLED(CONFIG_HMM_MIRROR) */
static inline void hmm_mm_init(struct mm_struct *mm) {}
#endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */ #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
#endif /* LINUX_HMM_H */ #endif /* LINUX_HMM_H */
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
struct address_space; struct address_space;
struct mem_cgroup; struct mem_cgroup;
struct hmm;
/* /*
* Each physical page in the system has a struct page associated with * Each physical page in the system has a struct page associated with
...@@ -502,11 +501,6 @@ struct mm_struct { ...@@ -502,11 +501,6 @@ struct mm_struct {
atomic_long_t hugetlb_usage; atomic_long_t hugetlb_usage;
#endif #endif
struct work_struct async_put_work; struct work_struct async_put_work;
#ifdef CONFIG_HMM_MIRROR
/* HMM needs to track a few things per mm */
struct hmm *hmm;
#endif
} __randomize_layout; } __randomize_layout;
/* /*
......
...@@ -1007,7 +1007,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, ...@@ -1007,7 +1007,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
mm_init_owner(mm, p); mm_init_owner(mm, p);
RCU_INIT_POINTER(mm->exe_file, NULL); RCU_INIT_POINTER(mm->exe_file, NULL);
mmu_notifier_mm_init(mm); mmu_notifier_mm_init(mm);
hmm_mm_init(mm);
init_tlb_flush_pending(mm); init_tlb_flush_pending(mm);
#if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
mm->pmd_huge_pte = NULL; mm->pmd_huge_pte = NULL;
......
...@@ -26,101 +26,37 @@ ...@@ -26,101 +26,37 @@
#include <linux/mmu_notifier.h> #include <linux/mmu_notifier.h>
#include <linux/memory_hotplug.h> #include <linux/memory_hotplug.h>
static const struct mmu_notifier_ops hmm_mmu_notifier_ops; static struct mmu_notifier *hmm_alloc_notifier(struct mm_struct *mm)
/**
* hmm_get_or_create - register HMM against an mm (HMM internal)
*
* @mm: mm struct to attach to
* Return: an HMM object, either by referencing the existing
* (per-process) object, or by creating a new one.
*
* This is not intended to be used directly by device drivers. If mm already
* has an HMM struct then it get a reference on it and returns it. Otherwise
* it allocates an HMM struct, initializes it, associate it with the mm and
* returns it.
*/
static struct hmm *hmm_get_or_create(struct mm_struct *mm)
{ {
struct hmm *hmm; struct hmm *hmm;
lockdep_assert_held_write(&mm->mmap_sem); hmm = kzalloc(sizeof(*hmm), GFP_KERNEL);
/* Abuse the page_table_lock to also protect mm->hmm. */
spin_lock(&mm->page_table_lock);
hmm = mm->hmm;
if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
goto out_unlock;
spin_unlock(&mm->page_table_lock);
hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
if (!hmm) if (!hmm)
return NULL; return ERR_PTR(-ENOMEM);
init_waitqueue_head(&hmm->wq); init_waitqueue_head(&hmm->wq);
INIT_LIST_HEAD(&hmm->mirrors); INIT_LIST_HEAD(&hmm->mirrors);
init_rwsem(&hmm->mirrors_sem); init_rwsem(&hmm->mirrors_sem);
hmm->mmu_notifier.ops = NULL;
INIT_LIST_HEAD(&hmm->ranges); INIT_LIST_HEAD(&hmm->ranges);
spin_lock_init(&hmm->ranges_lock); spin_lock_init(&hmm->ranges_lock);
kref_init(&hmm->kref);
hmm->notifiers = 0; hmm->notifiers = 0;
hmm->mm = mm; return &hmm->mmu_notifier;
hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
kfree(hmm);
return NULL;
}
mmgrab(hmm->mm);
/*
* We hold the exclusive mmap_sem here so we know that mm->hmm is
* still NULL or 0 kref, and is safe to update.
*/
spin_lock(&mm->page_table_lock);
mm->hmm = hmm;
out_unlock:
spin_unlock(&mm->page_table_lock);
return hmm;
} }
static void hmm_free_rcu(struct rcu_head *rcu) static void hmm_free_notifier(struct mmu_notifier *mn)
{ {
struct hmm *hmm = container_of(rcu, struct hmm, rcu); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
mmdrop(hmm->mm); WARN_ON(!list_empty(&hmm->ranges));
WARN_ON(!list_empty(&hmm->mirrors));
kfree(hmm); kfree(hmm);
} }
static void hmm_free(struct kref *kref)
{
struct hmm *hmm = container_of(kref, struct hmm, kref);
spin_lock(&hmm->mm->page_table_lock);
if (hmm->mm->hmm == hmm)
hmm->mm->hmm = NULL;
spin_unlock(&hmm->mm->page_table_lock);
mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
}
static inline void hmm_put(struct hmm *hmm)
{
kref_put(&hmm->kref, hmm_free);
}
static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
{ {
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
struct hmm_mirror *mirror; struct hmm_mirror *mirror;
/* Bail out if hmm is in the process of being freed */
if (!kref_get_unless_zero(&hmm->kref))
return;
/* /*
* Since hmm_range_register() holds the mmget() lock hmm_release() is * Since hmm_range_register() holds the mmget() lock hmm_release() is
* prevented as long as a range exists. * prevented as long as a range exists.
...@@ -137,8 +73,6 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -137,8 +73,6 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
mirror->ops->release(mirror); mirror->ops->release(mirror);
} }
up_read(&hmm->mirrors_sem); up_read(&hmm->mirrors_sem);
hmm_put(hmm);
} }
static void notifiers_decrement(struct hmm *hmm) static void notifiers_decrement(struct hmm *hmm)
...@@ -169,9 +103,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ...@@ -169,9 +103,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
unsigned long flags; unsigned long flags;
int ret = 0; int ret = 0;
if (!kref_get_unless_zero(&hmm->kref))
return 0;
spin_lock_irqsave(&hmm->ranges_lock, flags); spin_lock_irqsave(&hmm->ranges_lock, flags);
hmm->notifiers++; hmm->notifiers++;
list_for_each_entry(range, &hmm->ranges, list) { list_for_each_entry(range, &hmm->ranges, list) {
...@@ -206,7 +137,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ...@@ -206,7 +137,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
out: out:
if (ret) if (ret)
notifiers_decrement(hmm); notifiers_decrement(hmm);
hmm_put(hmm);
return ret; return ret;
} }
...@@ -215,17 +145,15 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn, ...@@ -215,17 +145,15 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
{ {
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
if (!kref_get_unless_zero(&hmm->kref))
return;
notifiers_decrement(hmm); notifiers_decrement(hmm);
hmm_put(hmm);
} }
static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
.release = hmm_release, .release = hmm_release,
.invalidate_range_start = hmm_invalidate_range_start, .invalidate_range_start = hmm_invalidate_range_start,
.invalidate_range_end = hmm_invalidate_range_end, .invalidate_range_end = hmm_invalidate_range_end,
.alloc_notifier = hmm_alloc_notifier,
.free_notifier = hmm_free_notifier,
}; };
/* /*
...@@ -237,18 +165,27 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { ...@@ -237,18 +165,27 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
* *
* To start mirroring a process address space, the device driver must register * To start mirroring a process address space, the device driver must register
* an HMM mirror struct. * an HMM mirror struct.
*
* The caller cannot unregister the hmm_mirror while any ranges are
* registered.
*
* Callers using this function must put a call to mmu_notifier_synchronize()
* in their module exit functions.
*/ */
int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm) int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
{ {
struct mmu_notifier *mn;
lockdep_assert_held_write(&mm->mmap_sem); lockdep_assert_held_write(&mm->mmap_sem);
/* Sanity check */ /* Sanity check */
if (!mm || !mirror || !mirror->ops) if (!mm || !mirror || !mirror->ops)
return -EINVAL; return -EINVAL;
mirror->hmm = hmm_get_or_create(mm); mn = mmu_notifier_get_locked(&hmm_mmu_notifier_ops, mm);
if (!mirror->hmm) if (IS_ERR(mn))
return -ENOMEM; return PTR_ERR(mn);
mirror->hmm = container_of(mn, struct hmm, mmu_notifier);
down_write(&mirror->hmm->mirrors_sem); down_write(&mirror->hmm->mirrors_sem);
list_add(&mirror->list, &mirror->hmm->mirrors); list_add(&mirror->list, &mirror->hmm->mirrors);
...@@ -272,7 +209,7 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror) ...@@ -272,7 +209,7 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror)
down_write(&hmm->mirrors_sem); down_write(&hmm->mirrors_sem);
list_del(&mirror->list); list_del(&mirror->list);
up_write(&hmm->mirrors_sem); up_write(&hmm->mirrors_sem);
hmm_put(hmm); mmu_notifier_put(&hmm->mmu_notifier);
} }
EXPORT_SYMBOL(hmm_mirror_unregister); EXPORT_SYMBOL(hmm_mirror_unregister);
...@@ -854,14 +791,13 @@ int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror) ...@@ -854,14 +791,13 @@ int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror)
return -EINVAL; return -EINVAL;
/* Prevent hmm_release() from running while the range is valid */ /* Prevent hmm_release() from running while the range is valid */
if (!mmget_not_zero(hmm->mm)) if (!mmget_not_zero(hmm->mmu_notifier.mm))
return -EFAULT; return -EFAULT;
/* Initialize range to track CPU page table updates. */ /* Initialize range to track CPU page table updates. */
spin_lock_irqsave(&hmm->ranges_lock, flags); spin_lock_irqsave(&hmm->ranges_lock, flags);
range->hmm = hmm; range->hmm = hmm;
kref_get(&hmm->kref);
list_add(&range->list, &hmm->ranges); list_add(&range->list, &hmm->ranges);
/* /*
...@@ -893,8 +829,7 @@ void hmm_range_unregister(struct hmm_range *range) ...@@ -893,8 +829,7 @@ void hmm_range_unregister(struct hmm_range *range)
spin_unlock_irqrestore(&hmm->ranges_lock, flags); spin_unlock_irqrestore(&hmm->ranges_lock, flags);
/* Drop reference taken by hmm_range_register() */ /* Drop reference taken by hmm_range_register() */
mmput(hmm->mm); mmput(hmm->mmu_notifier.mm);
hmm_put(hmm);
/* /*
* The range is now invalid and the ref on the hmm is dropped, so * The range is now invalid and the ref on the hmm is dropped, so
...@@ -944,14 +879,14 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags) ...@@ -944,14 +879,14 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
struct mm_walk mm_walk; struct mm_walk mm_walk;
int ret; int ret;
lockdep_assert_held(&hmm->mm->mmap_sem); lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
do { do {
/* If range is no longer valid force retry. */ /* If range is no longer valid force retry. */
if (!range->valid) if (!range->valid)
return -EBUSY; return -EBUSY;
vma = find_vma(hmm->mm, start); vma = find_vma(hmm->mmu_notifier.mm, start);
if (vma == NULL || (vma->vm_flags & device_vma)) if (vma == NULL || (vma->vm_flags & device_vma))
return -EFAULT; return -EFAULT;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment