Commit 92fed820 authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Andrew Morton

mm/mmap: convert brk to use vma iterator

Use the vma iterator API for the brk() system call.  This will provide
type safety at compile time.

Link: https://lkml.kernel.org/r/20230120162650.984577-9-Liam.Howlett@oracle.comSigned-off-by: default avatarLiam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
parent b62b633e
...@@ -180,10 +180,10 @@ static int check_brk_limits(unsigned long addr, unsigned long len) ...@@ -180,10 +180,10 @@ static int check_brk_limits(unsigned long addr, unsigned long len)
return mlock_future_check(current->mm, current->mm->def_flags, len); return mlock_future_check(current->mm, current->mm->def_flags, len);
} }
static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
unsigned long newbrk, unsigned long oldbrk, unsigned long newbrk, unsigned long oldbrk,
struct list_head *uf); struct list_head *uf);
static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *brkvma, static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *brkvma,
unsigned long addr, unsigned long request, unsigned long flags); unsigned long addr, unsigned long request, unsigned long flags);
SYSCALL_DEFINE1(brk, unsigned long, brk) SYSCALL_DEFINE1(brk, unsigned long, brk)
{ {
...@@ -194,7 +194,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -194,7 +194,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
bool populate; bool populate;
bool downgraded = false; bool downgraded = false;
LIST_HEAD(uf); LIST_HEAD(uf);
MA_STATE(mas, &mm->mm_mt, 0, 0); struct vma_iterator vmi;
if (mmap_write_lock_killable(mm)) if (mmap_write_lock_killable(mm))
return -EINTR; return -EINTR;
...@@ -242,8 +242,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -242,8 +242,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
int ret; int ret;
/* Search one past newbrk */ /* Search one past newbrk */
mas_set(&mas, newbrk); vma_iter_init(&vmi, mm, newbrk);
brkvma = mas_find(&mas, oldbrk); brkvma = vma_find(&vmi, oldbrk);
if (!brkvma || brkvma->vm_start >= oldbrk) if (!brkvma || brkvma->vm_start >= oldbrk)
goto out; /* mapping intersects with an existing non-brk vma. */ goto out; /* mapping intersects with an existing non-brk vma. */
/* /*
...@@ -252,7 +252,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -252,7 +252,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
* before calling do_brk_munmap(). * before calling do_brk_munmap().
*/ */
mm->brk = brk; mm->brk = brk;
ret = do_brk_munmap(&mas, brkvma, newbrk, oldbrk, &uf); ret = do_brk_munmap(&vmi, brkvma, newbrk, oldbrk, &uf);
if (ret == 1) { if (ret == 1) {
downgraded = true; downgraded = true;
goto success; goto success;
...@@ -270,14 +270,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -270,14 +270,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
* Only check if the next VMA is within the stack_guard_gap of the * Only check if the next VMA is within the stack_guard_gap of the
* expansion area * expansion area
*/ */
mas_set(&mas, oldbrk); vma_iter_init(&vmi, mm, oldbrk);
next = mas_find(&mas, newbrk - 1 + PAGE_SIZE + stack_guard_gap); next = vma_find(&vmi, newbrk + PAGE_SIZE + stack_guard_gap);
if (next && newbrk + PAGE_SIZE > vm_start_gap(next)) if (next && newbrk + PAGE_SIZE > vm_start_gap(next))
goto out; goto out;
brkvma = mas_prev(&mas, mm->start_brk); brkvma = vma_prev_limit(&vmi, mm->start_brk);
/* Ok, looks good - let it rip. */ /* Ok, looks good - let it rip. */
if (do_brk_flags(&mas, brkvma, oldbrk, newbrk - oldbrk, 0) < 0) if (do_brk_flags(&vmi, brkvma, oldbrk, newbrk - oldbrk, 0) < 0)
goto out; goto out;
mm->brk = brk; mm->brk = brk;
...@@ -2917,8 +2917,8 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, ...@@ -2917,8 +2917,8 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
} }
/* /*
* brk_munmap() - Unmap a partial vma. * brk_munmap() - Unmap a full or partial vma.
* @mas: The maple tree state. * @vmi: The vma iterator
* @vma: The vma to be modified * @vma: The vma to be modified
* @newbrk: the start of the address to unmap * @newbrk: the start of the address to unmap
* @oldbrk: The end of the address to unmap * @oldbrk: The end of the address to unmap
...@@ -2928,7 +2928,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, ...@@ -2928,7 +2928,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
* unmaps a partial VMA mapping. Does not handle alignment, downgrades lock if * unmaps a partial VMA mapping. Does not handle alignment, downgrades lock if
* possible. * possible.
*/ */
static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, static int do_brk_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
unsigned long newbrk, unsigned long oldbrk, unsigned long newbrk, unsigned long oldbrk,
struct list_head *uf) struct list_head *uf)
{ {
...@@ -2936,14 +2936,14 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, ...@@ -2936,14 +2936,14 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
int ret; int ret;
arch_unmap(mm, newbrk, oldbrk); arch_unmap(mm, newbrk, oldbrk);
ret = do_mas_align_munmap(mas, vma, mm, newbrk, oldbrk, uf, true); ret = do_mas_align_munmap(&vmi->mas, vma, mm, newbrk, oldbrk, uf, true);
validate_mm_mt(mm); validate_mm_mt(mm);
return ret; return ret;
} }
/* /*
* do_brk_flags() - Increase the brk vma if the flags match. * do_brk_flags() - Increase the brk vma if the flags match.
* @mas: The maple tree state. * @vmi: The vma iterator
* @addr: The start address * @addr: The start address
* @len: The length of the increase * @len: The length of the increase
* @vma: The vma, * @vma: The vma,
...@@ -2953,7 +2953,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, ...@@ -2953,7 +2953,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma,
* do not match then create a new anonymous VMA. Eventually we may be able to * do not match then create a new anonymous VMA. Eventually we may be able to
* do some brk-specific accounting here. * do some brk-specific accounting here.
*/ */
static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma, static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
unsigned long addr, unsigned long len, unsigned long flags) unsigned long addr, unsigned long len, unsigned long flags)
{ {
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
...@@ -2980,8 +2980,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma, ...@@ -2980,8 +2980,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
if (vma && vma->vm_end == addr && !vma_policy(vma) && if (vma && vma->vm_end == addr && !vma_policy(vma) &&
can_vma_merge_after(vma, flags, NULL, NULL, can_vma_merge_after(vma, flags, NULL, NULL,
addr >> PAGE_SHIFT, NULL_VM_UFFD_CTX, NULL)) { addr >> PAGE_SHIFT, NULL_VM_UFFD_CTX, NULL)) {
mas_set_range(mas, vma->vm_start, addr + len - 1); if (vma_iter_prealloc(vmi))
if (mas_preallocate(mas, GFP_KERNEL))
goto unacct_fail; goto unacct_fail;
vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0); vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0);
...@@ -2991,7 +2990,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma, ...@@ -2991,7 +2990,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
} }
vma->vm_end = addr + len; vma->vm_end = addr + len;
vma->vm_flags |= VM_SOFTDIRTY; vma->vm_flags |= VM_SOFTDIRTY;
mas_store_prealloc(mas, vma); vma_iter_store(vmi, vma);
if (vma->anon_vma) { if (vma->anon_vma) {
anon_vma_interval_tree_post_update_vma(vma); anon_vma_interval_tree_post_update_vma(vma);
...@@ -3012,8 +3011,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma, ...@@ -3012,8 +3011,7 @@ static int do_brk_flags(struct ma_state *mas, struct vm_area_struct *vma,
vma->vm_pgoff = addr >> PAGE_SHIFT; vma->vm_pgoff = addr >> PAGE_SHIFT;
vma->vm_flags = flags; vma->vm_flags = flags;
vma->vm_page_prot = vm_get_page_prot(flags); vma->vm_page_prot = vm_get_page_prot(flags);
mas_set_range(mas, vma->vm_start, addr + len - 1); if (vma_iter_store_gfp(vmi, vma, GFP_KERNEL))
if (mas_store_gfp(mas, vma, GFP_KERNEL))
goto mas_store_fail; goto mas_store_fail;
mm->map_count++; mm->map_count++;
...@@ -3042,7 +3040,7 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags) ...@@ -3042,7 +3040,7 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
int ret; int ret;
bool populate; bool populate;
LIST_HEAD(uf); LIST_HEAD(uf);
MA_STATE(mas, &mm->mm_mt, addr, addr); VMA_ITERATOR(vmi, mm, addr);
len = PAGE_ALIGN(request); len = PAGE_ALIGN(request);
if (len < request) if (len < request)
...@@ -3061,12 +3059,12 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags) ...@@ -3061,12 +3059,12 @@ int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
if (ret) if (ret)
goto limits_failed; goto limits_failed;
ret = do_mas_munmap(&mas, mm, addr, len, &uf, 0); ret = do_mas_munmap(&vmi.mas, mm, addr, len, &uf, 0);
if (ret) if (ret)
goto munmap_failed; goto munmap_failed;
vma = mas_prev(&mas, 0); vma = vma_prev(&vmi);
ret = do_brk_flags(&mas, vma, addr, len, flags); ret = do_brk_flags(&vmi, vma, addr, len, flags);
populate = ((mm->def_flags & VM_LOCKED) != 0); populate = ((mm->def_flags & VM_LOCKED) != 0);
mmap_write_unlock(mm); mmap_write_unlock(mm);
userfaultfd_unmap_complete(mm, &uf); userfaultfd_unmap_complete(mm, &uf);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment