Commit c85d6913 authored by Roman Gushchin's avatar Roman Gushchin Committed by Alexei Starovoitov

bpf: move memory size checks to bpf_map_charge_init()

Most bpf map types doing similar checks and bytes to pages
conversion during memory allocation and charging.

Let's unify these checks by moving them into bpf_map_charge_init().
Signed-off-by: default avatarRoman Gushchin <guro@fb.com>
Acked-by: default avatarSong Liu <songliubraving@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parent b936ca64
...@@ -652,7 +652,7 @@ void bpf_map_put_with_uref(struct bpf_map *map); ...@@ -652,7 +652,7 @@ void bpf_map_put_with_uref(struct bpf_map *map);
void bpf_map_put(struct bpf_map *map); void bpf_map_put(struct bpf_map *map);
int bpf_map_charge_memlock(struct bpf_map *map, u32 pages); int bpf_map_charge_memlock(struct bpf_map *map, u32 pages);
void bpf_map_uncharge_memlock(struct bpf_map *map, u32 pages); void bpf_map_uncharge_memlock(struct bpf_map *map, u32 pages);
int bpf_map_charge_init(struct bpf_map_memory *mem, u32 pages); int bpf_map_charge_init(struct bpf_map_memory *mem, size_t size);
void bpf_map_charge_finish(struct bpf_map_memory *mem); void bpf_map_charge_finish(struct bpf_map_memory *mem);
void bpf_map_charge_move(struct bpf_map_memory *dst, void bpf_map_charge_move(struct bpf_map_memory *dst,
struct bpf_map_memory *src); struct bpf_map_memory *src);
......
...@@ -117,14 +117,8 @@ static struct bpf_map *array_map_alloc(union bpf_attr *attr) ...@@ -117,14 +117,8 @@ static struct bpf_map *array_map_alloc(union bpf_attr *attr)
/* make sure there is no u32 overflow later in round_up() */ /* make sure there is no u32 overflow later in round_up() */
cost = array_size; cost = array_size;
if (cost >= U32_MAX - PAGE_SIZE) if (percpu)
return ERR_PTR(-ENOMEM);
if (percpu) {
cost += (u64)attr->max_entries * elem_size * num_possible_cpus(); cost += (u64)attr->max_entries * elem_size * num_possible_cpus();
if (cost >= U32_MAX - PAGE_SIZE)
return ERR_PTR(-ENOMEM);
}
cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
ret = bpf_map_charge_init(&mem, cost); ret = bpf_map_charge_init(&mem, cost);
if (ret < 0) if (ret < 0)
......
...@@ -106,12 +106,9 @@ static struct bpf_map *cpu_map_alloc(union bpf_attr *attr) ...@@ -106,12 +106,9 @@ static struct bpf_map *cpu_map_alloc(union bpf_attr *attr)
/* make sure page count doesn't overflow */ /* make sure page count doesn't overflow */
cost = (u64) cmap->map.max_entries * sizeof(struct bpf_cpu_map_entry *); cost = (u64) cmap->map.max_entries * sizeof(struct bpf_cpu_map_entry *);
cost += cpu_map_bitmap_size(attr) * num_possible_cpus(); cost += cpu_map_bitmap_size(attr) * num_possible_cpus();
if (cost >= U32_MAX - PAGE_SIZE)
goto free_cmap;
/* Notice returns -EPERM on if map size is larger than memlock limit */ /* Notice returns -EPERM on if map size is larger than memlock limit */
ret = bpf_map_charge_init(&cmap->map.memory, ret = bpf_map_charge_init(&cmap->map.memory, cost);
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (ret) { if (ret) {
err = ret; err = ret;
goto free_cmap; goto free_cmap;
......
...@@ -108,12 +108,9 @@ static struct bpf_map *dev_map_alloc(union bpf_attr *attr) ...@@ -108,12 +108,9 @@ static struct bpf_map *dev_map_alloc(union bpf_attr *attr)
/* make sure page count doesn't overflow */ /* make sure page count doesn't overflow */
cost = (u64) dtab->map.max_entries * sizeof(struct bpf_dtab_netdev *); cost = (u64) dtab->map.max_entries * sizeof(struct bpf_dtab_netdev *);
cost += dev_map_bitmap_size(attr) * num_possible_cpus(); cost += dev_map_bitmap_size(attr) * num_possible_cpus();
if (cost >= U32_MAX - PAGE_SIZE)
goto free_dtab;
/* if map size is larger than memlock limit, reject it */ /* if map size is larger than memlock limit, reject it */
err = bpf_map_charge_init(&dtab->map.memory, err = bpf_map_charge_init(&dtab->map.memory, cost);
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (err) if (err)
goto free_dtab; goto free_dtab;
......
...@@ -360,13 +360,8 @@ static struct bpf_map *htab_map_alloc(union bpf_attr *attr) ...@@ -360,13 +360,8 @@ static struct bpf_map *htab_map_alloc(union bpf_attr *attr)
else else
cost += (u64) htab->elem_size * num_possible_cpus(); cost += (u64) htab->elem_size * num_possible_cpus();
if (cost >= U32_MAX - PAGE_SIZE)
/* make sure page count doesn't overflow */
goto free_htab;
/* if map size is larger than memlock limit, reject it */ /* if map size is larger than memlock limit, reject it */
err = bpf_map_charge_init(&htab->map.memory, err = bpf_map_charge_init(&htab->map.memory, cost);
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (err) if (err)
goto free_htab; goto free_htab;
......
...@@ -273,7 +273,6 @@ static struct bpf_map *cgroup_storage_map_alloc(union bpf_attr *attr) ...@@ -273,7 +273,6 @@ static struct bpf_map *cgroup_storage_map_alloc(union bpf_attr *attr)
int numa_node = bpf_map_attr_numa_node(attr); int numa_node = bpf_map_attr_numa_node(attr);
struct bpf_cgroup_storage_map *map; struct bpf_cgroup_storage_map *map;
struct bpf_map_memory mem; struct bpf_map_memory mem;
u32 pages;
int ret; int ret;
if (attr->key_size != sizeof(struct bpf_cgroup_storage_key)) if (attr->key_size != sizeof(struct bpf_cgroup_storage_key))
...@@ -293,9 +292,7 @@ static struct bpf_map *cgroup_storage_map_alloc(union bpf_attr *attr) ...@@ -293,9 +292,7 @@ static struct bpf_map *cgroup_storage_map_alloc(union bpf_attr *attr)
/* max_entries is not used and enforced to be 0 */ /* max_entries is not used and enforced to be 0 */
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
pages = round_up(sizeof(struct bpf_cgroup_storage_map), PAGE_SIZE) >> ret = bpf_map_charge_init(&mem, sizeof(struct bpf_cgroup_storage_map));
PAGE_SHIFT;
ret = bpf_map_charge_init(&mem, pages);
if (ret < 0) if (ret < 0)
return ERR_PTR(ret); return ERR_PTR(ret);
......
...@@ -573,13 +573,8 @@ static struct bpf_map *trie_alloc(union bpf_attr *attr) ...@@ -573,13 +573,8 @@ static struct bpf_map *trie_alloc(union bpf_attr *attr)
cost_per_node = sizeof(struct lpm_trie_node) + cost_per_node = sizeof(struct lpm_trie_node) +
attr->value_size + trie->data_size; attr->value_size + trie->data_size;
cost += (u64) attr->max_entries * cost_per_node; cost += (u64) attr->max_entries * cost_per_node;
if (cost >= U32_MAX - PAGE_SIZE) {
ret = -E2BIG;
goto out_err;
}
ret = bpf_map_charge_init(&trie->map.memory, ret = bpf_map_charge_init(&trie->map.memory, cost);
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (ret) if (ret)
goto out_err; goto out_err;
......
...@@ -73,10 +73,6 @@ static struct bpf_map *queue_stack_map_alloc(union bpf_attr *attr) ...@@ -73,10 +73,6 @@ static struct bpf_map *queue_stack_map_alloc(union bpf_attr *attr)
size = (u64) attr->max_entries + 1; size = (u64) attr->max_entries + 1;
cost = queue_size = sizeof(*qs) + size * attr->value_size; cost = queue_size = sizeof(*qs) + size * attr->value_size;
if (cost >= U32_MAX - PAGE_SIZE)
return ERR_PTR(-E2BIG);
cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
ret = bpf_map_charge_init(&mem, cost); ret = bpf_map_charge_init(&mem, cost);
if (ret < 0) if (ret < 0)
......
...@@ -152,7 +152,7 @@ static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr) ...@@ -152,7 +152,7 @@ static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr)
int err, numa_node = bpf_map_attr_numa_node(attr); int err, numa_node = bpf_map_attr_numa_node(attr);
struct reuseport_array *array; struct reuseport_array *array;
struct bpf_map_memory mem; struct bpf_map_memory mem;
u64 cost, array_size; u64 array_size;
if (!capable(CAP_SYS_ADMIN)) if (!capable(CAP_SYS_ADMIN))
return ERR_PTR(-EPERM); return ERR_PTR(-EPERM);
...@@ -160,13 +160,7 @@ static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr) ...@@ -160,13 +160,7 @@ static struct bpf_map *reuseport_array_alloc(union bpf_attr *attr)
array_size = sizeof(*array); array_size = sizeof(*array);
array_size += (u64)attr->max_entries * sizeof(struct sock *); array_size += (u64)attr->max_entries * sizeof(struct sock *);
/* make sure there is no u32 overflow later in round_up() */ err = bpf_map_charge_init(&mem, array_size);
cost = array_size;
if (cost >= U32_MAX - PAGE_SIZE)
return ERR_PTR(-ENOMEM);
cost = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
err = bpf_map_charge_init(&mem, cost);
if (err) if (err)
return ERR_PTR(err); return ERR_PTR(err);
......
...@@ -117,14 +117,8 @@ static struct bpf_map *stack_map_alloc(union bpf_attr *attr) ...@@ -117,14 +117,8 @@ static struct bpf_map *stack_map_alloc(union bpf_attr *attr)
n_buckets = roundup_pow_of_two(attr->max_entries); n_buckets = roundup_pow_of_two(attr->max_entries);
cost = n_buckets * sizeof(struct stack_map_bucket *) + sizeof(*smap); cost = n_buckets * sizeof(struct stack_map_bucket *) + sizeof(*smap);
if (cost >= U32_MAX - PAGE_SIZE)
return ERR_PTR(-E2BIG);
cost += n_buckets * (value_size + sizeof(struct stack_map_bucket)); cost += n_buckets * (value_size + sizeof(struct stack_map_bucket));
if (cost >= U32_MAX - PAGE_SIZE) err = bpf_map_charge_init(&mem, cost);
return ERR_PTR(-E2BIG);
err = bpf_map_charge_init(&mem,
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (err) if (err)
return ERR_PTR(err); return ERR_PTR(err);
......
...@@ -205,11 +205,16 @@ static void bpf_uncharge_memlock(struct user_struct *user, u32 pages) ...@@ -205,11 +205,16 @@ static void bpf_uncharge_memlock(struct user_struct *user, u32 pages)
atomic_long_sub(pages, &user->locked_vm); atomic_long_sub(pages, &user->locked_vm);
} }
int bpf_map_charge_init(struct bpf_map_memory *mem, u32 pages) int bpf_map_charge_init(struct bpf_map_memory *mem, size_t size)
{ {
struct user_struct *user = get_current_user(); u32 pages = round_up(size, PAGE_SIZE) >> PAGE_SHIFT;
struct user_struct *user;
int ret; int ret;
if (size >= U32_MAX - PAGE_SIZE)
return -E2BIG;
user = get_current_user();
ret = bpf_charge_memlock(user, pages); ret = bpf_charge_memlock(user, pages);
if (ret) { if (ret) {
free_uid(user); free_uid(user);
......
...@@ -37,12 +37,9 @@ static struct bpf_map *xsk_map_alloc(union bpf_attr *attr) ...@@ -37,12 +37,9 @@ static struct bpf_map *xsk_map_alloc(union bpf_attr *attr)
cost = (u64)m->map.max_entries * sizeof(struct xdp_sock *); cost = (u64)m->map.max_entries * sizeof(struct xdp_sock *);
cost += sizeof(struct list_head) * num_possible_cpus(); cost += sizeof(struct list_head) * num_possible_cpus();
if (cost >= U32_MAX - PAGE_SIZE)
goto free_m;
/* Notice returns -EPERM on if map size is larger than memlock limit */ /* Notice returns -EPERM on if map size is larger than memlock limit */
err = bpf_map_charge_init(&m->map.memory, err = bpf_map_charge_init(&m->map.memory, cost);
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (err) if (err)
goto free_m; goto free_m;
......
...@@ -626,7 +626,6 @@ static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr) ...@@ -626,7 +626,6 @@ static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
struct bpf_sk_storage_map *smap; struct bpf_sk_storage_map *smap;
unsigned int i; unsigned int i;
u32 nbuckets; u32 nbuckets;
u32 pages;
u64 cost; u64 cost;
int ret; int ret;
...@@ -638,9 +637,8 @@ static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr) ...@@ -638,9 +637,8 @@ static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
smap->bucket_log = ilog2(roundup_pow_of_two(num_possible_cpus())); smap->bucket_log = ilog2(roundup_pow_of_two(num_possible_cpus()));
nbuckets = 1U << smap->bucket_log; nbuckets = 1U << smap->bucket_log;
cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap); cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
ret = bpf_map_charge_init(&smap->map.memory, pages); ret = bpf_map_charge_init(&smap->map.memory, cost);
if (ret < 0) { if (ret < 0) {
kfree(smap); kfree(smap);
return ERR_PTR(ret); return ERR_PTR(ret);
......
...@@ -44,13 +44,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr) ...@@ -44,13 +44,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
/* Make sure page count doesn't overflow. */ /* Make sure page count doesn't overflow. */
cost = (u64) stab->map.max_entries * sizeof(struct sock *); cost = (u64) stab->map.max_entries * sizeof(struct sock *);
if (cost >= U32_MAX - PAGE_SIZE) { err = bpf_map_charge_init(&stab->map.memory, cost);
err = -EINVAL;
goto free_stab;
}
err = bpf_map_charge_init(&stab->map.memory,
round_up(cost, PAGE_SIZE) >> PAGE_SHIFT);
if (err) if (err)
goto free_stab; goto free_stab;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment