Commit f8c888a3 authored by Christoph Hellwig's avatar Christoph Hellwig Committed by Jason Gunthorpe

mm/hmm: don't handle the non-fault case in hmm_vma_walk_hole_()

Setting a pfns entry to NONE before returning -EBUSY is a bug that will
cause corruption of the input flags on the next loop.

There is just a single caller using hmm_vma_walk_hole_() for the non-fault
case.  Use hmm_pfns_fill() to fill the whole pfn array with zeroes in the
only caller for the non-fault case and remove the non-fault path from
hmm_vma_walk_hole_(). This avoids setting NONE before returning -EBUSY.

Also rename the function to hmm_vma_fault() to better describe what it
does.

Fixes: 2aee09d8 ("mm/hmm: change hmm_vma_fault() to allow write fault on page basis")
Link: https://lore.kernel.org/r/20200316135310.899364-5-hch@lst.deSigned-off-by: default avatarChristoph Hellwig <hch@lst.de>
Reviewed-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 45050692
...@@ -73,45 +73,41 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end, ...@@ -73,45 +73,41 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
} }
/* /*
* hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s) * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
* @addr: range virtual start address (inclusive) * @addr: range virtual start address (inclusive)
* @end: range virtual end address (exclusive) * @end: range virtual end address (exclusive)
* @fault: should we fault or not ? * @fault: should we fault or not ?
* @write_fault: write fault ? * @write_fault: write fault ?
* @walk: mm_walk structure * @walk: mm_walk structure
* Return: 0 on success, -EBUSY after page fault, or page fault error * Return: -EBUSY after page fault, or page fault error
* *
* This function will be called whenever pmd_none() or pte_none() returns true, * This function will be called whenever pmd_none() or pte_none() returns true,
* or whenever there is no page directory covering the virtual address range. * or whenever there is no page directory covering the virtual address range.
*/ */
static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end, static int hmm_vma_fault(unsigned long addr, unsigned long end,
bool fault, bool write_fault, bool fault, bool write_fault,
struct mm_walk *walk) struct mm_walk *walk)
{ {
struct hmm_vma_walk *hmm_vma_walk = walk->private; struct hmm_vma_walk *hmm_vma_walk = walk->private;
struct hmm_range *range = hmm_vma_walk->range; struct hmm_range *range = hmm_vma_walk->range;
uint64_t *pfns = range->pfns; uint64_t *pfns = range->pfns;
unsigned long i; unsigned long i = (addr - range->start) >> PAGE_SHIFT;
WARN_ON_ONCE(!fault && !write_fault);
hmm_vma_walk->last = addr; hmm_vma_walk->last = addr;
i = (addr - range->start) >> PAGE_SHIFT;
if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE)) if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE))
return -EPERM; return -EPERM;
for (; addr < end; addr += PAGE_SIZE, i++) { for (; addr < end; addr += PAGE_SIZE, i++) {
pfns[i] = range->values[HMM_PFN_NONE]; int ret;
if (fault || write_fault) {
int ret;
ret = hmm_vma_do_fault(walk, addr, write_fault, ret = hmm_vma_do_fault(walk, addr, write_fault, &pfns[i]);
&pfns[i]); if (ret != -EBUSY)
if (ret != -EBUSY) return ret;
return ret;
}
} }
return (fault || write_fault) ? -EBUSY : 0; return -EBUSY;
} }
static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk, static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
...@@ -193,7 +189,10 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end, ...@@ -193,7 +189,10 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
pfns = &range->pfns[i]; pfns = &range->pfns[i];
hmm_range_need_fault(hmm_vma_walk, pfns, npages, hmm_range_need_fault(hmm_vma_walk, pfns, npages,
0, &fault, &write_fault); 0, &fault, &write_fault);
return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); if (fault || write_fault)
return hmm_vma_fault(addr, end, fault, write_fault, walk);
hmm_vma_walk->last = addr;
return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
} }
static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd) static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
...@@ -221,7 +220,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr, ...@@ -221,7 +220,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
&fault, &write_fault); &fault, &write_fault);
if (fault || write_fault) if (fault || write_fault)
return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); return hmm_vma_fault(addr, end, fault, write_fault, walk);
pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) { for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
...@@ -360,7 +359,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr, ...@@ -360,7 +359,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
} }
pte_unmap(ptep); pte_unmap(ptep);
/* Fault any virtual address we were asked to fault */ /* Fault any virtual address we were asked to fault */
return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); return hmm_vma_fault(addr, end, fault, write_fault, walk);
} }
static int hmm_vma_walk_pmd(pmd_t *pmdp, static int hmm_vma_walk_pmd(pmd_t *pmdp,
...@@ -512,7 +511,7 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end, ...@@ -512,7 +511,7 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
cpu_flags, &fault, &write_fault); cpu_flags, &fault, &write_fault);
if (fault || write_fault) { if (fault || write_fault) {
spin_unlock(ptl); spin_unlock(ptl);
return hmm_vma_walk_hole_(addr, end, fault, write_fault, return hmm_vma_fault(addr, end, fault, write_fault,
walk); walk);
} }
...@@ -572,7 +571,7 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask, ...@@ -572,7 +571,7 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
&fault, &write_fault); &fault, &write_fault);
if (fault || write_fault) { if (fault || write_fault) {
spin_unlock(ptl); spin_unlock(ptl);
return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); return hmm_vma_fault(addr, end, fault, write_fault, walk);
} }
pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT); pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment