Commit 8d24c0b4 authored by Thomas Graf's avatar Thomas Graf Committed by David S. Miller

rhashtable: Do hashing inside of rhashtable_lookup_compare()

Hash the key inside of rhashtable_lookup_compare() like
rhashtable_lookup() does. This allows to simplify the hashing
functions and keep them private.
Signed-off-by: default avatarThomas Graf <tgraf@suug.ch>
Cc: netfilter-devel@vger.kernel.org
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent dd955398
...@@ -96,9 +96,6 @@ static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht) ...@@ -96,9 +96,6 @@ static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params); int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params);
u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len);
u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr);
void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node); void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node);
bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node); bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node);
void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj, void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
...@@ -111,7 +108,7 @@ int rhashtable_expand(struct rhashtable *ht); ...@@ -111,7 +108,7 @@ int rhashtable_expand(struct rhashtable *ht);
int rhashtable_shrink(struct rhashtable *ht); int rhashtable_shrink(struct rhashtable *ht);
void *rhashtable_lookup(const struct rhashtable *ht, const void *key); void *rhashtable_lookup(const struct rhashtable *ht, const void *key);
void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, void *rhashtable_lookup_compare(const struct rhashtable *ht, const void *key,
bool (*compare)(void *, void *), void *arg); bool (*compare)(void *, void *), void *arg);
void rhashtable_destroy(const struct rhashtable *ht); void rhashtable_destroy(const struct rhashtable *ht);
......
...@@ -42,69 +42,39 @@ static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he) ...@@ -42,69 +42,39 @@ static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he)
return (void *) he - ht->p.head_offset; return (void *) he - ht->p.head_offset;
} }
static u32 __hashfn(const struct rhashtable *ht, const void *key, static u32 rht_bucket_index(const struct bucket_table *tbl, u32 hash)
u32 len, u32 hsize)
{ {
u32 h; return hash & (tbl->size - 1);
h = ht->p.hashfn(key, len, ht->p.hash_rnd);
return h & (hsize - 1);
}
/**
* rhashtable_hashfn - compute hash for key of given length
* @ht: hash table to compute for
* @key: pointer to key
* @len: length of key
*
* Computes the hash value using the hash function provided in the 'hashfn'
* of struct rhashtable_params. The returned value is guaranteed to be
* smaller than the number of buckets in the hash table.
*/
u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len)
{
struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
return __hashfn(ht, key, len, tbl->size);
} }
EXPORT_SYMBOL_GPL(rhashtable_hashfn);
static u32 obj_hashfn(const struct rhashtable *ht, const void *ptr, u32 hsize) static u32 obj_raw_hashfn(const struct rhashtable *ht, const void *ptr)
{ {
if (unlikely(!ht->p.key_len)) { u32 hash;
u32 h;
h = ht->p.obj_hashfn(ptr, ht->p.hash_rnd);
return h & (hsize - 1); if (unlikely(!ht->p.key_len))
} hash = ht->p.obj_hashfn(ptr, ht->p.hash_rnd);
else
hash = ht->p.hashfn(ptr + ht->p.key_offset, ht->p.key_len,
ht->p.hash_rnd);
return __hashfn(ht, ptr + ht->p.key_offset, ht->p.key_len, hsize); return hash;
} }
/** static u32 key_hashfn(const struct rhashtable *ht, const void *key, u32 len)
* rhashtable_obj_hashfn - compute hash for hashed object
* @ht: hash table to compute for
* @ptr: pointer to hashed object
*
* Computes the hash value using the hash function `hashfn` respectively
* 'obj_hashfn' depending on whether the hash table is set up to work with
* a fixed length key. The returned value is guaranteed to be smaller than
* the number of buckets in the hash table.
*/
u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr)
{ {
struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
u32 hash;
hash = ht->p.hashfn(key, len, ht->p.hash_rnd);
return obj_hashfn(ht, ptr, tbl->size); return rht_bucket_index(tbl, hash);
} }
EXPORT_SYMBOL_GPL(rhashtable_obj_hashfn);
static u32 head_hashfn(const struct rhashtable *ht, static u32 head_hashfn(const struct rhashtable *ht,
const struct rhash_head *he, u32 hsize) const struct bucket_table *tbl,
const struct rhash_head *he)
{ {
return obj_hashfn(ht, rht_obj(ht, he), hsize); return rht_bucket_index(tbl, obj_raw_hashfn(ht, rht_obj(ht, he)));
} }
static struct bucket_table *bucket_table_alloc(size_t nbuckets) static struct bucket_table *bucket_table_alloc(size_t nbuckets)
...@@ -170,9 +140,9 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, ...@@ -170,9 +140,9 @@ static void hashtable_chain_unzip(const struct rhashtable *ht,
* reaches a node that doesn't hash to the same bucket as the * reaches a node that doesn't hash to the same bucket as the
* previous node p. Call the previous node p; * previous node p. Call the previous node p;
*/ */
h = head_hashfn(ht, p, new_tbl->size); h = head_hashfn(ht, new_tbl, p);
rht_for_each(he, p->next, ht) { rht_for_each(he, p->next, ht) {
if (head_hashfn(ht, he, new_tbl->size) != h) if (head_hashfn(ht, new_tbl, he) != h)
break; break;
p = he; p = he;
} }
...@@ -184,7 +154,7 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, ...@@ -184,7 +154,7 @@ static void hashtable_chain_unzip(const struct rhashtable *ht,
next = NULL; next = NULL;
if (he) { if (he) {
rht_for_each(he, he->next, ht) { rht_for_each(he, he->next, ht) {
if (head_hashfn(ht, he, new_tbl->size) == h) { if (head_hashfn(ht, new_tbl, he) == h) {
next = he; next = he;
break; break;
} }
...@@ -237,9 +207,9 @@ int rhashtable_expand(struct rhashtable *ht) ...@@ -237,9 +207,9 @@ int rhashtable_expand(struct rhashtable *ht)
* single imprecise chain. * single imprecise chain.
*/ */
for (i = 0; i < new_tbl->size; i++) { for (i = 0; i < new_tbl->size; i++) {
h = i & (old_tbl->size - 1); h = rht_bucket_index(old_tbl, i);
rht_for_each(he, old_tbl->buckets[h], ht) { rht_for_each(he, old_tbl->buckets[h], ht) {
if (head_hashfn(ht, he, new_tbl->size) == i) { if (head_hashfn(ht, new_tbl, he) == i) {
RCU_INIT_POINTER(new_tbl->buckets[i], he); RCU_INIT_POINTER(new_tbl->buckets[i], he);
break; break;
} }
...@@ -353,7 +323,7 @@ void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj) ...@@ -353,7 +323,7 @@ void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj)
ASSERT_RHT_MUTEX(ht); ASSERT_RHT_MUTEX(ht);
hash = head_hashfn(ht, obj, tbl->size); hash = head_hashfn(ht, tbl, obj);
RCU_INIT_POINTER(obj->next, tbl->buckets[hash]); RCU_INIT_POINTER(obj->next, tbl->buckets[hash]);
rcu_assign_pointer(tbl->buckets[hash], obj); rcu_assign_pointer(tbl->buckets[hash], obj);
ht->nelems++; ht->nelems++;
...@@ -413,7 +383,7 @@ bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj) ...@@ -413,7 +383,7 @@ bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj)
ASSERT_RHT_MUTEX(ht); ASSERT_RHT_MUTEX(ht);
h = head_hashfn(ht, obj, tbl->size); h = head_hashfn(ht, tbl, obj);
pprev = &tbl->buckets[h]; pprev = &tbl->buckets[h];
rht_for_each(he, tbl->buckets[h], ht) { rht_for_each(he, tbl->buckets[h], ht) {
...@@ -452,7 +422,7 @@ void *rhashtable_lookup(const struct rhashtable *ht, const void *key) ...@@ -452,7 +422,7 @@ void *rhashtable_lookup(const struct rhashtable *ht, const void *key)
BUG_ON(!ht->p.key_len); BUG_ON(!ht->p.key_len);
h = __hashfn(ht, key, ht->p.key_len, tbl->size); h = key_hashfn(ht, key, ht->p.key_len);
rht_for_each_rcu(he, tbl->buckets[h], ht) { rht_for_each_rcu(he, tbl->buckets[h], ht) {
if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key, if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key,
ht->p.key_len)) ht->p.key_len))
...@@ -467,7 +437,7 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup); ...@@ -467,7 +437,7 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup);
/** /**
* rhashtable_lookup_compare - search hash table with compare function * rhashtable_lookup_compare - search hash table with compare function
* @ht: hash table * @ht: hash table
* @hash: hash value of desired entry * @key: the pointer to the key
* @compare: compare function, must return true on match * @compare: compare function, must return true on match
* @arg: argument passed on to compare function * @arg: argument passed on to compare function
* *
...@@ -479,15 +449,14 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup); ...@@ -479,15 +449,14 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup);
* *
* Returns the first entry on which the compare function returned true. * Returns the first entry on which the compare function returned true.
*/ */
void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, void *rhashtable_lookup_compare(const struct rhashtable *ht, const void *key,
bool (*compare)(void *, void *), void *arg) bool (*compare)(void *, void *), void *arg)
{ {
const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
struct rhash_head *he; struct rhash_head *he;
u32 hash;
if (unlikely(hash >= tbl->size)) hash = key_hashfn(ht, key, ht->p.key_len);
return NULL;
rht_for_each_rcu(he, tbl->buckets[hash], ht) { rht_for_each_rcu(he, tbl->buckets[hash], ht) {
if (!compare(rht_obj(ht, he), arg)) if (!compare(rht_obj(ht, he), arg))
continue; continue;
......
...@@ -94,28 +94,40 @@ static void nft_hash_remove(const struct nft_set *set, ...@@ -94,28 +94,40 @@ static void nft_hash_remove(const struct nft_set *set,
kfree(he); kfree(he);
} }
struct nft_compare_arg {
const struct nft_set *set;
struct nft_set_elem *elem;
};
static bool nft_hash_compare(void *ptr, void *arg)
{
struct nft_hash_elem *he = ptr;
struct nft_compare_arg *x = arg;
if (!nft_data_cmp(&he->key, &x->elem->key, x->set->klen)) {
x->elem->cookie = &he->node;
x->elem->flags = 0;
if (x->set->flags & NFT_SET_MAP)
nft_data_copy(&x->elem->data, he->data);
return true;
}
return false;
}
static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem) static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
{ {
const struct rhashtable *priv = nft_set_priv(set); const struct rhashtable *priv = nft_set_priv(set);
const struct bucket_table *tbl = rht_dereference_rcu(priv->tbl, priv); struct nft_compare_arg arg = {
struct rhash_head __rcu * const *pprev; .set = set,
struct nft_hash_elem *he; .elem = elem,
u32 h; };
h = rhashtable_hashfn(priv, &elem->key, set->klen);
pprev = &tbl->buckets[h];
rht_for_each_entry_rcu(he, tbl->buckets[h], node) {
if (nft_data_cmp(&he->key, &elem->key, set->klen)) {
pprev = &he->node.next;
continue;
}
elem->cookie = (void *)pprev; if (rhashtable_lookup_compare(priv, &elem->key,
elem->flags = 0; &nft_hash_compare, &arg))
if (set->flags & NFT_SET_MAP)
nft_data_copy(&elem->data, he->data);
return 0; return 0;
}
return -ENOENT; return -ENOENT;
} }
......
...@@ -1002,11 +1002,8 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, ...@@ -1002,11 +1002,8 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
.net = net, .net = net,
.portid = portid, .portid = portid,
}; };
u32 hash;
hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid)); return rhashtable_lookup_compare(&table->hash, &portid,
return rhashtable_lookup_compare(&table->hash, hash,
&netlink_compare, &arg); &netlink_compare, &arg);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment