Commit be957c88 authored by Jason Gunthorpe's avatar Jason Gunthorpe

mm/hmm: make hmm_range_fault return 0 or -1

hmm_vma_walk->last is supposed to be updated after every write to the
pfns, so that it can be returned by hmm_range_fault(). However, this is
not done consistently. Fortunately nothing checks the return code of
hmm_range_fault() for anything other than error.

More importantly last must be set before returning -EBUSY as it is used to
prevent reading an output pfn as an input flags when the loop restarts.

For clarity and simplicity make hmm_range_fault() return 0 or -ERRNO. Only
set last when returning -EBUSY.

Link: https://lore.kernel.org/r/2-v2-b4e84f444c7d+24f57-hmm_no_flags_jgg@mellanox.comAcked-by: default avatarFelix Kuehling <Felix.Kuehling@amd.com>
Tested-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Reviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 0e698dfa
...@@ -161,7 +161,7 @@ device must complete the update before the driver callback returns. ...@@ -161,7 +161,7 @@ device must complete the update before the driver callback returns.
When the device driver wants to populate a range of virtual addresses, it can When the device driver wants to populate a range of virtual addresses, it can
use:: use::
long hmm_range_fault(struct hmm_range *range); int hmm_range_fault(struct hmm_range *range);
It will trigger a page fault on missing or read-only entries if write access is It will trigger a page fault on missing or read-only entries if write access is
requested (see below). Page faults use the generic mm page fault code path just requested (see below). Page faults use the generic mm page fault code path just
......
...@@ -852,12 +852,12 @@ int amdgpu_ttm_tt_get_user_pages(struct amdgpu_bo *bo, struct page **pages) ...@@ -852,12 +852,12 @@ int amdgpu_ttm_tt_get_user_pages(struct amdgpu_bo *bo, struct page **pages)
down_read(&mm->mmap_sem); down_read(&mm->mmap_sem);
r = hmm_range_fault(range); r = hmm_range_fault(range);
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
if (unlikely(r <= 0)) { if (unlikely(r)) {
/* /*
* FIXME: This timeout should encompass the retry from * FIXME: This timeout should encompass the retry from
* mmu_interval_read_retry() as well. * mmu_interval_read_retry() as well.
*/ */
if ((r == 0 || r == -EBUSY) && !time_after(jiffies, timeout)) if (r == -EBUSY && !time_after(jiffies, timeout))
goto retry; goto retry;
goto out_free_pfns; goto out_free_pfns;
} }
......
...@@ -536,7 +536,7 @@ static int nouveau_range_fault(struct nouveau_svmm *svmm, ...@@ -536,7 +536,7 @@ static int nouveau_range_fault(struct nouveau_svmm *svmm,
.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT, .pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT,
}; };
struct mm_struct *mm = notifier->notifier.mm; struct mm_struct *mm = notifier->notifier.mm;
long ret; int ret;
while (true) { while (true) {
if (time_after(jiffies, timeout)) if (time_after(jiffies, timeout))
...@@ -548,8 +548,8 @@ static int nouveau_range_fault(struct nouveau_svmm *svmm, ...@@ -548,8 +548,8 @@ static int nouveau_range_fault(struct nouveau_svmm *svmm,
down_read(&mm->mmap_sem); down_read(&mm->mmap_sem);
ret = hmm_range_fault(&range); ret = hmm_range_fault(&range);
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
if (ret <= 0) { if (ret) {
if (ret == 0 || ret == -EBUSY) if (ret == -EBUSY)
continue; continue;
return ret; return ret;
} }
......
...@@ -120,7 +120,7 @@ static inline struct page *hmm_device_entry_to_page(const struct hmm_range *rang ...@@ -120,7 +120,7 @@ static inline struct page *hmm_device_entry_to_page(const struct hmm_range *rang
/* /*
* Please see Documentation/vm/hmm.rst for how to use the range API. * Please see Documentation/vm/hmm.rst for how to use the range API.
*/ */
long hmm_range_fault(struct hmm_range *range); int hmm_range_fault(struct hmm_range *range);
/* /*
* HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range * HMM_RANGE_DEFAULT_TIMEOUT - default timeout (ms) when waiting for a range
......
...@@ -174,7 +174,6 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end, ...@@ -174,7 +174,6 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
} }
if (required_fault) if (required_fault)
return hmm_vma_fault(addr, end, required_fault, walk); return hmm_vma_fault(addr, end, required_fault, walk);
hmm_vma_walk->last = addr;
return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE); return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
} }
...@@ -207,7 +206,6 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr, ...@@ -207,7 +206,6 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags; pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
hmm_vma_walk->last = end;
return 0; return 0;
} }
#else /* CONFIG_TRANSPARENT_HUGEPAGE */ #else /* CONFIG_TRANSPARENT_HUGEPAGE */
...@@ -386,13 +384,10 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp, ...@@ -386,13 +384,10 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, pfns); r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, pfns);
if (r) { if (r) {
/* hmm_vma_handle_pte() did pte_unmap() */ /* hmm_vma_handle_pte() did pte_unmap() */
hmm_vma_walk->last = addr;
return r; return r;
} }
} }
pte_unmap(ptep - 1); pte_unmap(ptep - 1);
hmm_vma_walk->last = addr;
return 0; return 0;
} }
...@@ -455,7 +450,6 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end, ...@@ -455,7 +450,6 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
for (i = 0; i < npages; ++i, ++pfn) for (i = 0; i < npages; ++i, ++pfn)
pfns[i] = hmm_device_entry_from_pfn(range, pfn) | pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
cpu_flags; cpu_flags;
hmm_vma_walk->last = end;
goto out_unlock; goto out_unlock;
} }
...@@ -500,7 +494,6 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask, ...@@ -500,7 +494,6 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
for (; addr < end; addr += PAGE_SIZE, i++, pfn++) for (; addr < end; addr += PAGE_SIZE, i++, pfn++)
range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) | range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
cpu_flags; cpu_flags;
hmm_vma_walk->last = end;
spin_unlock(ptl); spin_unlock(ptl);
return 0; return 0;
} }
...@@ -537,7 +530,6 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end, ...@@ -537,7 +530,6 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
return -EFAULT; return -EFAULT;
hmm_pfns_fill(start, end, range, HMM_PFN_ERROR); hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
hmm_vma_walk->last = end;
/* Skip this vma and continue processing the next vma. */ /* Skip this vma and continue processing the next vma. */
return 1; return 1;
...@@ -555,9 +547,7 @@ static const struct mm_walk_ops hmm_walk_ops = { ...@@ -555,9 +547,7 @@ static const struct mm_walk_ops hmm_walk_ops = {
* hmm_range_fault - try to fault some address in a virtual address range * hmm_range_fault - try to fault some address in a virtual address range
* @range: argument structure * @range: argument structure
* *
* Return: the number of valid pages in range->pfns[] (from range start * Returns 0 on success or one of the following error codes:
* address), which may be zero. On error one of the following status codes
* can be returned:
* *
* -EINVAL: Invalid arguments or mm or virtual address is in an invalid vma * -EINVAL: Invalid arguments or mm or virtual address is in an invalid vma
* (e.g., device file vma). * (e.g., device file vma).
...@@ -572,7 +562,7 @@ static const struct mm_walk_ops hmm_walk_ops = { ...@@ -572,7 +562,7 @@ static const struct mm_walk_ops hmm_walk_ops = {
* This is similar to get_user_pages(), except that it can read the page tables * This is similar to get_user_pages(), except that it can read the page tables
* without mutating them (ie causing faults). * without mutating them (ie causing faults).
*/ */
long hmm_range_fault(struct hmm_range *range) int hmm_range_fault(struct hmm_range *range)
{ {
struct hmm_vma_walk hmm_vma_walk = { struct hmm_vma_walk hmm_vma_walk = {
.range = range, .range = range,
...@@ -590,10 +580,13 @@ long hmm_range_fault(struct hmm_range *range) ...@@ -590,10 +580,13 @@ long hmm_range_fault(struct hmm_range *range)
return -EBUSY; return -EBUSY;
ret = walk_page_range(mm, hmm_vma_walk.last, range->end, ret = walk_page_range(mm, hmm_vma_walk.last, range->end,
&hmm_walk_ops, &hmm_vma_walk); &hmm_walk_ops, &hmm_vma_walk);
/*
* When -EBUSY is returned the loop restarts with
* hmm_vma_walk.last set to an address that has not been stored
* in pfns. All entries < last in the pfn array are set to their
* output, and all >= are still at their input values.
*/
} while (ret == -EBUSY); } while (ret == -EBUSY);
return ret;
if (ret)
return ret;
return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
} }
EXPORT_SYMBOL(hmm_range_fault); EXPORT_SYMBOL(hmm_range_fault);
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment