diff --git a/drivers/md/dm-zoned-metadata.c b/drivers/md/dm-zoned-metadata.c
index c2c17149d9687b3a595ab64577b59c429cea81ea..086a870087cff81ff7fe57b8311a7e10000f6b9b 100644
--- a/drivers/md/dm-zoned-metadata.c
+++ b/drivers/md/dm-zoned-metadata.c
@@ -132,6 +132,7 @@ struct dmz_metadata {
 
 	sector_t		zone_bitmap_size;
 	unsigned int		zone_nr_bitmap_blocks;
+	unsigned int		zone_bits_per_mblk;
 
 	unsigned int		nr_bitmap_blocks;
 	unsigned int		nr_map_blocks;
@@ -1165,7 +1166,10 @@ static int dmz_init_zones(struct dmz_metadata *zmd)
 
 	/* Init */
 	zmd->zone_bitmap_size = dev->zone_nr_blocks >> 3;
-	zmd->zone_nr_bitmap_blocks = zmd->zone_bitmap_size >> DMZ_BLOCK_SHIFT;
+	zmd->zone_nr_bitmap_blocks =
+		max_t(sector_t, 1, zmd->zone_bitmap_size >> DMZ_BLOCK_SHIFT);
+	zmd->zone_bits_per_mblk = min_t(sector_t, dev->zone_nr_blocks,
+					DMZ_BLOCK_SIZE_BITS);
 
 	/* Allocate zone array */
 	zmd->zones = kcalloc(dev->nr_zones, sizeof(struct dm_zone), GFP_KERNEL);
@@ -1982,7 +1986,7 @@ int dmz_copy_valid_blocks(struct dmz_metadata *zmd, struct dm_zone *from_zone,
 		dmz_release_mblock(zmd, to_mblk);
 		dmz_release_mblock(zmd, from_mblk);
 
-		chunk_block += DMZ_BLOCK_SIZE_BITS;
+		chunk_block += zmd->zone_bits_per_mblk;
 	}
 
 	to_zone->weight = from_zone->weight;
@@ -2043,7 +2047,7 @@ int dmz_validate_blocks(struct dmz_metadata *zmd, struct dm_zone *zone,
 
 		/* Set bits */
 		bit = chunk_block & DMZ_BLOCK_MASK_BITS;
-		nr_bits = min(nr_blocks, DMZ_BLOCK_SIZE_BITS - bit);
+		nr_bits = min(nr_blocks, zmd->zone_bits_per_mblk - bit);
 
 		count = dmz_set_bits((unsigned long *)mblk->data, bit, nr_bits);
 		if (count) {
@@ -2122,7 +2126,7 @@ int dmz_invalidate_blocks(struct dmz_metadata *zmd, struct dm_zone *zone,
 
 		/* Clear bits */
 		bit = chunk_block & DMZ_BLOCK_MASK_BITS;
-		nr_bits = min(nr_blocks, DMZ_BLOCK_SIZE_BITS - bit);
+		nr_bits = min(nr_blocks, zmd->zone_bits_per_mblk - bit);
 
 		count = dmz_clear_bits((unsigned long *)mblk->data,
 				       bit, nr_bits);
@@ -2182,6 +2186,7 @@ static int dmz_to_next_set_block(struct dmz_metadata *zmd, struct dm_zone *zone,
 {
 	struct dmz_mblock *mblk;
 	unsigned int bit, set_bit, nr_bits;
+	unsigned int zone_bits = zmd->zone_bits_per_mblk;
 	unsigned long *bitmap;
 	int n = 0;
 
@@ -2196,15 +2201,15 @@ static int dmz_to_next_set_block(struct dmz_metadata *zmd, struct dm_zone *zone,
 		/* Get offset */
 		bitmap = (unsigned long *) mblk->data;
 		bit = chunk_block & DMZ_BLOCK_MASK_BITS;
-		nr_bits = min(nr_blocks, DMZ_BLOCK_SIZE_BITS - bit);
+		nr_bits = min(nr_blocks, zone_bits - bit);
 		if (set)
-			set_bit = find_next_bit(bitmap, DMZ_BLOCK_SIZE_BITS, bit);
+			set_bit = find_next_bit(bitmap, zone_bits, bit);
 		else
-			set_bit = find_next_zero_bit(bitmap, DMZ_BLOCK_SIZE_BITS, bit);
+			set_bit = find_next_zero_bit(bitmap, zone_bits, bit);
 		dmz_release_mblock(zmd, mblk);
 
 		n += set_bit - bit;
-		if (set_bit < DMZ_BLOCK_SIZE_BITS)
+		if (set_bit < zone_bits)
 			break;
 
 		nr_blocks -= nr_bits;
@@ -2307,7 +2312,7 @@ static void dmz_get_zone_weight(struct dmz_metadata *zmd, struct dm_zone *zone)
 		/* Count bits in this block */
 		bitmap = mblk->data;
 		bit = chunk_block & DMZ_BLOCK_MASK_BITS;
-		nr_bits = min(nr_blocks, DMZ_BLOCK_SIZE_BITS - bit);
+		nr_bits = min(nr_blocks, zmd->zone_bits_per_mblk - bit);
 		n += dmz_count_bits(bitmap, bit, nr_bits);
 
 		dmz_release_mblock(zmd, mblk);