Commit 8fe45924 authored by Teng Qin's avatar Teng Qin Committed by David S. Miller

bpf: map_get_next_key to return first key on NULL

When iterating through a map, we need to find a key that does not exist
in the map so map_get_next_key will give us the first key of the map.
This often requires a lot of guessing in production systems.

This patch makes map_get_next_key return the first key when the key
pointer in the parameter is NULL.
Signed-off-by: default avatarTeng Qin <qinteng@fb.com>
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
Acked-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 472ecf08
...@@ -321,11 +321,14 @@ TRACE_EVENT(bpf_map_next_key, ...@@ -321,11 +321,14 @@ TRACE_EVENT(bpf_map_next_key,
__dynamic_array(u8, key, map->key_size) __dynamic_array(u8, key, map->key_size)
__dynamic_array(u8, nxt, map->key_size) __dynamic_array(u8, nxt, map->key_size)
__field(bool, key_trunc) __field(bool, key_trunc)
__field(bool, key_null)
__field(int, ufd) __field(int, ufd)
), ),
TP_fast_assign( TP_fast_assign(
memcpy(__get_dynamic_array(key), key, map->key_size); if (key)
memcpy(__get_dynamic_array(key), key, map->key_size);
__entry->key_null = !key;
memcpy(__get_dynamic_array(nxt), key_next, map->key_size); memcpy(__get_dynamic_array(nxt), key_next, map->key_size);
__entry->type = map->map_type; __entry->type = map->map_type;
__entry->key_len = min(map->key_size, 16U); __entry->key_len = min(map->key_size, 16U);
...@@ -336,8 +339,9 @@ TRACE_EVENT(bpf_map_next_key, ...@@ -336,8 +339,9 @@ TRACE_EVENT(bpf_map_next_key,
TP_printk("map type=%s ufd=%d key=[%s%s] next=[%s%s]", TP_printk("map type=%s ufd=%d key=[%s%s] next=[%s%s]",
__print_symbolic(__entry->type, __MAP_TYPE_SYM_TAB), __print_symbolic(__entry->type, __MAP_TYPE_SYM_TAB),
__entry->ufd, __entry->ufd,
__print_hex(__get_dynamic_array(key), __entry->key_len), __entry->key_null ? "NULL" : __print_hex(__get_dynamic_array(key),
__entry->key_trunc ? " ..." : "", __entry->key_len),
__entry->key_trunc && !__entry->key_null ? " ..." : "",
__print_hex(__get_dynamic_array(nxt), __entry->key_len), __print_hex(__get_dynamic_array(nxt), __entry->key_len),
__entry->key_trunc ? " ..." : "") __entry->key_trunc ? " ..." : "")
); );
......
...@@ -182,7 +182,7 @@ int bpf_percpu_array_copy(struct bpf_map *map, void *key, void *value) ...@@ -182,7 +182,7 @@ int bpf_percpu_array_copy(struct bpf_map *map, void *key, void *value)
static int array_map_get_next_key(struct bpf_map *map, void *key, void *next_key) static int array_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
{ {
struct bpf_array *array = container_of(map, struct bpf_array, map); struct bpf_array *array = container_of(map, struct bpf_array, map);
u32 index = *(u32 *)key; u32 index = key ? *(u32 *)key : U32_MAX;
u32 *next = (u32 *)next_key; u32 *next = (u32 *)next_key;
if (index >= array->map.max_entries) { if (index >= array->map.max_entries) {
......
...@@ -540,12 +540,15 @@ static int htab_map_get_next_key(struct bpf_map *map, void *key, void *next_key) ...@@ -540,12 +540,15 @@ static int htab_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
struct hlist_nulls_head *head; struct hlist_nulls_head *head;
struct htab_elem *l, *next_l; struct htab_elem *l, *next_l;
u32 hash, key_size; u32 hash, key_size;
int i; int i = 0;
WARN_ON_ONCE(!rcu_read_lock_held()); WARN_ON_ONCE(!rcu_read_lock_held());
key_size = map->key_size; key_size = map->key_size;
if (!key)
goto find_first_elem;
hash = htab_map_hash(key, key_size); hash = htab_map_hash(key, key_size);
head = select_bucket(htab, hash); head = select_bucket(htab, hash);
...@@ -553,10 +556,8 @@ static int htab_map_get_next_key(struct bpf_map *map, void *key, void *next_key) ...@@ -553,10 +556,8 @@ static int htab_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
/* lookup the key */ /* lookup the key */
l = lookup_nulls_elem_raw(head, hash, key, key_size, htab->n_buckets); l = lookup_nulls_elem_raw(head, hash, key, key_size, htab->n_buckets);
if (!l) { if (!l)
i = 0;
goto find_first_elem; goto find_first_elem;
}
/* key was found, get next key in the same bucket */ /* key was found, get next key in the same bucket */
next_l = hlist_nulls_entry_safe(rcu_dereference_raw(hlist_nulls_next_rcu(&l->hash_node)), next_l = hlist_nulls_entry_safe(rcu_dereference_raw(hlist_nulls_next_rcu(&l->hash_node)),
......
...@@ -536,14 +536,18 @@ static int map_get_next_key(union bpf_attr *attr) ...@@ -536,14 +536,18 @@ static int map_get_next_key(union bpf_attr *attr)
if (IS_ERR(map)) if (IS_ERR(map))
return PTR_ERR(map); return PTR_ERR(map);
err = -ENOMEM; if (ukey) {
key = kmalloc(map->key_size, GFP_USER); err = -ENOMEM;
if (!key) key = kmalloc(map->key_size, GFP_USER);
goto err_put; if (!key)
goto err_put;
err = -EFAULT;
if (copy_from_user(key, ukey, map->key_size) != 0) err = -EFAULT;
goto free_key; if (copy_from_user(key, ukey, map->key_size) != 0)
goto free_key;
} else {
key = NULL;
}
err = -ENOMEM; err = -ENOMEM;
next_key = kmalloc(map->key_size, GFP_USER); next_key = kmalloc(map->key_size, GFP_USER);
......
...@@ -28,7 +28,7 @@ static int map_flags; ...@@ -28,7 +28,7 @@ static int map_flags;
static void test_hashmap(int task, void *data) static void test_hashmap(int task, void *data)
{ {
long long key, next_key, value; long long key, next_key, first_key, value;
int fd; int fd;
fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(key), sizeof(value), fd = bpf_create_map(BPF_MAP_TYPE_HASH, sizeof(key), sizeof(value),
...@@ -89,10 +89,13 @@ static void test_hashmap(int task, void *data) ...@@ -89,10 +89,13 @@ static void test_hashmap(int task, void *data)
assert(bpf_map_delete_elem(fd, &key) == -1 && errno == ENOENT); assert(bpf_map_delete_elem(fd, &key) == -1 && errno == ENOENT);
/* Iterate over two elements. */ /* Iterate over two elements. */
assert(bpf_map_get_next_key(fd, NULL, &first_key) == 0 &&
(first_key == 1 || first_key == 2));
assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 &&
(next_key == 1 || next_key == 2)); (next_key == first_key));
assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 &&
(next_key == 1 || next_key == 2)); (next_key == 1 || next_key == 2) &&
(next_key != first_key));
assert(bpf_map_get_next_key(fd, &next_key, &next_key) == -1 && assert(bpf_map_get_next_key(fd, &next_key, &next_key) == -1 &&
errno == ENOENT); errno == ENOENT);
...@@ -105,6 +108,8 @@ static void test_hashmap(int task, void *data) ...@@ -105,6 +108,8 @@ static void test_hashmap(int task, void *data)
key = 0; key = 0;
/* Check that map is empty. */ /* Check that map is empty. */
assert(bpf_map_get_next_key(fd, NULL, &next_key) == -1 &&
errno == ENOENT);
assert(bpf_map_get_next_key(fd, &key, &next_key) == -1 && assert(bpf_map_get_next_key(fd, &key, &next_key) == -1 &&
errno == ENOENT); errno == ENOENT);
...@@ -133,7 +138,7 @@ static void test_hashmap_percpu(int task, void *data) ...@@ -133,7 +138,7 @@ static void test_hashmap_percpu(int task, void *data)
{ {
unsigned int nr_cpus = bpf_num_possible_cpus(); unsigned int nr_cpus = bpf_num_possible_cpus();
long long value[nr_cpus]; long long value[nr_cpus];
long long key, next_key; long long key, next_key, first_key;
int expected_key_mask = 0; int expected_key_mask = 0;
int fd, i; int fd, i;
...@@ -193,7 +198,13 @@ static void test_hashmap_percpu(int task, void *data) ...@@ -193,7 +198,13 @@ static void test_hashmap_percpu(int task, void *data)
assert(bpf_map_delete_elem(fd, &key) == -1 && errno == ENOENT); assert(bpf_map_delete_elem(fd, &key) == -1 && errno == ENOENT);
/* Iterate over two elements. */ /* Iterate over two elements. */
assert(bpf_map_get_next_key(fd, NULL, &first_key) == 0 &&
((expected_key_mask & first_key) == first_key));
while (!bpf_map_get_next_key(fd, &key, &next_key)) { while (!bpf_map_get_next_key(fd, &key, &next_key)) {
if (first_key) {
assert(next_key == first_key);
first_key = 0;
}
assert((expected_key_mask & next_key) == next_key); assert((expected_key_mask & next_key) == next_key);
expected_key_mask &= ~next_key; expected_key_mask &= ~next_key;
...@@ -219,6 +230,8 @@ static void test_hashmap_percpu(int task, void *data) ...@@ -219,6 +230,8 @@ static void test_hashmap_percpu(int task, void *data)
key = 0; key = 0;
/* Check that map is empty. */ /* Check that map is empty. */
assert(bpf_map_get_next_key(fd, NULL, &next_key) == -1 &&
errno == ENOENT);
assert(bpf_map_get_next_key(fd, &key, &next_key) == -1 && assert(bpf_map_get_next_key(fd, &key, &next_key) == -1 &&
errno == ENOENT); errno == ENOENT);
...@@ -264,6 +277,8 @@ static void test_arraymap(int task, void *data) ...@@ -264,6 +277,8 @@ static void test_arraymap(int task, void *data)
assert(bpf_map_lookup_elem(fd, &key, &value) == -1 && errno == ENOENT); assert(bpf_map_lookup_elem(fd, &key, &value) == -1 && errno == ENOENT);
/* Iterate over two elements. */ /* Iterate over two elements. */
assert(bpf_map_get_next_key(fd, NULL, &next_key) == 0 &&
next_key == 0);
assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 &&
next_key == 0); next_key == 0);
assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 &&
...@@ -319,6 +334,8 @@ static void test_arraymap_percpu(int task, void *data) ...@@ -319,6 +334,8 @@ static void test_arraymap_percpu(int task, void *data)
assert(bpf_map_lookup_elem(fd, &key, values) == -1 && errno == ENOENT); assert(bpf_map_lookup_elem(fd, &key, values) == -1 && errno == ENOENT);
/* Iterate over two elements. */ /* Iterate over two elements. */
assert(bpf_map_get_next_key(fd, NULL, &next_key) == 0 &&
next_key == 0);
assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &key, &next_key) == 0 &&
next_key == 0); next_key == 0);
assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 && assert(bpf_map_get_next_key(fd, &next_key, &next_key) == 0 &&
...@@ -400,6 +417,8 @@ static void test_map_large(void) ...@@ -400,6 +417,8 @@ static void test_map_large(void)
errno == E2BIG); errno == E2BIG);
/* Iterate through all elements. */ /* Iterate through all elements. */
assert(bpf_map_get_next_key(fd, NULL, &key) == 0);
key.c = -1;
for (i = 0; i < MAP_SIZE; i++) for (i = 0; i < MAP_SIZE; i++)
assert(bpf_map_get_next_key(fd, &key, &key) == 0); assert(bpf_map_get_next_key(fd, &key, &key) == 0);
assert(bpf_map_get_next_key(fd, &key, &key) == -1 && errno == ENOENT); assert(bpf_map_get_next_key(fd, &key, &key) == -1 && errno == ENOENT);
...@@ -499,6 +518,7 @@ static void test_map_parallel(void) ...@@ -499,6 +518,7 @@ static void test_map_parallel(void)
errno == EEXIST); errno == EEXIST);
/* Check that all elements were inserted. */ /* Check that all elements were inserted. */
assert(bpf_map_get_next_key(fd, NULL, &key) == 0);
key = -1; key = -1;
for (i = 0; i < MAP_SIZE; i++) for (i = 0; i < MAP_SIZE; i++)
assert(bpf_map_get_next_key(fd, &key, &key) == 0); assert(bpf_map_get_next_key(fd, &key, &key) == 0);
...@@ -518,6 +538,7 @@ static void test_map_parallel(void) ...@@ -518,6 +538,7 @@ static void test_map_parallel(void)
/* Nothing should be left. */ /* Nothing should be left. */
key = -1; key = -1;
assert(bpf_map_get_next_key(fd, NULL, &key) == -1 && errno == ENOENT);
assert(bpf_map_get_next_key(fd, &key, &key) == -1 && errno == ENOENT); assert(bpf_map_get_next_key(fd, &key, &key) == -1 && errno == ENOENT);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment