Commit 2286a691 authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Andrew Morton

mm: change mprotect_fixup to vma iterator

Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-18-Liam.Howlett@oracle.comSigned-off-by: default avatarLiam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent 11a9b902
...@@ -758,6 +758,7 @@ int setup_arg_pages(struct linux_binprm *bprm, ...@@ -758,6 +758,7 @@ int setup_arg_pages(struct linux_binprm *bprm,
unsigned long stack_expand; unsigned long stack_expand;
unsigned long rlim_stack; unsigned long rlim_stack;
struct mmu_gather tlb; struct mmu_gather tlb;
struct vma_iterator vmi;
#ifdef CONFIG_STACK_GROWSUP #ifdef CONFIG_STACK_GROWSUP
/* Limit stack size */ /* Limit stack size */
...@@ -812,8 +813,10 @@ int setup_arg_pages(struct linux_binprm *bprm, ...@@ -812,8 +813,10 @@ int setup_arg_pages(struct linux_binprm *bprm,
vm_flags |= mm->def_flags; vm_flags |= mm->def_flags;
vm_flags |= VM_STACK_INCOMPLETE_SETUP; vm_flags |= VM_STACK_INCOMPLETE_SETUP;
vma_iter_init(&vmi, mm, vma->vm_start);
tlb_gather_mmu(&tlb, mm); tlb_gather_mmu(&tlb, mm);
ret = mprotect_fixup(&tlb, vma, &prev, vma->vm_start, vma->vm_end, ret = mprotect_fixup(&vmi, &tlb, vma, &prev, vma->vm_start, vma->vm_end,
vm_flags); vm_flags);
tlb_finish_mmu(&tlb); tlb_finish_mmu(&tlb);
......
...@@ -2197,9 +2197,9 @@ bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr, ...@@ -2197,9 +2197,9 @@ bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
extern long change_protection(struct mmu_gather *tlb, extern long change_protection(struct mmu_gather *tlb,
struct vm_area_struct *vma, unsigned long start, struct vm_area_struct *vma, unsigned long start,
unsigned long end, unsigned long cp_flags); unsigned long end, unsigned long cp_flags);
extern int mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma, extern int mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
struct vm_area_struct **pprev, unsigned long start, struct vm_area_struct *vma, struct vm_area_struct **pprev,
unsigned long end, unsigned long newflags); unsigned long start, unsigned long end, unsigned long newflags);
/* /*
* doesn't attempt to fault and will return short. * doesn't attempt to fault and will return short.
......
...@@ -585,9 +585,9 @@ static const struct mm_walk_ops prot_none_walk_ops = { ...@@ -585,9 +585,9 @@ static const struct mm_walk_ops prot_none_walk_ops = {
}; };
int int
mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma, mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
struct vm_area_struct **pprev, unsigned long start, struct vm_area_struct *vma, struct vm_area_struct **pprev,
unsigned long end, unsigned long newflags) unsigned long start, unsigned long end, unsigned long newflags)
{ {
struct mm_struct *mm = vma->vm_mm; struct mm_struct *mm = vma->vm_mm;
unsigned long oldflags = vma->vm_flags; unsigned long oldflags = vma->vm_flags;
...@@ -642,7 +642,7 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma, ...@@ -642,7 +642,7 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
* First try to merge with previous and/or next vma. * First try to merge with previous and/or next vma.
*/ */
pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT); pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
*pprev = vma_merge(mm, *pprev, start, end, newflags, *pprev = vmi_vma_merge(vmi, mm, *pprev, start, end, newflags,
vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma), vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
vma->vm_userfaultfd_ctx, anon_vma_name(vma)); vma->vm_userfaultfd_ctx, anon_vma_name(vma));
if (*pprev) { if (*pprev) {
...@@ -654,13 +654,13 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma, ...@@ -654,13 +654,13 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
*pprev = vma; *pprev = vma;
if (start != vma->vm_start) { if (start != vma->vm_start) {
error = split_vma(mm, vma, start, 1); error = vmi_split_vma(vmi, mm, vma, start, 1);
if (error) if (error)
goto fail; goto fail;
} }
if (end != vma->vm_end) { if (end != vma->vm_end) {
error = split_vma(mm, vma, end, 0); error = vmi_split_vma(vmi, mm, vma, end, 0);
if (error) if (error)
goto fail; goto fail;
} }
...@@ -709,7 +709,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len, ...@@ -709,7 +709,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
const bool rier = (current->personality & READ_IMPLIES_EXEC) && const bool rier = (current->personality & READ_IMPLIES_EXEC) &&
(prot & PROT_READ); (prot & PROT_READ);
struct mmu_gather tlb; struct mmu_gather tlb;
MA_STATE(mas, &current->mm->mm_mt, 0, 0); struct vma_iterator vmi;
start = untagged_addr(start); start = untagged_addr(start);
...@@ -741,8 +741,8 @@ static int do_mprotect_pkey(unsigned long start, size_t len, ...@@ -741,8 +741,8 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
if ((pkey != -1) && !mm_pkey_is_allocated(current->mm, pkey)) if ((pkey != -1) && !mm_pkey_is_allocated(current->mm, pkey))
goto out; goto out;
mas_set(&mas, start); vma_iter_init(&vmi, current->mm, start);
vma = mas_find(&mas, ULONG_MAX); vma = vma_find(&vmi, end);
error = -ENOMEM; error = -ENOMEM;
if (!vma) if (!vma)
goto out; goto out;
...@@ -765,18 +765,22 @@ static int do_mprotect_pkey(unsigned long start, size_t len, ...@@ -765,18 +765,22 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
} }
} }
prev = vma_prev(&vmi);
if (start > vma->vm_start) if (start > vma->vm_start)
prev = vma; prev = vma;
else
prev = mas_prev(&mas, 0);
tlb_gather_mmu(&tlb, current->mm); tlb_gather_mmu(&tlb, current->mm);
for (nstart = start ; ; ) { nstart = start;
tmp = vma->vm_start;
for_each_vma_range(vmi, vma, end) {
unsigned long mask_off_old_flags; unsigned long mask_off_old_flags;
unsigned long newflags; unsigned long newflags;
int new_vma_pkey; int new_vma_pkey;
/* Here we know that vma->vm_start <= nstart < vma->vm_end. */ if (vma->vm_start != tmp) {
error = -ENOMEM;
break;
}
/* Does the application expect PROT_READ to imply PROT_EXEC */ /* Does the application expect PROT_READ to imply PROT_EXEC */
if (rier && (vma->vm_flags & VM_MAYEXEC)) if (rier && (vma->vm_flags & VM_MAYEXEC))
...@@ -824,25 +828,18 @@ static int do_mprotect_pkey(unsigned long start, size_t len, ...@@ -824,25 +828,18 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
break; break;
} }
error = mprotect_fixup(&tlb, vma, &prev, nstart, tmp, newflags); error = mprotect_fixup(&vmi, &tlb, vma, &prev, nstart, tmp, newflags);
if (error) if (error)
break; break;
nstart = tmp; nstart = tmp;
if (nstart < prev->vm_end)
nstart = prev->vm_end;
if (nstart >= end)
break;
vma = find_vma(current->mm, prev->vm_end);
if (!vma || vma->vm_start != nstart) {
error = -ENOMEM;
break;
}
prot = reqprot; prot = reqprot;
} }
tlb_finish_mmu(&tlb); tlb_finish_mmu(&tlb);
if (vma_iter_end(&vmi) < end)
error = -ENOMEM;
out: out:
mmap_write_unlock(current->mm); mmap_write_unlock(current->mm);
return error; return error;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment