Commit 36845663 authored by Huang Shijie's avatar Huang Shijie Committed by Linus Torvalds

lib/genalloc: fix the overflow when size is too big

Some graphic card has very big memory on chip, such as 32G bytes.

In the following case, it will cause overflow:

    pool = gen_pool_create(PAGE_SHIFT, NUMA_NO_NODE);
    ret = gen_pool_add(pool, 0x1000000, SZ_32G, NUMA_NO_NODE);

    va = gen_pool_alloc(pool, SZ_4G);

The overflow occurs in gen_pool_alloc_algo_owner():

		....
		size = nbits << order;
		....

The @nbits is "int" type, so it will overflow.
Then the gen_pool_avail() will return the wrong value.

This patch converts some "int" to "unsigned long", and
changes the compare code in while.

Link: https://lkml.kernel.org/r/20201229060657.3389-1-sjhuang@iluvatar.aiSigned-off-by: default avatarHuang Shijie <sjhuang@iluvatar.ai>
Reported-by: default avatarShi Jiasheng <jiasheng.shi@iluvatar.ai>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent aa8c7db4
...@@ -81,14 +81,14 @@ static int clear_bits_ll(unsigned long *addr, unsigned long mask_to_clear) ...@@ -81,14 +81,14 @@ static int clear_bits_ll(unsigned long *addr, unsigned long mask_to_clear)
* users set the same bit, one user will return remain bits, otherwise * users set the same bit, one user will return remain bits, otherwise
* return 0. * return 0.
*/ */
static int bitmap_set_ll(unsigned long *map, int start, int nr) static int bitmap_set_ll(unsigned long *map, unsigned long start, unsigned long nr)
{ {
unsigned long *p = map + BIT_WORD(start); unsigned long *p = map + BIT_WORD(start);
const int size = start + nr; const unsigned long size = start + nr;
int bits_to_set = BITS_PER_LONG - (start % BITS_PER_LONG); int bits_to_set = BITS_PER_LONG - (start % BITS_PER_LONG);
unsigned long mask_to_set = BITMAP_FIRST_WORD_MASK(start); unsigned long mask_to_set = BITMAP_FIRST_WORD_MASK(start);
while (nr - bits_to_set >= 0) { while (nr >= bits_to_set) {
if (set_bits_ll(p, mask_to_set)) if (set_bits_ll(p, mask_to_set))
return nr; return nr;
nr -= bits_to_set; nr -= bits_to_set;
...@@ -116,14 +116,15 @@ static int bitmap_set_ll(unsigned long *map, int start, int nr) ...@@ -116,14 +116,15 @@ static int bitmap_set_ll(unsigned long *map, int start, int nr)
* users clear the same bit, one user will return remain bits, * users clear the same bit, one user will return remain bits,
* otherwise return 0. * otherwise return 0.
*/ */
static int bitmap_clear_ll(unsigned long *map, int start, int nr) static unsigned long
bitmap_clear_ll(unsigned long *map, unsigned long start, unsigned long nr)
{ {
unsigned long *p = map + BIT_WORD(start); unsigned long *p = map + BIT_WORD(start);
const int size = start + nr; const unsigned long size = start + nr;
int bits_to_clear = BITS_PER_LONG - (start % BITS_PER_LONG); int bits_to_clear = BITS_PER_LONG - (start % BITS_PER_LONG);
unsigned long mask_to_clear = BITMAP_FIRST_WORD_MASK(start); unsigned long mask_to_clear = BITMAP_FIRST_WORD_MASK(start);
while (nr - bits_to_clear >= 0) { while (nr >= bits_to_clear) {
if (clear_bits_ll(p, mask_to_clear)) if (clear_bits_ll(p, mask_to_clear))
return nr; return nr;
nr -= bits_to_clear; nr -= bits_to_clear;
...@@ -183,8 +184,8 @@ int gen_pool_add_owner(struct gen_pool *pool, unsigned long virt, phys_addr_t ph ...@@ -183,8 +184,8 @@ int gen_pool_add_owner(struct gen_pool *pool, unsigned long virt, phys_addr_t ph
size_t size, int nid, void *owner) size_t size, int nid, void *owner)
{ {
struct gen_pool_chunk *chunk; struct gen_pool_chunk *chunk;
int nbits = size >> pool->min_alloc_order; unsigned long nbits = size >> pool->min_alloc_order;
int nbytes = sizeof(struct gen_pool_chunk) + unsigned long nbytes = sizeof(struct gen_pool_chunk) +
BITS_TO_LONGS(nbits) * sizeof(long); BITS_TO_LONGS(nbits) * sizeof(long);
chunk = vzalloc_node(nbytes, nid); chunk = vzalloc_node(nbytes, nid);
...@@ -242,7 +243,7 @@ void gen_pool_destroy(struct gen_pool *pool) ...@@ -242,7 +243,7 @@ void gen_pool_destroy(struct gen_pool *pool)
struct list_head *_chunk, *_next_chunk; struct list_head *_chunk, *_next_chunk;
struct gen_pool_chunk *chunk; struct gen_pool_chunk *chunk;
int order = pool->min_alloc_order; int order = pool->min_alloc_order;
int bit, end_bit; unsigned long bit, end_bit;
list_for_each_safe(_chunk, _next_chunk, &pool->chunks) { list_for_each_safe(_chunk, _next_chunk, &pool->chunks) {
chunk = list_entry(_chunk, struct gen_pool_chunk, next_chunk); chunk = list_entry(_chunk, struct gen_pool_chunk, next_chunk);
...@@ -278,7 +279,7 @@ unsigned long gen_pool_alloc_algo_owner(struct gen_pool *pool, size_t size, ...@@ -278,7 +279,7 @@ unsigned long gen_pool_alloc_algo_owner(struct gen_pool *pool, size_t size,
struct gen_pool_chunk *chunk; struct gen_pool_chunk *chunk;
unsigned long addr = 0; unsigned long addr = 0;
int order = pool->min_alloc_order; int order = pool->min_alloc_order;
int nbits, start_bit, end_bit, remain; unsigned long nbits, start_bit, end_bit, remain;
#ifndef CONFIG_ARCH_HAVE_NMI_SAFE_CMPXCHG #ifndef CONFIG_ARCH_HAVE_NMI_SAFE_CMPXCHG
BUG_ON(in_nmi()); BUG_ON(in_nmi());
...@@ -487,7 +488,7 @@ void gen_pool_free_owner(struct gen_pool *pool, unsigned long addr, size_t size, ...@@ -487,7 +488,7 @@ void gen_pool_free_owner(struct gen_pool *pool, unsigned long addr, size_t size,
{ {
struct gen_pool_chunk *chunk; struct gen_pool_chunk *chunk;
int order = pool->min_alloc_order; int order = pool->min_alloc_order;
int start_bit, nbits, remain; unsigned long start_bit, nbits, remain;
#ifndef CONFIG_ARCH_HAVE_NMI_SAFE_CMPXCHG #ifndef CONFIG_ARCH_HAVE_NMI_SAFE_CMPXCHG
BUG_ON(in_nmi()); BUG_ON(in_nmi());
...@@ -755,7 +756,7 @@ unsigned long gen_pool_best_fit(unsigned long *map, unsigned long size, ...@@ -755,7 +756,7 @@ unsigned long gen_pool_best_fit(unsigned long *map, unsigned long size,
index = bitmap_find_next_zero_area(map, size, start, nr, 0); index = bitmap_find_next_zero_area(map, size, start, nr, 0);
while (index < size) { while (index < size) {
int next_bit = find_next_bit(map, size, index + nr); unsigned long next_bit = find_next_bit(map, size, index + nr);
if ((next_bit - index) < len) { if ((next_bit - index) < len) {
len = next_bit - index; len = next_bit - index;
start_bit = index; start_bit = index;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment