diff --git a/mm/mmap.c b/mm/mmap.c index 7cba84f8e3a5..548bc45a27bf 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -1740,16 +1740,7 @@ int do_vma_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, unsigned long start, unsigned long end, struct list_head *uf, bool unlock) { - struct mm_struct *mm = vma->vm_mm; - - /* - * Check if memory is sealed, prevent unmapping a sealed VMA. - * can_modify_mm assumes we have acquired the lock on MM. - */ - if (unlikely(!can_modify_mm(mm, start, end))) - return -EPERM; - - return do_vmi_align_munmap(vmi, vma, mm, start, end, uf, unlock); + return do_vmi_align_munmap(vmi, vma, vma->vm_mm, start, end, uf, unlock); } /* diff --git a/mm/vma.c b/mm/vma.c index 84965f2cd580..5850f7c0949b 100644 --- a/mm/vma.c +++ b/mm/vma.c @@ -712,6 +712,12 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, if (end < vma->vm_end && mm->map_count >= sysctl_max_map_count) goto map_count_exceeded; + /* Don't bother splitting the VMA if we can't unmap it anyway */ + if (!can_modify_vma(vma)) { + error = -EPERM; + goto start_split_failed; + } + error = __split_vma(vmi, vma, start, 1); if (error) goto start_split_failed; @@ -723,6 +729,11 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, */ next = vma; do { + if (!can_modify_vma(next)) { + error = -EPERM; + goto modify_vma_failed; + } + /* Does it split the end? */ if (next->vm_end > end) { error = __split_vma(vmi, next, end, 0); @@ -815,6 +826,7 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, __mt_destroy(&mt_detach); return 0; +modify_vma_failed: clear_tree_failed: userfaultfd_error: munmap_gather_failed: @@ -860,13 +872,6 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm, if (end == start) return -EINVAL; - /* - * Check if memory is sealed, prevent unmapping a sealed VMA. - * can_modify_mm assumes we have acquired the lock on MM. - */ - if (unlikely(!can_modify_mm(mm, start, end))) - return -EPERM; - /* Find the first overlapping VMA */ vma = vma_find(vmi, end); if (!vma) {