diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 9696d654b01faa45afecf14893d7182c160ed5ab..18d2f584945b94bc49a7c8e6c3f72721cca619c6 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -2902,13 +2902,13 @@ void kvm_mmu_zap_all(struct kvm *kvm)
 	kvm_flush_remote_tlbs(kvm);
 }
 
-static void kvm_mmu_remove_one_alloc_mmu_page(struct kvm *kvm)
+static int kvm_mmu_remove_some_alloc_mmu_pages(struct kvm *kvm)
 {
 	struct kvm_mmu_page *page;
 
 	page = container_of(kvm->arch.active_mmu_pages.prev,
 			    struct kvm_mmu_page, link);
-	kvm_mmu_zap_page(kvm, page);
+	return kvm_mmu_zap_page(kvm, page) + 1;
 }
 
 static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
@@ -2920,7 +2920,7 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
 	spin_lock(&kvm_lock);
 
 	list_for_each_entry(kvm, &vm_list, vm_list) {
-		int npages, idx;
+		int npages, idx, freed_pages;
 
 		idx = srcu_read_lock(&kvm->srcu);
 		spin_lock(&kvm->mmu_lock);
@@ -2928,8 +2928,8 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
 			 kvm->arch.n_free_mmu_pages;
 		cache_count += npages;
 		if (!kvm_freed && nr_to_scan > 0 && npages > 0) {
-			kvm_mmu_remove_one_alloc_mmu_page(kvm);
-			cache_count--;
+			freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm);
+			cache_count -= freed_pages;
 			kvm_freed = kvm;
 		}
 		nr_to_scan--;