Commit f4fb3b9c authored by Jason Gunthorpe's avatar Jason Gunthorpe

Merge 'notifier_get_put' into hmm.git

Jason Gunthorpe says:

====================
Add mmu_notifier_get/put for managing mmu notifier registrations

This series introduces a new registration flow for mmu_notifiers based on
the idea that the user would like to get a single refcounted piece of
memory for a mm, keyed to its use.

For instance many users of mmu_notifiers use an interval tree or similar
to dispatch notifications to some object. There are many objects but only
one notifier subscription per mm holding the tree.

Of the 12 places that call mmu_notifier_register:
 - 7 are maintaining some kind of obvious mapping of mm_struct to
   mmu_notifier registration, ie in some linked list or hash table. Of
   the 7 this series converts 4 (gru, hmm, RDMA, radeon)

 - 3 (hfi1, gntdev, vhost) are registering multiple notifiers, but each
   one immediately does some VA range filtering, ie with an interval tree.
   These would be better with a global subsystem-wide range filter and
   could convert to this API.

 - 2 (kvm, amd_iommu) are deliberately using a single mm at a time, and
   really can't use this API. One of the intel-svm's modes is also in this
   list

The 3/7 unconverted drivers are:
 - intel-svm
   This driver tracks mm's in a global linked list 'global_svm_list'
   and would benefit from this API.

   Its flow is a bit complex, since it also wants a set of non-shared
   notifiers.

 - i915_gem_usrptr
   This driver tracks mm's in a per-device hash
   table (dev_priv->mm_structs), but only has an optional use of
   mmu_notifiers.  Since it still seems to need the hash table it is
   difficult to convert.

 - amdkfd/kfd_process
   This driver is using a global SRCU hash table to track mm's

   The control flow here is very complicated and the driver is relying on
   this hash table to be fast on the ioctl syscall path.

   It would definitely benefit, but only if the ioctl path didn't need to
   do the search so often.
====================
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parents 9c240a7b 471f3902
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <linux/pm_runtime.h> #include <linux/pm_runtime.h>
#include <linux/vga_switcheroo.h> #include <linux/vga_switcheroo.h>
#include <drm/drm_probe_helper.h> #include <drm/drm_probe_helper.h>
#include <linux/mmu_notifier.h>
#include "amdgpu.h" #include "amdgpu.h"
#include "amdgpu_irq.h" #include "amdgpu_irq.h"
...@@ -1464,6 +1465,7 @@ static void __exit amdgpu_exit(void) ...@@ -1464,6 +1465,7 @@ static void __exit amdgpu_exit(void)
amdgpu_unregister_atpx_handler(); amdgpu_unregister_atpx_handler();
amdgpu_sync_fini(); amdgpu_sync_fini();
amdgpu_fence_slab_fini(); amdgpu_fence_slab_fini();
mmu_notifier_synchronize();
} }
module_init(amdgpu_init); module_init(amdgpu_init);
......
...@@ -686,9 +686,6 @@ struct kfd_process { ...@@ -686,9 +686,6 @@ struct kfd_process {
/* We want to receive a notification when the mm_struct is destroyed */ /* We want to receive a notification when the mm_struct is destroyed */
struct mmu_notifier mmu_notifier; struct mmu_notifier mmu_notifier;
/* Use for delayed freeing of kfd_process structure */
struct rcu_head rcu;
unsigned int pasid; unsigned int pasid;
unsigned int doorbell_index; unsigned int doorbell_index;
......
...@@ -62,8 +62,8 @@ static struct workqueue_struct *kfd_restore_wq; ...@@ -62,8 +62,8 @@ static struct workqueue_struct *kfd_restore_wq;
static struct kfd_process *find_process(const struct task_struct *thread); static struct kfd_process *find_process(const struct task_struct *thread);
static void kfd_process_ref_release(struct kref *ref); static void kfd_process_ref_release(struct kref *ref);
static struct kfd_process *create_process(const struct task_struct *thread, static struct kfd_process *create_process(const struct task_struct *thread);
struct file *filep); static int kfd_process_init_cwsr_apu(struct kfd_process *p, struct file *filep);
static void evict_process_worker(struct work_struct *work); static void evict_process_worker(struct work_struct *work);
static void restore_process_worker(struct work_struct *work); static void restore_process_worker(struct work_struct *work);
...@@ -289,7 +289,15 @@ struct kfd_process *kfd_create_process(struct file *filep) ...@@ -289,7 +289,15 @@ struct kfd_process *kfd_create_process(struct file *filep)
if (process) { if (process) {
pr_debug("Process already found\n"); pr_debug("Process already found\n");
} else { } else {
process = create_process(thread, filep); process = create_process(thread);
if (IS_ERR(process))
goto out;
ret = kfd_process_init_cwsr_apu(process, filep);
if (ret) {
process = ERR_PTR(ret);
goto out;
}
if (!procfs.kobj) if (!procfs.kobj)
goto out; goto out;
...@@ -478,11 +486,9 @@ static void kfd_process_ref_release(struct kref *ref) ...@@ -478,11 +486,9 @@ static void kfd_process_ref_release(struct kref *ref)
queue_work(kfd_process_wq, &p->release_work); queue_work(kfd_process_wq, &p->release_work);
} }
static void kfd_process_destroy_delayed(struct rcu_head *rcu) static void kfd_process_free_notifier(struct mmu_notifier *mn)
{ {
struct kfd_process *p = container_of(rcu, struct kfd_process, rcu); kfd_unref_process(container_of(mn, struct kfd_process, mmu_notifier));
kfd_unref_process(p);
} }
static void kfd_process_notifier_release(struct mmu_notifier *mn, static void kfd_process_notifier_release(struct mmu_notifier *mn,
...@@ -534,12 +540,12 @@ static void kfd_process_notifier_release(struct mmu_notifier *mn, ...@@ -534,12 +540,12 @@ static void kfd_process_notifier_release(struct mmu_notifier *mn,
mutex_unlock(&p->mutex); mutex_unlock(&p->mutex);
mmu_notifier_unregister_no_release(&p->mmu_notifier, mm); mmu_notifier_put(&p->mmu_notifier);
mmu_notifier_call_srcu(&p->rcu, &kfd_process_destroy_delayed);
} }
static const struct mmu_notifier_ops kfd_process_mmu_notifier_ops = { static const struct mmu_notifier_ops kfd_process_mmu_notifier_ops = {
.release = kfd_process_notifier_release, .release = kfd_process_notifier_release,
.free_notifier = kfd_process_free_notifier,
}; };
static int kfd_process_init_cwsr_apu(struct kfd_process *p, struct file *filep) static int kfd_process_init_cwsr_apu(struct kfd_process *p, struct file *filep)
...@@ -609,81 +615,69 @@ static int kfd_process_device_init_cwsr_dgpu(struct kfd_process_device *pdd) ...@@ -609,81 +615,69 @@ static int kfd_process_device_init_cwsr_dgpu(struct kfd_process_device *pdd)
return 0; return 0;
} }
static struct kfd_process *create_process(const struct task_struct *thread, /*
struct file *filep) * On return the kfd_process is fully operational and will be freed when the
* mm is released
*/
static struct kfd_process *create_process(const struct task_struct *thread)
{ {
struct kfd_process *process; struct kfd_process *process;
int err = -ENOMEM; int err = -ENOMEM;
process = kzalloc(sizeof(*process), GFP_KERNEL); process = kzalloc(sizeof(*process), GFP_KERNEL);
if (!process) if (!process)
goto err_alloc_process; goto err_alloc_process;
process->pasid = kfd_pasid_alloc();
if (process->pasid == 0)
goto err_alloc_pasid;
if (kfd_alloc_process_doorbells(process) < 0)
goto err_alloc_doorbells;
kref_init(&process->ref); kref_init(&process->ref);
mutex_init(&process->mutex); mutex_init(&process->mutex);
process->mm = thread->mm; process->mm = thread->mm;
/* register notifier */
process->mmu_notifier.ops = &kfd_process_mmu_notifier_ops;
err = mmu_notifier_register(&process->mmu_notifier, process->mm);
if (err)
goto err_mmu_notifier;
hash_add_rcu(kfd_processes_table, &process->kfd_processes,
(uintptr_t)process->mm);
process->lead_thread = thread->group_leader; process->lead_thread = thread->group_leader;
get_task_struct(process->lead_thread);
INIT_LIST_HEAD(&process->per_device_data); INIT_LIST_HEAD(&process->per_device_data);
INIT_DELAYED_WORK(&process->eviction_work, evict_process_worker);
INIT_DELAYED_WORK(&process->restore_work, restore_process_worker);
process->last_restore_timestamp = get_jiffies_64();
kfd_event_init_process(process); kfd_event_init_process(process);
process->is_32bit_user_mode = in_compat_syscall();
process->pasid = kfd_pasid_alloc();
if (process->pasid == 0)
goto err_alloc_pasid;
if (kfd_alloc_process_doorbells(process) < 0)
goto err_alloc_doorbells;
err = pqm_init(&process->pqm, process); err = pqm_init(&process->pqm, process);
if (err != 0) if (err != 0)
goto err_process_pqm_init; goto err_process_pqm_init;
/* init process apertures*/ /* init process apertures*/
process->is_32bit_user_mode = in_compat_syscall();
err = kfd_init_apertures(process); err = kfd_init_apertures(process);
if (err != 0) if (err != 0)
goto err_init_apertures; goto err_init_apertures;
INIT_DELAYED_WORK(&process->eviction_work, evict_process_worker); /* Must be last, have to use release destruction after this */
INIT_DELAYED_WORK(&process->restore_work, restore_process_worker); process->mmu_notifier.ops = &kfd_process_mmu_notifier_ops;
process->last_restore_timestamp = get_jiffies_64(); err = mmu_notifier_register(&process->mmu_notifier, process->mm);
err = kfd_process_init_cwsr_apu(process, filep);
if (err) if (err)
goto err_init_cwsr; goto err_register_notifier;
get_task_struct(process->lead_thread);
hash_add_rcu(kfd_processes_table, &process->kfd_processes,
(uintptr_t)process->mm);
return process; return process;
err_init_cwsr: err_register_notifier:
kfd_process_free_outstanding_kfd_bos(process); kfd_process_free_outstanding_kfd_bos(process);
kfd_process_destroy_pdds(process); kfd_process_destroy_pdds(process);
err_init_apertures: err_init_apertures:
pqm_uninit(&process->pqm); pqm_uninit(&process->pqm);
err_process_pqm_init: err_process_pqm_init:
hash_del_rcu(&process->kfd_processes);
synchronize_rcu();
mmu_notifier_unregister_no_release(&process->mmu_notifier, process->mm);
err_mmu_notifier:
mutex_destroy(&process->mutex);
kfd_free_process_doorbells(process); kfd_free_process_doorbells(process);
err_alloc_doorbells: err_alloc_doorbells:
kfd_pasid_free(process->pasid); kfd_pasid_free(process->pasid);
err_alloc_pasid: err_alloc_pasid:
mutex_destroy(&process->mutex);
kfree(process); kfree(process);
err_alloc_process: err_alloc_process:
return ERR_PTR(err); return ERR_PTR(err);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <linux/pci.h> #include <linux/pci.h>
#include <linux/pm_runtime.h> #include <linux/pm_runtime.h>
#include <linux/vga_switcheroo.h> #include <linux/vga_switcheroo.h>
#include <linux/mmu_notifier.h>
#include <drm/drmP.h> #include <drm/drmP.h>
#include <drm/drm_crtc_helper.h> #include <drm/drm_crtc_helper.h>
...@@ -1292,6 +1293,8 @@ nouveau_drm_exit(void) ...@@ -1292,6 +1293,8 @@ nouveau_drm_exit(void)
#ifdef CONFIG_NOUVEAU_PLATFORM_DRIVER #ifdef CONFIG_NOUVEAU_PLATFORM_DRIVER
platform_driver_unregister(&nouveau_platform_driver); platform_driver_unregister(&nouveau_platform_driver);
#endif #endif
if (IS_ENABLED(CONFIG_DRM_NOUVEAU_SVM))
mmu_notifier_synchronize();
} }
module_init(nouveau_drm_init); module_init(nouveau_drm_init);
......
...@@ -2451,9 +2451,6 @@ struct radeon_device { ...@@ -2451,9 +2451,6 @@ struct radeon_device {
/* tracking pinned memory */ /* tracking pinned memory */
u64 vram_pin_size; u64 vram_pin_size;
u64 gart_pin_size; u64 gart_pin_size;
struct mutex mn_lock;
DECLARE_HASHTABLE(mn_hash, 7);
}; };
bool radeon_is_px(struct drm_device *dev); bool radeon_is_px(struct drm_device *dev);
......
...@@ -1325,8 +1325,6 @@ int radeon_device_init(struct radeon_device *rdev, ...@@ -1325,8 +1325,6 @@ int radeon_device_init(struct radeon_device *rdev,
init_rwsem(&rdev->pm.mclk_lock); init_rwsem(&rdev->pm.mclk_lock);
init_rwsem(&rdev->exclusive_lock); init_rwsem(&rdev->exclusive_lock);
init_waitqueue_head(&rdev->irq.vblank_queue); init_waitqueue_head(&rdev->irq.vblank_queue);
mutex_init(&rdev->mn_lock);
hash_init(rdev->mn_hash);
r = radeon_gem_init(rdev); r = radeon_gem_init(rdev);
if (r) if (r)
return r; return r;
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <linux/module.h> #include <linux/module.h>
#include <linux/pm_runtime.h> #include <linux/pm_runtime.h>
#include <linux/vga_switcheroo.h> #include <linux/vga_switcheroo.h>
#include <linux/mmu_notifier.h>
#include <drm/drm_crtc_helper.h> #include <drm/drm_crtc_helper.h>
#include <drm/drm_drv.h> #include <drm/drm_drv.h>
...@@ -624,6 +625,7 @@ static void __exit radeon_exit(void) ...@@ -624,6 +625,7 @@ static void __exit radeon_exit(void)
{ {
pci_unregister_driver(pdriver); pci_unregister_driver(pdriver);
radeon_unregister_atpx_handler(); radeon_unregister_atpx_handler();
mmu_notifier_synchronize();
} }
module_init(radeon_init); module_init(radeon_init);
......
...@@ -37,17 +37,8 @@ ...@@ -37,17 +37,8 @@
#include "radeon.h" #include "radeon.h"
struct radeon_mn { struct radeon_mn {
/* constant after initialisation */
struct radeon_device *rdev;
struct mm_struct *mm;
struct mmu_notifier mn; struct mmu_notifier mn;
/* only used on destruction */
struct work_struct work;
/* protected by rdev->mn_lock */
struct hlist_node node;
/* objects protected by lock */ /* objects protected by lock */
struct mutex lock; struct mutex lock;
struct rb_root_cached objects; struct rb_root_cached objects;
...@@ -58,55 +49,6 @@ struct radeon_mn_node { ...@@ -58,55 +49,6 @@ struct radeon_mn_node {
struct list_head bos; struct list_head bos;
}; };
/**
* radeon_mn_destroy - destroy the rmn
*
* @work: previously sheduled work item
*
* Lazy destroys the notifier from a work item
*/
static void radeon_mn_destroy(struct work_struct *work)
{
struct radeon_mn *rmn = container_of(work, struct radeon_mn, work);
struct radeon_device *rdev = rmn->rdev;
struct radeon_mn_node *node, *next_node;
struct radeon_bo *bo, *next_bo;
mutex_lock(&rdev->mn_lock);
mutex_lock(&rmn->lock);
hash_del(&rmn->node);
rbtree_postorder_for_each_entry_safe(node, next_node,
&rmn->objects.rb_root, it.rb) {
interval_tree_remove(&node->it, &rmn->objects);
list_for_each_entry_safe(bo, next_bo, &node->bos, mn_list) {
bo->mn = NULL;
list_del_init(&bo->mn_list);
}
kfree(node);
}
mutex_unlock(&rmn->lock);
mutex_unlock(&rdev->mn_lock);
mmu_notifier_unregister(&rmn->mn, rmn->mm);
kfree(rmn);
}
/**
* radeon_mn_release - callback to notify about mm destruction
*
* @mn: our notifier
* @mn: the mm this callback is about
*
* Shedule a work item to lazy destroy our notifier.
*/
static void radeon_mn_release(struct mmu_notifier *mn,
struct mm_struct *mm)
{
struct radeon_mn *rmn = container_of(mn, struct radeon_mn, mn);
INIT_WORK(&rmn->work, radeon_mn_destroy);
schedule_work(&rmn->work);
}
/** /**
* radeon_mn_invalidate_range_start - callback to notify about mm change * radeon_mn_invalidate_range_start - callback to notify about mm change
* *
...@@ -183,65 +125,44 @@ static int radeon_mn_invalidate_range_start(struct mmu_notifier *mn, ...@@ -183,65 +125,44 @@ static int radeon_mn_invalidate_range_start(struct mmu_notifier *mn,
return ret; return ret;
} }
static const struct mmu_notifier_ops radeon_mn_ops = { static void radeon_mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
.release = radeon_mn_release, {
.invalidate_range_start = radeon_mn_invalidate_range_start, struct mmu_notifier_range range = {
}; .mm = mm,
.start = 0,
.end = ULONG_MAX,
.flags = 0,
.event = MMU_NOTIFY_UNMAP,
};
radeon_mn_invalidate_range_start(mn, &range);
}
/** static struct mmu_notifier *radeon_mn_alloc_notifier(struct mm_struct *mm)
* radeon_mn_get - create notifier context
*
* @rdev: radeon device pointer
*
* Creates a notifier context for current->mm.
*/
static struct radeon_mn *radeon_mn_get(struct radeon_device *rdev)
{ {
struct mm_struct *mm = current->mm;
struct radeon_mn *rmn; struct radeon_mn *rmn;
int r;
if (down_write_killable(&mm->mmap_sem))
return ERR_PTR(-EINTR);
mutex_lock(&rdev->mn_lock);
hash_for_each_possible(rdev->mn_hash, rmn, node, (unsigned long)mm)
if (rmn->mm == mm)
goto release_locks;
rmn = kzalloc(sizeof(*rmn), GFP_KERNEL); rmn = kzalloc(sizeof(*rmn), GFP_KERNEL);
if (!rmn) { if (!rmn)
rmn = ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
goto release_locks;
}
rmn->rdev = rdev;
rmn->mm = mm;
rmn->mn.ops = &radeon_mn_ops;
mutex_init(&rmn->lock); mutex_init(&rmn->lock);
rmn->objects = RB_ROOT_CACHED; rmn->objects = RB_ROOT_CACHED;
return &rmn->mn;
}
r = __mmu_notifier_register(&rmn->mn, mm); static void radeon_mn_free_notifier(struct mmu_notifier *mn)
if (r) {
goto free_rmn; kfree(container_of(mn, struct radeon_mn, mn));
hash_add(rdev->mn_hash, &rmn->node, (unsigned long)mm);
release_locks:
mutex_unlock(&rdev->mn_lock);
up_write(&mm->mmap_sem);
return rmn;
free_rmn:
mutex_unlock(&rdev->mn_lock);
up_write(&mm->mmap_sem);
kfree(rmn);
return ERR_PTR(r);
} }
static const struct mmu_notifier_ops radeon_mn_ops = {
.release = radeon_mn_release,
.invalidate_range_start = radeon_mn_invalidate_range_start,
.alloc_notifier = radeon_mn_alloc_notifier,
.free_notifier = radeon_mn_free_notifier,
};
/** /**
* radeon_mn_register - register a BO for notifier updates * radeon_mn_register - register a BO for notifier updates
* *
...@@ -254,15 +175,16 @@ static struct radeon_mn *radeon_mn_get(struct radeon_device *rdev) ...@@ -254,15 +175,16 @@ static struct radeon_mn *radeon_mn_get(struct radeon_device *rdev)
int radeon_mn_register(struct radeon_bo *bo, unsigned long addr) int radeon_mn_register(struct radeon_bo *bo, unsigned long addr)
{ {
unsigned long end = addr + radeon_bo_size(bo) - 1; unsigned long end = addr + radeon_bo_size(bo) - 1;
struct radeon_device *rdev = bo->rdev; struct mmu_notifier *mn;
struct radeon_mn *rmn; struct radeon_mn *rmn;
struct radeon_mn_node *node = NULL; struct radeon_mn_node *node = NULL;
struct list_head bos; struct list_head bos;
struct interval_tree_node *it; struct interval_tree_node *it;
rmn = radeon_mn_get(rdev); mn = mmu_notifier_get(&radeon_mn_ops, current->mm);
if (IS_ERR(rmn)) if (IS_ERR(mn))
return PTR_ERR(rmn); return PTR_ERR(mn);
rmn = container_of(mn, struct radeon_mn, mn);
INIT_LIST_HEAD(&bos); INIT_LIST_HEAD(&bos);
...@@ -309,22 +231,13 @@ int radeon_mn_register(struct radeon_bo *bo, unsigned long addr) ...@@ -309,22 +231,13 @@ int radeon_mn_register(struct radeon_bo *bo, unsigned long addr)
*/ */
void radeon_mn_unregister(struct radeon_bo *bo) void radeon_mn_unregister(struct radeon_bo *bo)
{ {
struct radeon_device *rdev = bo->rdev; struct radeon_mn *rmn = bo->mn;
struct radeon_mn *rmn;
struct list_head *head; struct list_head *head;
mutex_lock(&rdev->mn_lock);
rmn = bo->mn;
if (rmn == NULL) {
mutex_unlock(&rdev->mn_lock);
return;
}
mutex_lock(&rmn->lock); mutex_lock(&rmn->lock);
/* save the next list entry for later */ /* save the next list entry for later */
head = bo->mn_list.next; head = bo->mn_list.next;
bo->mn = NULL;
list_del(&bo->mn_list); list_del(&bo->mn_list);
if (list_empty(head)) { if (list_empty(head)) {
...@@ -335,5 +248,7 @@ void radeon_mn_unregister(struct radeon_bo *bo) ...@@ -335,5 +248,7 @@ void radeon_mn_unregister(struct radeon_bo *bo)
} }
mutex_unlock(&rmn->lock); mutex_unlock(&rmn->lock);
mutex_unlock(&rdev->mn_lock);
mmu_notifier_put(&rmn->mn);
bo->mn = NULL;
} }
...@@ -573,6 +573,7 @@ static void __exit gru_exit(void) ...@@ -573,6 +573,7 @@ static void __exit gru_exit(void)
gru_free_tables(); gru_free_tables();
misc_deregister(&gru_miscdev); misc_deregister(&gru_miscdev);
gru_proc_exit(); gru_proc_exit();
mmu_notifier_synchronize();
} }
static const struct file_operations gru_fops = { static const struct file_operations gru_fops = {
......
...@@ -307,10 +307,8 @@ struct gru_mm_tracker { /* pack to reduce size */ ...@@ -307,10 +307,8 @@ struct gru_mm_tracker { /* pack to reduce size */
struct gru_mm_struct { struct gru_mm_struct {
struct mmu_notifier ms_notifier; struct mmu_notifier ms_notifier;
atomic_t ms_refcnt;
spinlock_t ms_asid_lock; /* protects ASID assignment */ spinlock_t ms_asid_lock; /* protects ASID assignment */
atomic_t ms_range_active;/* num range_invals active */ atomic_t ms_range_active;/* num range_invals active */
char ms_released;
wait_queue_head_t ms_wait_queue; wait_queue_head_t ms_wait_queue;
DECLARE_BITMAP(ms_asidmap, GRU_MAX_GRUS); DECLARE_BITMAP(ms_asidmap, GRU_MAX_GRUS);
struct gru_mm_tracker ms_asids[GRU_MAX_GRUS]; struct gru_mm_tracker ms_asids[GRU_MAX_GRUS];
......
...@@ -235,83 +235,47 @@ static void gru_invalidate_range_end(struct mmu_notifier *mn, ...@@ -235,83 +235,47 @@ static void gru_invalidate_range_end(struct mmu_notifier *mn,
gms, range->start, range->end); gms, range->start, range->end);
} }
static void gru_release(struct mmu_notifier *mn, struct mm_struct *mm) static struct mmu_notifier *gru_alloc_notifier(struct mm_struct *mm)
{ {
struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct, struct gru_mm_struct *gms;
ms_notifier);
gms->ms_released = 1; gms = kzalloc(sizeof(*gms), GFP_KERNEL);
gru_dbg(grudev, "gms %p\n", gms); if (!gms)
return ERR_PTR(-ENOMEM);
STAT(gms_alloc);
spin_lock_init(&gms->ms_asid_lock);
init_waitqueue_head(&gms->ms_wait_queue);
return &gms->ms_notifier;
} }
static void gru_free_notifier(struct mmu_notifier *mn)
{
kfree(container_of(mn, struct gru_mm_struct, ms_notifier));
STAT(gms_free);
}
static const struct mmu_notifier_ops gru_mmuops = { static const struct mmu_notifier_ops gru_mmuops = {
.invalidate_range_start = gru_invalidate_range_start, .invalidate_range_start = gru_invalidate_range_start,
.invalidate_range_end = gru_invalidate_range_end, .invalidate_range_end = gru_invalidate_range_end,
.release = gru_release, .alloc_notifier = gru_alloc_notifier,
.free_notifier = gru_free_notifier,
}; };
/* Move this to the basic mmu_notifier file. But for now... */
static struct mmu_notifier *mmu_find_ops(struct mm_struct *mm,
const struct mmu_notifier_ops *ops)
{
struct mmu_notifier *mn, *gru_mn = NULL;
if (mm->mmu_notifier_mm) {
rcu_read_lock();
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list,
hlist)
if (mn->ops == ops) {
gru_mn = mn;
break;
}
rcu_read_unlock();
}
return gru_mn;
}
struct gru_mm_struct *gru_register_mmu_notifier(void) struct gru_mm_struct *gru_register_mmu_notifier(void)
{ {
struct gru_mm_struct *gms;
struct mmu_notifier *mn; struct mmu_notifier *mn;
int err;
mn = mmu_find_ops(current->mm, &gru_mmuops); mn = mmu_notifier_get_locked(&gru_mmuops, current->mm);
if (mn) { if (IS_ERR(mn))
gms = container_of(mn, struct gru_mm_struct, ms_notifier); return ERR_CAST(mn);
atomic_inc(&gms->ms_refcnt);
} else { return container_of(mn, struct gru_mm_struct, ms_notifier);
gms = kzalloc(sizeof(*gms), GFP_KERNEL);
if (!gms)
return ERR_PTR(-ENOMEM);
STAT(gms_alloc);
spin_lock_init(&gms->ms_asid_lock);
gms->ms_notifier.ops = &gru_mmuops;
atomic_set(&gms->ms_refcnt, 1);
init_waitqueue_head(&gms->ms_wait_queue);
err = __mmu_notifier_register(&gms->ms_notifier, current->mm);
if (err)
goto error;
}
if (gms)
gru_dbg(grudev, "gms %p, refcnt %d\n", gms,
atomic_read(&gms->ms_refcnt));
return gms;
error:
kfree(gms);
return ERR_PTR(err);
} }
void gru_drop_mmu_notifier(struct gru_mm_struct *gms) void gru_drop_mmu_notifier(struct gru_mm_struct *gms)
{ {
gru_dbg(grudev, "gms %p, refcnt %d, released %d\n", gms, mmu_notifier_put(&gms->ms_notifier);
atomic_read(&gms->ms_refcnt), gms->ms_released);
if (atomic_dec_return(&gms->ms_refcnt) == 0) {
if (!gms->ms_released)
mmu_notifier_unregister(&gms->ms_notifier, current->mm);
kfree(gms);
STAT(gms_free);
}
} }
/* /*
......
...@@ -84,15 +84,12 @@ ...@@ -84,15 +84,12 @@
* @notifiers: count of active mmu notifiers * @notifiers: count of active mmu notifiers
*/ */
struct hmm { struct hmm {
struct mm_struct *mm; struct mmu_notifier mmu_notifier;
struct kref kref;
spinlock_t ranges_lock; spinlock_t ranges_lock;
struct list_head ranges; struct list_head ranges;
struct list_head mirrors; struct list_head mirrors;
struct mmu_notifier mmu_notifier;
struct rw_semaphore mirrors_sem; struct rw_semaphore mirrors_sem;
wait_queue_head_t wq; wait_queue_head_t wq;
struct rcu_head rcu;
long notifiers; long notifiers;
}; };
...@@ -409,13 +406,6 @@ long hmm_range_dma_unmap(struct hmm_range *range, ...@@ -409,13 +406,6 @@ long hmm_range_dma_unmap(struct hmm_range *range,
*/ */
#define HMM_RANGE_DEFAULT_TIMEOUT 1000 #define HMM_RANGE_DEFAULT_TIMEOUT 1000
/* Below are for HMM internal use only! Not to be used by device driver! */
static inline void hmm_mm_init(struct mm_struct *mm)
{
mm->hmm = NULL;
}
#else /* IS_ENABLED(CONFIG_HMM_MIRROR) */
static inline void hmm_mm_init(struct mm_struct *mm) {}
#endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */ #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
#endif /* LINUX_HMM_H */ #endif /* LINUX_HMM_H */
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
struct address_space; struct address_space;
struct mem_cgroup; struct mem_cgroup;
struct hmm;
/* /*
* Each physical page in the system has a struct page associated with * Each physical page in the system has a struct page associated with
...@@ -502,11 +501,6 @@ struct mm_struct { ...@@ -502,11 +501,6 @@ struct mm_struct {
atomic_long_t hugetlb_usage; atomic_long_t hugetlb_usage;
#endif #endif
struct work_struct async_put_work; struct work_struct async_put_work;
#ifdef CONFIG_HMM_MIRROR
/* HMM needs to track a few things per mm */
struct hmm *hmm;
#endif
} __randomize_layout; } __randomize_layout;
/* /*
......
...@@ -211,6 +211,19 @@ struct mmu_notifier_ops { ...@@ -211,6 +211,19 @@ struct mmu_notifier_ops {
*/ */
void (*invalidate_range)(struct mmu_notifier *mn, struct mm_struct *mm, void (*invalidate_range)(struct mmu_notifier *mn, struct mm_struct *mm,
unsigned long start, unsigned long end); unsigned long start, unsigned long end);
/*
* These callbacks are used with the get/put interface to manage the
* lifetime of the mmu_notifier memory. alloc_notifier() returns a new
* notifier for use with the mm.
*
* free_notifier() is only called after the mmu_notifier has been
* fully put, calls to any ops callback are prevented and no ops
* callbacks are currently running. It is called from a SRCU callback
* and cannot sleep.
*/
struct mmu_notifier *(*alloc_notifier)(struct mm_struct *mm);
void (*free_notifier)(struct mmu_notifier *mn);
}; };
/* /*
...@@ -227,6 +240,9 @@ struct mmu_notifier_ops { ...@@ -227,6 +240,9 @@ struct mmu_notifier_ops {
struct mmu_notifier { struct mmu_notifier {
struct hlist_node hlist; struct hlist_node hlist;
const struct mmu_notifier_ops *ops; const struct mmu_notifier_ops *ops;
struct mm_struct *mm;
struct rcu_head rcu;
unsigned int users;
}; };
static inline int mm_has_notifiers(struct mm_struct *mm) static inline int mm_has_notifiers(struct mm_struct *mm)
...@@ -234,6 +250,21 @@ static inline int mm_has_notifiers(struct mm_struct *mm) ...@@ -234,6 +250,21 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
return unlikely(mm->mmu_notifier_mm); return unlikely(mm->mmu_notifier_mm);
} }
struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
struct mm_struct *mm);
static inline struct mmu_notifier *
mmu_notifier_get(const struct mmu_notifier_ops *ops, struct mm_struct *mm)
{
struct mmu_notifier *ret;
down_write(&mm->mmap_sem);
ret = mmu_notifier_get_locked(ops, mm);
up_write(&mm->mmap_sem);
return ret;
}
void mmu_notifier_put(struct mmu_notifier *mn);
void mmu_notifier_synchronize(void);
extern int mmu_notifier_register(struct mmu_notifier *mn, extern int mmu_notifier_register(struct mmu_notifier *mn,
struct mm_struct *mm); struct mm_struct *mm);
extern int __mmu_notifier_register(struct mmu_notifier *mn, extern int __mmu_notifier_register(struct mmu_notifier *mn,
...@@ -581,6 +612,10 @@ static inline void mmu_notifier_mm_destroy(struct mm_struct *mm) ...@@ -581,6 +612,10 @@ static inline void mmu_notifier_mm_destroy(struct mm_struct *mm)
#define pudp_huge_clear_flush_notify pudp_huge_clear_flush #define pudp_huge_clear_flush_notify pudp_huge_clear_flush
#define set_pte_at_notify set_pte_at #define set_pte_at_notify set_pte_at
static inline void mmu_notifier_synchronize(void)
{
}
#endif /* CONFIG_MMU_NOTIFIER */ #endif /* CONFIG_MMU_NOTIFIER */
#endif /* _LINUX_MMU_NOTIFIER_H */ #endif /* _LINUX_MMU_NOTIFIER_H */
...@@ -1007,7 +1007,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, ...@@ -1007,7 +1007,6 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
mm_init_owner(mm, p); mm_init_owner(mm, p);
RCU_INIT_POINTER(mm->exe_file, NULL); RCU_INIT_POINTER(mm->exe_file, NULL);
mmu_notifier_mm_init(mm); mmu_notifier_mm_init(mm);
hmm_mm_init(mm);
init_tlb_flush_pending(mm); init_tlb_flush_pending(mm);
#if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
mm->pmd_huge_pte = NULL; mm->pmd_huge_pte = NULL;
......
...@@ -26,101 +26,37 @@ ...@@ -26,101 +26,37 @@
#include <linux/mmu_notifier.h> #include <linux/mmu_notifier.h>
#include <linux/memory_hotplug.h> #include <linux/memory_hotplug.h>
static const struct mmu_notifier_ops hmm_mmu_notifier_ops; static struct mmu_notifier *hmm_alloc_notifier(struct mm_struct *mm)
/**
* hmm_get_or_create - register HMM against an mm (HMM internal)
*
* @mm: mm struct to attach to
* Return: an HMM object, either by referencing the existing
* (per-process) object, or by creating a new one.
*
* This is not intended to be used directly by device drivers. If mm already
* has an HMM struct then it get a reference on it and returns it. Otherwise
* it allocates an HMM struct, initializes it, associate it with the mm and
* returns it.
*/
static struct hmm *hmm_get_or_create(struct mm_struct *mm)
{ {
struct hmm *hmm; struct hmm *hmm;
lockdep_assert_held_write(&mm->mmap_sem); hmm = kzalloc(sizeof(*hmm), GFP_KERNEL);
/* Abuse the page_table_lock to also protect mm->hmm. */
spin_lock(&mm->page_table_lock);
hmm = mm->hmm;
if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
goto out_unlock;
spin_unlock(&mm->page_table_lock);
hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
if (!hmm) if (!hmm)
return NULL; return ERR_PTR(-ENOMEM);
init_waitqueue_head(&hmm->wq); init_waitqueue_head(&hmm->wq);
INIT_LIST_HEAD(&hmm->mirrors); INIT_LIST_HEAD(&hmm->mirrors);
init_rwsem(&hmm->mirrors_sem); init_rwsem(&hmm->mirrors_sem);
hmm->mmu_notifier.ops = NULL;
INIT_LIST_HEAD(&hmm->ranges); INIT_LIST_HEAD(&hmm->ranges);
spin_lock_init(&hmm->ranges_lock); spin_lock_init(&hmm->ranges_lock);
kref_init(&hmm->kref);
hmm->notifiers = 0; hmm->notifiers = 0;
hmm->mm = mm; return &hmm->mmu_notifier;
hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
kfree(hmm);
return NULL;
}
mmgrab(hmm->mm);
/*
* We hold the exclusive mmap_sem here so we know that mm->hmm is
* still NULL or 0 kref, and is safe to update.
*/
spin_lock(&mm->page_table_lock);
mm->hmm = hmm;
out_unlock:
spin_unlock(&mm->page_table_lock);
return hmm;
} }
static void hmm_free_rcu(struct rcu_head *rcu) static void hmm_free_notifier(struct mmu_notifier *mn)
{ {
struct hmm *hmm = container_of(rcu, struct hmm, rcu); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
mmdrop(hmm->mm); WARN_ON(!list_empty(&hmm->ranges));
WARN_ON(!list_empty(&hmm->mirrors));
kfree(hmm); kfree(hmm);
} }
static void hmm_free(struct kref *kref)
{
struct hmm *hmm = container_of(kref, struct hmm, kref);
spin_lock(&hmm->mm->page_table_lock);
if (hmm->mm->hmm == hmm)
hmm->mm->hmm = NULL;
spin_unlock(&hmm->mm->page_table_lock);
mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
}
static inline void hmm_put(struct hmm *hmm)
{
kref_put(&hmm->kref, hmm_free);
}
static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
{ {
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
struct hmm_mirror *mirror; struct hmm_mirror *mirror;
/* Bail out if hmm is in the process of being freed */
if (!kref_get_unless_zero(&hmm->kref))
return;
/* /*
* Since hmm_range_register() holds the mmget() lock hmm_release() is * Since hmm_range_register() holds the mmget() lock hmm_release() is
* prevented as long as a range exists. * prevented as long as a range exists.
...@@ -137,8 +73,6 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -137,8 +73,6 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
mirror->ops->release(mirror); mirror->ops->release(mirror);
} }
up_read(&hmm->mirrors_sem); up_read(&hmm->mirrors_sem);
hmm_put(hmm);
} }
static void notifiers_decrement(struct hmm *hmm) static void notifiers_decrement(struct hmm *hmm)
...@@ -169,9 +103,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ...@@ -169,9 +103,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
unsigned long flags; unsigned long flags;
int ret = 0; int ret = 0;
if (!kref_get_unless_zero(&hmm->kref))
return 0;
spin_lock_irqsave(&hmm->ranges_lock, flags); spin_lock_irqsave(&hmm->ranges_lock, flags);
hmm->notifiers++; hmm->notifiers++;
list_for_each_entry(range, &hmm->ranges, list) { list_for_each_entry(range, &hmm->ranges, list) {
...@@ -206,7 +137,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ...@@ -206,7 +137,6 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
out: out:
if (ret) if (ret)
notifiers_decrement(hmm); notifiers_decrement(hmm);
hmm_put(hmm);
return ret; return ret;
} }
...@@ -215,17 +145,15 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn, ...@@ -215,17 +145,15 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
{ {
struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
if (!kref_get_unless_zero(&hmm->kref))
return;
notifiers_decrement(hmm); notifiers_decrement(hmm);
hmm_put(hmm);
} }
static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
.release = hmm_release, .release = hmm_release,
.invalidate_range_start = hmm_invalidate_range_start, .invalidate_range_start = hmm_invalidate_range_start,
.invalidate_range_end = hmm_invalidate_range_end, .invalidate_range_end = hmm_invalidate_range_end,
.alloc_notifier = hmm_alloc_notifier,
.free_notifier = hmm_free_notifier,
}; };
/* /*
...@@ -237,18 +165,27 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { ...@@ -237,18 +165,27 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
* *
* To start mirroring a process address space, the device driver must register * To start mirroring a process address space, the device driver must register
* an HMM mirror struct. * an HMM mirror struct.
*
* The caller cannot unregister the hmm_mirror while any ranges are
* registered.
*
* Callers using this function must put a call to mmu_notifier_synchronize()
* in their module exit functions.
*/ */
int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm) int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
{ {
struct mmu_notifier *mn;
lockdep_assert_held_write(&mm->mmap_sem); lockdep_assert_held_write(&mm->mmap_sem);
/* Sanity check */ /* Sanity check */
if (!mm || !mirror || !mirror->ops) if (!mm || !mirror || !mirror->ops)
return -EINVAL; return -EINVAL;
mirror->hmm = hmm_get_or_create(mm); mn = mmu_notifier_get_locked(&hmm_mmu_notifier_ops, mm);
if (!mirror->hmm) if (IS_ERR(mn))
return -ENOMEM; return PTR_ERR(mn);
mirror->hmm = container_of(mn, struct hmm, mmu_notifier);
down_write(&mirror->hmm->mirrors_sem); down_write(&mirror->hmm->mirrors_sem);
list_add(&mirror->list, &mirror->hmm->mirrors); list_add(&mirror->list, &mirror->hmm->mirrors);
...@@ -272,7 +209,7 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror) ...@@ -272,7 +209,7 @@ void hmm_mirror_unregister(struct hmm_mirror *mirror)
down_write(&hmm->mirrors_sem); down_write(&hmm->mirrors_sem);
list_del(&mirror->list); list_del(&mirror->list);
up_write(&hmm->mirrors_sem); up_write(&hmm->mirrors_sem);
hmm_put(hmm); mmu_notifier_put(&hmm->mmu_notifier);
} }
EXPORT_SYMBOL(hmm_mirror_unregister); EXPORT_SYMBOL(hmm_mirror_unregister);
...@@ -854,14 +791,13 @@ int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror) ...@@ -854,14 +791,13 @@ int hmm_range_register(struct hmm_range *range, struct hmm_mirror *mirror)
return -EINVAL; return -EINVAL;
/* Prevent hmm_release() from running while the range is valid */ /* Prevent hmm_release() from running while the range is valid */
if (!mmget_not_zero(hmm->mm)) if (!mmget_not_zero(hmm->mmu_notifier.mm))
return -EFAULT; return -EFAULT;
/* Initialize range to track CPU page table updates. */ /* Initialize range to track CPU page table updates. */
spin_lock_irqsave(&hmm->ranges_lock, flags); spin_lock_irqsave(&hmm->ranges_lock, flags);
range->hmm = hmm; range->hmm = hmm;
kref_get(&hmm->kref);
list_add(&range->list, &hmm->ranges); list_add(&range->list, &hmm->ranges);
/* /*
...@@ -893,8 +829,7 @@ void hmm_range_unregister(struct hmm_range *range) ...@@ -893,8 +829,7 @@ void hmm_range_unregister(struct hmm_range *range)
spin_unlock_irqrestore(&hmm->ranges_lock, flags); spin_unlock_irqrestore(&hmm->ranges_lock, flags);
/* Drop reference taken by hmm_range_register() */ /* Drop reference taken by hmm_range_register() */
mmput(hmm->mm); mmput(hmm->mmu_notifier.mm);
hmm_put(hmm);
/* /*
* The range is now invalid and the ref on the hmm is dropped, so * The range is now invalid and the ref on the hmm is dropped, so
...@@ -944,14 +879,14 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags) ...@@ -944,14 +879,14 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
struct mm_walk mm_walk; struct mm_walk mm_walk;
int ret; int ret;
lockdep_assert_held(&hmm->mm->mmap_sem); lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
do { do {
/* If range is no longer valid force retry. */ /* If range is no longer valid force retry. */
if (!range->valid) if (!range->valid)
return -EBUSY; return -EBUSY;
vma = find_vma(hmm->mm, start); vma = find_vma(hmm->mmu_notifier.mm, start);
if (vma == NULL || (vma->vm_flags & device_vma)) if (vma == NULL || (vma->vm_flags & device_vma))
return -EFAULT; return -EFAULT;
......
...@@ -236,33 +236,41 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm, ...@@ -236,33 +236,41 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm,
} }
EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range); EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range);
static int do_mmu_notifier_register(struct mmu_notifier *mn, /*
struct mm_struct *mm, * Same as mmu_notifier_register but here the caller must hold the
int take_mmap_sem) * mmap_sem in write mode.
*/
int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
{ {
struct mmu_notifier_mm *mmu_notifier_mm; struct mmu_notifier_mm *mmu_notifier_mm = NULL;
int ret; int ret;
lockdep_assert_held_write(&mm->mmap_sem);
BUG_ON(atomic_read(&mm->mm_users) <= 0); BUG_ON(atomic_read(&mm->mm_users) <= 0);
ret = -ENOMEM; mn->mm = mm;
mmu_notifier_mm = kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL); mn->users = 1;
if (unlikely(!mmu_notifier_mm))
goto out;
if (take_mmap_sem) if (!mm->mmu_notifier_mm) {
down_write(&mm->mmap_sem); /*
ret = mm_take_all_locks(mm); * kmalloc cannot be called under mm_take_all_locks(), but we
if (unlikely(ret)) * know that mm->mmu_notifier_mm can't change while we hold
goto out_clean; * the write side of the mmap_sem.
*/
mmu_notifier_mm =
kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
if (!mmu_notifier_mm)
return -ENOMEM;
if (!mm_has_notifiers(mm)) {
INIT_HLIST_HEAD(&mmu_notifier_mm->list); INIT_HLIST_HEAD(&mmu_notifier_mm->list);
spin_lock_init(&mmu_notifier_mm->lock); spin_lock_init(&mmu_notifier_mm->lock);
mm->mmu_notifier_mm = mmu_notifier_mm;
mmu_notifier_mm = NULL;
} }
ret = mm_take_all_locks(mm);
if (unlikely(ret))
goto out_clean;
/* Pairs with the mmdrop in mmu_notifier_unregister_* */
mmgrab(mm); mmgrab(mm);
/* /*
...@@ -273,48 +281,118 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn, ...@@ -273,48 +281,118 @@ static int do_mmu_notifier_register(struct mmu_notifier *mn,
* We can't race against any other mmu notifier method either * We can't race against any other mmu notifier method either
* thanks to mm_take_all_locks(). * thanks to mm_take_all_locks().
*/ */
if (mmu_notifier_mm)
mm->mmu_notifier_mm = mmu_notifier_mm;
spin_lock(&mm->mmu_notifier_mm->lock); spin_lock(&mm->mmu_notifier_mm->lock);
hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list); hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list);
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->mmu_notifier_mm->lock);
mm_drop_all_locks(mm); mm_drop_all_locks(mm);
BUG_ON(atomic_read(&mm->mm_users) <= 0);
return 0;
out_clean: out_clean:
if (take_mmap_sem)
up_write(&mm->mmap_sem);
kfree(mmu_notifier_mm); kfree(mmu_notifier_mm);
out:
BUG_ON(atomic_read(&mm->mm_users) <= 0);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(__mmu_notifier_register);
/* /**
* mmu_notifier_register - Register a notifier on a mm
* @mn: The notifier to attach
* @mm: The mm to attach the notifier to
*
* Must not hold mmap_sem nor any other VM related lock when calling * Must not hold mmap_sem nor any other VM related lock when calling
* this registration function. Must also ensure mm_users can't go down * this registration function. Must also ensure mm_users can't go down
* to zero while this runs to avoid races with mmu_notifier_release, * to zero while this runs to avoid races with mmu_notifier_release,
* so mm has to be current->mm or the mm should be pinned safely such * so mm has to be current->mm or the mm should be pinned safely such
* as with get_task_mm(). If the mm is not current->mm, the mm_users * as with get_task_mm(). If the mm is not current->mm, the mm_users
* pin should be released by calling mmput after mmu_notifier_register * pin should be released by calling mmput after mmu_notifier_register
* returns. mmu_notifier_unregister must be always called to * returns.
* unregister the notifier. mm_count is automatically pinned to allow *
* mmu_notifier_unregister to safely run at any time later, before or * mmu_notifier_unregister() or mmu_notifier_put() must be always called to
* after exit_mmap. ->release will always be called before exit_mmap * unregister the notifier.
* frees the pages. *
* While the caller has a mmu_notifier get the mn->mm pointer will remain
* valid, and can be converted to an active mm pointer via mmget_not_zero().
*/ */
int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
{ {
return do_mmu_notifier_register(mn, mm, 1); int ret;
down_write(&mm->mmap_sem);
ret = __mmu_notifier_register(mn, mm);
up_write(&mm->mmap_sem);
return ret;
} }
EXPORT_SYMBOL_GPL(mmu_notifier_register); EXPORT_SYMBOL_GPL(mmu_notifier_register);
/* static struct mmu_notifier *
* Same as mmu_notifier_register but here the caller must hold the find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
* mmap_sem in write mode. {
struct mmu_notifier *mn;
spin_lock(&mm->mmu_notifier_mm->lock);
hlist_for_each_entry_rcu (mn, &mm->mmu_notifier_mm->list, hlist) {
if (mn->ops != ops)
continue;
if (likely(mn->users != UINT_MAX))
mn->users++;
else
mn = ERR_PTR(-EOVERFLOW);
spin_unlock(&mm->mmu_notifier_mm->lock);
return mn;
}
spin_unlock(&mm->mmu_notifier_mm->lock);
return NULL;
}
/**
* mmu_notifier_get_locked - Return the single struct mmu_notifier for
* the mm & ops
* @ops: The operations struct being subscribe with
* @mm : The mm to attach notifiers too
*
* This function either allocates a new mmu_notifier via
* ops->alloc_notifier(), or returns an already existing notifier on the
* list. The value of the ops pointer is used to determine when two notifiers
* are the same.
*
* Each call to mmu_notifier_get() must be paired with a call to
* mmu_notifier_put(). The caller must hold the write side of mm->mmap_sem.
*
* While the caller has a mmu_notifier get the mm pointer will remain valid,
* and can be converted to an active mm pointer via mmget_not_zero().
*/ */
int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
struct mm_struct *mm)
{ {
return do_mmu_notifier_register(mn, mm, 0); struct mmu_notifier *mn;
int ret;
lockdep_assert_held_write(&mm->mmap_sem);
if (mm->mmu_notifier_mm) {
mn = find_get_mmu_notifier(mm, ops);
if (mn)
return mn;
}
mn = ops->alloc_notifier(mm);
if (IS_ERR(mn))
return mn;
mn->ops = ops;
ret = __mmu_notifier_register(mn, mm);
if (ret)
goto out_free;
return mn;
out_free:
mn->ops->free_notifier(mn);
return ERR_PTR(ret);
} }
EXPORT_SYMBOL_GPL(__mmu_notifier_register); EXPORT_SYMBOL_GPL(mmu_notifier_get_locked);
/* this is called after the last mmu_notifier_unregister() returned */ /* this is called after the last mmu_notifier_unregister() returned */
void __mmu_notifier_mm_destroy(struct mm_struct *mm) void __mmu_notifier_mm_destroy(struct mm_struct *mm)
...@@ -394,6 +472,75 @@ void mmu_notifier_unregister_no_release(struct mmu_notifier *mn, ...@@ -394,6 +472,75 @@ void mmu_notifier_unregister_no_release(struct mmu_notifier *mn,
} }
EXPORT_SYMBOL_GPL(mmu_notifier_unregister_no_release); EXPORT_SYMBOL_GPL(mmu_notifier_unregister_no_release);
static void mmu_notifier_free_rcu(struct rcu_head *rcu)
{
struct mmu_notifier *mn = container_of(rcu, struct mmu_notifier, rcu);
struct mm_struct *mm = mn->mm;
mn->ops->free_notifier(mn);
/* Pairs with the get in __mmu_notifier_register() */
mmdrop(mm);
}
/**
* mmu_notifier_put - Release the reference on the notifier
* @mn: The notifier to act on
*
* This function must be paired with each mmu_notifier_get(), it releases the
* reference obtained by the get. If this is the last reference then process
* to free the notifier will be run asynchronously.
*
* Unlike mmu_notifier_unregister() the get/put flow only calls ops->release
* when the mm_struct is destroyed. Instead free_notifier is always called to
* release any resources held by the user.
*
* As ops->release is not guaranteed to be called, the user must ensure that
* all sptes are dropped, and no new sptes can be established before
* mmu_notifier_put() is called.
*
* This function can be called from the ops->release callback, however the
* caller must still ensure it is called pairwise with mmu_notifier_get().
*
* Modules calling this function must call mmu_notifier_synchronize() in
* their __exit functions to ensure the async work is completed.
*/
void mmu_notifier_put(struct mmu_notifier *mn)
{
struct mm_struct *mm = mn->mm;
spin_lock(&mm->mmu_notifier_mm->lock);
if (WARN_ON(!mn->users) || --mn->users)
goto out_unlock;
hlist_del_init_rcu(&mn->hlist);
spin_unlock(&mm->mmu_notifier_mm->lock);
call_srcu(&srcu, &mn->rcu, mmu_notifier_free_rcu);
return;
out_unlock:
spin_unlock(&mm->mmu_notifier_mm->lock);
}
EXPORT_SYMBOL_GPL(mmu_notifier_put);
/**
* mmu_notifier_synchronize - Ensure all mmu_notifiers are freed
*
* This function ensures that all outstanding async SRU work from
* mmu_notifier_put() is completed. After it returns any mmu_notifier_ops
* associated with an unused mmu_notifier will no longer be called.
*
* Before using the caller must ensure that all of its mmu_notifiers have been
* fully released via mmu_notifier_put().
*
* Modules using the mmu_notifier_put() API should call this in their __exit
* function to avoid module unloading races.
*/
void mmu_notifier_synchronize(void)
{
synchronize_srcu(&srcu);
}
EXPORT_SYMBOL_GPL(mmu_notifier_synchronize);
bool bool
mmu_notifier_range_update_to_read_only(const struct mmu_notifier_range *range) mmu_notifier_range_update_to_read_only(const struct mmu_notifier_range *range)
{ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment