diff --git a/arch/x86/mm/pat.c b/arch/x86/mm/pat.c
index 1aca073ba5715ce82cde679736457a47abbe9590..031782e7423197ab4dbadb857eeed3349bcde396 100644
--- a/arch/x86/mm/pat.c
+++ b/arch/x86/mm/pat.c
@@ -586,7 +586,7 @@ int free_memtype(u64 start, u64 end)
 	entry = rbt_memtype_erase(start, end);
 	spin_unlock(&memtype_lock);
 
-	if (!entry) {
+	if (IS_ERR(entry)) {
 		pr_info("x86/PAT: %s:%d freeing invalid memtype [mem %#010Lx-%#010Lx]\n",
 			current->comm, current->pid, start, end - 1);
 		return -EINVAL;
diff --git a/arch/x86/mm/pat_rbtree.c b/arch/x86/mm/pat_rbtree.c
index 63931080366aaae1626c48b8ec780acd78d09569..2f7702253ccfaad30da5233ac9293cb5421d1742 100644
--- a/arch/x86/mm/pat_rbtree.c
+++ b/arch/x86/mm/pat_rbtree.c
@@ -98,8 +98,13 @@ static struct memtype *memtype_rb_lowest_match(struct rb_root *root,
 	return last_lower; /* Returns NULL if there is no overlap */
 }
 
-static struct memtype *memtype_rb_exact_match(struct rb_root *root,
-				u64 start, u64 end)
+enum {
+	MEMTYPE_EXACT_MATCH	= 0,
+	MEMTYPE_END_MATCH	= 1
+};
+
+static struct memtype *memtype_rb_match(struct rb_root *root,
+				u64 start, u64 end, int match_type)
 {
 	struct memtype *match;
 
@@ -107,7 +112,12 @@ static struct memtype *memtype_rb_exact_match(struct rb_root *root,
 	while (match != NULL && match->start < end) {
 		struct rb_node *node;
 
-		if (match->start == start && match->end == end)
+		if ((match_type == MEMTYPE_EXACT_MATCH) &&
+		    (match->start == start) && (match->end == end))
+			return match;
+
+		if ((match_type == MEMTYPE_END_MATCH) &&
+		    (match->start < start) && (match->end == end))
 			return match;
 
 		node = rb_next(&match->rb);
@@ -117,7 +127,7 @@ static struct memtype *memtype_rb_exact_match(struct rb_root *root,
 			match = NULL;
 	}
 
-	return NULL; /* Returns NULL if there is no exact match */
+	return NULL; /* Returns NULL if there is no match */
 }
 
 static int memtype_rb_check_conflict(struct rb_root *root,
@@ -210,12 +220,36 @@ struct memtype *rbt_memtype_erase(u64 start, u64 end)
 {
 	struct memtype *data;
 
-	data = memtype_rb_exact_match(&memtype_rbroot, start, end);
-	if (!data)
-		goto out;
+	/*
+	 * Since the memtype_rbroot tree allows overlapping ranges,
+	 * rbt_memtype_erase() checks with EXACT_MATCH first, i.e. free
+	 * a whole node for the munmap case.  If no such entry is found,
+	 * it then checks with END_MATCH, i.e. shrink the size of a node
+	 * from the end for the mremap case.
+	 */
+	data = memtype_rb_match(&memtype_rbroot, start, end,
+				MEMTYPE_EXACT_MATCH);
+	if (!data) {
+		data = memtype_rb_match(&memtype_rbroot, start, end,
+					MEMTYPE_END_MATCH);
+		if (!data)
+			return ERR_PTR(-EINVAL);
+	}
+
+	if (data->start == start) {
+		/* munmap: erase this node */
+		rb_erase_augmented(&data->rb, &memtype_rbroot,
+					&memtype_rb_augment_cb);
+	} else {
+		/* mremap: update the end value of this node */
+		rb_erase_augmented(&data->rb, &memtype_rbroot,
+					&memtype_rb_augment_cb);
+		data->end = start;
+		data->subtree_max_end = data->end;
+		memtype_rb_insert(&memtype_rbroot, data);
+		return NULL;
+	}
 
-	rb_erase_augmented(&data->rb, &memtype_rbroot, &memtype_rb_augment_cb);
-out:
 	return data;
 }