Commit 7beceebf authored by David S. Miller's avatar David S. Miller

Merge branch 'rhashtable-next'

Thomas Graf says:

====================
rhashtable: Per bucket locks & deferred table resizing

Prepares for and introduces per bucket spinlocks and deferred table
resizing. This allows for parallel table mutations in different hash
buckets from atomic context. The resizing occurs in the background
in a separate worker thread while lookups, inserts, and removals can
continue.

Also modified the chain linked list to be terminated with a special
nulls marker to allow entries to move between multiple lists.

Last but not least, reintroduces lockless netlink_lookup() with
deferred Netlink socket destruction to avoid the side effect of
increased netlink_release() runtime.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents dd955398 21e4902a
...@@ -21,8 +21,9 @@ struct hlist_nulls_head { ...@@ -21,8 +21,9 @@ struct hlist_nulls_head {
struct hlist_nulls_node { struct hlist_nulls_node {
struct hlist_nulls_node *next, **pprev; struct hlist_nulls_node *next, **pprev;
}; };
#define NULLS_MARKER(value) (1UL | (((long)value) << 1))
#define INIT_HLIST_NULLS_HEAD(ptr, nulls) \ #define INIT_HLIST_NULLS_HEAD(ptr, nulls) \
((ptr)->first = (struct hlist_nulls_node *) (1UL | (((long)nulls) << 1))) ((ptr)->first = (struct hlist_nulls_node *) NULLS_MARKER(nulls))
#define hlist_nulls_entry(ptr, type, member) container_of(ptr,type,member) #define hlist_nulls_entry(ptr, type, member) container_of(ptr,type,member)
/** /**
......
...@@ -18,16 +18,43 @@ ...@@ -18,16 +18,43 @@
#ifndef _LINUX_RHASHTABLE_H #ifndef _LINUX_RHASHTABLE_H
#define _LINUX_RHASHTABLE_H #define _LINUX_RHASHTABLE_H
#include <linux/rculist.h> #include <linux/list_nulls.h>
#include <linux/workqueue.h>
/*
* The end of the chain is marked with a special nulls marks which has
* the following format:
*
* +-------+-----------------------------------------------------+-+
* | Base | Hash |1|
* +-------+-----------------------------------------------------+-+
*
* Base (4 bits) : Reserved to distinguish between multiple tables.
* Specified via &struct rhashtable_params.nulls_base.
* Hash (27 bits): Full hash (unmasked) of first element added to bucket
* 1 (1 bit) : Nulls marker (always set)
*
* The remaining bits of the next pointer remain unused for now.
*/
#define RHT_BASE_BITS 4
#define RHT_HASH_BITS 27
#define RHT_BASE_SHIFT RHT_HASH_BITS
struct rhash_head { struct rhash_head {
struct rhash_head __rcu *next; struct rhash_head __rcu *next;
}; };
#define INIT_HASH_HEAD(ptr) ((ptr)->next = NULL) /**
* struct bucket_table - Table of hash buckets
* @size: Number of hash buckets
* @locks_mask: Mask to apply before accessing locks[]
* @locks: Array of spinlocks protecting individual buckets
* @buckets: size * hash buckets
*/
struct bucket_table { struct bucket_table {
size_t size; size_t size;
unsigned int locks_mask;
spinlock_t *locks;
struct rhash_head __rcu *buckets[]; struct rhash_head __rcu *buckets[];
}; };
...@@ -45,11 +72,12 @@ struct rhashtable; ...@@ -45,11 +72,12 @@ struct rhashtable;
* @hash_rnd: Seed to use while hashing * @hash_rnd: Seed to use while hashing
* @max_shift: Maximum number of shifts while expanding * @max_shift: Maximum number of shifts while expanding
* @min_shift: Minimum number of shifts while shrinking * @min_shift: Minimum number of shifts while shrinking
* @nulls_base: Base value to generate nulls marker
* @locks_mul: Number of bucket locks to allocate per cpu (default: 128)
* @hashfn: Function to hash key * @hashfn: Function to hash key
* @obj_hashfn: Function to hash object * @obj_hashfn: Function to hash object
* @grow_decision: If defined, may return true if table should expand * @grow_decision: If defined, may return true if table should expand
* @shrink_decision: If defined, may return true if table should shrink * @shrink_decision: If defined, may return true if table should shrink
* @mutex_is_held: Must return true if protecting mutex is held
*/ */
struct rhashtable_params { struct rhashtable_params {
size_t nelem_hint; size_t nelem_hint;
...@@ -59,36 +87,67 @@ struct rhashtable_params { ...@@ -59,36 +87,67 @@ struct rhashtable_params {
u32 hash_rnd; u32 hash_rnd;
size_t max_shift; size_t max_shift;
size_t min_shift; size_t min_shift;
u32 nulls_base;
size_t locks_mul;
rht_hashfn_t hashfn; rht_hashfn_t hashfn;
rht_obj_hashfn_t obj_hashfn; rht_obj_hashfn_t obj_hashfn;
bool (*grow_decision)(const struct rhashtable *ht, bool (*grow_decision)(const struct rhashtable *ht,
size_t new_size); size_t new_size);
bool (*shrink_decision)(const struct rhashtable *ht, bool (*shrink_decision)(const struct rhashtable *ht,
size_t new_size); size_t new_size);
#ifdef CONFIG_PROVE_LOCKING
int (*mutex_is_held)(void *parent);
void *parent;
#endif
}; };
/** /**
* struct rhashtable - Hash table handle * struct rhashtable - Hash table handle
* @tbl: Bucket table * @tbl: Bucket table
* @future_tbl: Table under construction during expansion/shrinking
* @nelems: Number of elements in table * @nelems: Number of elements in table
* @shift: Current size (1 << shift) * @shift: Current size (1 << shift)
* @p: Configuration parameters * @p: Configuration parameters
* @run_work: Deferred worker to expand/shrink asynchronously
* @mutex: Mutex to protect current/future table swapping
* @being_destroyed: True if table is set up for destruction
*/ */
struct rhashtable { struct rhashtable {
struct bucket_table __rcu *tbl; struct bucket_table __rcu *tbl;
size_t nelems; struct bucket_table __rcu *future_tbl;
atomic_t nelems;
size_t shift; size_t shift;
struct rhashtable_params p; struct rhashtable_params p;
struct delayed_work run_work;
struct mutex mutex;
bool being_destroyed;
}; };
static inline unsigned long rht_marker(const struct rhashtable *ht, u32 hash)
{
return NULLS_MARKER(ht->p.nulls_base + hash);
}
#define INIT_RHT_NULLS_HEAD(ptr, ht, hash) \
((ptr) = (typeof(ptr)) rht_marker(ht, hash))
static inline bool rht_is_a_nulls(const struct rhash_head *ptr)
{
return ((unsigned long) ptr & 1);
}
static inline unsigned long rht_get_nulls_value(const struct rhash_head *ptr)
{
return ((unsigned long) ptr) >> 1;
}
#ifdef CONFIG_PROVE_LOCKING #ifdef CONFIG_PROVE_LOCKING
int lockdep_rht_mutex_is_held(const struct rhashtable *ht); int lockdep_rht_mutex_is_held(struct rhashtable *ht);
int lockdep_rht_bucket_is_held(const struct bucket_table *tbl, u32 hash);
#else #else
static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht) static inline int lockdep_rht_mutex_is_held(struct rhashtable *ht)
{
return 1;
}
static inline int lockdep_rht_bucket_is_held(const struct bucket_table *tbl,
u32 hash)
{ {
return 1; return 1;
} }
...@@ -96,13 +155,8 @@ static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht) ...@@ -96,13 +155,8 @@ static inline int lockdep_rht_mutex_is_held(const struct rhashtable *ht)
int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params); int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params);
u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len);
u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr);
void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node); void rhashtable_insert(struct rhashtable *ht, struct rhash_head *node);
bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node); bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *node);
void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj,
struct rhash_head __rcu **pprev);
bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size); bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size);
bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size); bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size);
...@@ -110,11 +164,11 @@ bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size); ...@@ -110,11 +164,11 @@ bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size);
int rhashtable_expand(struct rhashtable *ht); int rhashtable_expand(struct rhashtable *ht);
int rhashtable_shrink(struct rhashtable *ht); int rhashtable_shrink(struct rhashtable *ht);
void *rhashtable_lookup(const struct rhashtable *ht, const void *key); void *rhashtable_lookup(struct rhashtable *ht, const void *key);
void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, void *rhashtable_lookup_compare(struct rhashtable *ht, const void *key,
bool (*compare)(void *, void *), void *arg); bool (*compare)(void *, void *), void *arg);
void rhashtable_destroy(const struct rhashtable *ht); void rhashtable_destroy(struct rhashtable *ht);
#define rht_dereference(p, ht) \ #define rht_dereference(p, ht) \
rcu_dereference_protected(p, lockdep_rht_mutex_is_held(ht)) rcu_dereference_protected(p, lockdep_rht_mutex_is_held(ht))
...@@ -122,92 +176,144 @@ void rhashtable_destroy(const struct rhashtable *ht); ...@@ -122,92 +176,144 @@ void rhashtable_destroy(const struct rhashtable *ht);
#define rht_dereference_rcu(p, ht) \ #define rht_dereference_rcu(p, ht) \
rcu_dereference_check(p, lockdep_rht_mutex_is_held(ht)) rcu_dereference_check(p, lockdep_rht_mutex_is_held(ht))
#define rht_entry(ptr, type, member) container_of(ptr, type, member) #define rht_dereference_bucket(p, tbl, hash) \
#define rht_entry_safe(ptr, type, member) \ rcu_dereference_protected(p, lockdep_rht_bucket_is_held(tbl, hash))
({ \
typeof(ptr) __ptr = (ptr); \
__ptr ? rht_entry(__ptr, type, member) : NULL; \
})
#define rht_next_entry_safe(pos, ht, member) \ #define rht_dereference_bucket_rcu(p, tbl, hash) \
({ \ rcu_dereference_check(p, lockdep_rht_bucket_is_held(tbl, hash))
pos ? rht_entry_safe(rht_dereference((pos)->member.next, ht), \
typeof(*(pos)), member) : NULL; \ #define rht_entry(tpos, pos, member) \
}) ({ tpos = container_of(pos, typeof(*tpos), member); 1; })
/**
* rht_for_each_continue - continue iterating over hash chain
* @pos: the &struct rhash_head to use as a loop cursor.
* @head: the previous &struct rhash_head to continue from
* @tbl: the &struct bucket_table
* @hash: the hash value / bucket index
*/
#define rht_for_each_continue(pos, head, tbl, hash) \
for (pos = rht_dereference_bucket(head, tbl, hash); \
!rht_is_a_nulls(pos); \
pos = rht_dereference_bucket((pos)->next, tbl, hash))
/** /**
* rht_for_each - iterate over hash chain * rht_for_each - iterate over hash chain
* @pos: &struct rhash_head to use as a loop cursor. * @pos: the &struct rhash_head to use as a loop cursor.
* @head: head of the hash chain (struct rhash_head *) * @tbl: the &struct bucket_table
* @ht: pointer to your struct rhashtable * @hash: the hash value / bucket index
*/ */
#define rht_for_each(pos, head, ht) \ #define rht_for_each(pos, tbl, hash) \
for (pos = rht_dereference(head, ht); \ rht_for_each_continue(pos, (tbl)->buckets[hash], tbl, hash)
pos; \
pos = rht_dereference((pos)->next, ht)) /**
* rht_for_each_entry_continue - continue iterating over hash chain
* @tpos: the type * to use as a loop cursor.
* @pos: the &struct rhash_head to use as a loop cursor.
* @head: the previous &struct rhash_head to continue from
* @tbl: the &struct bucket_table
* @hash: the hash value / bucket index
* @member: name of the &struct rhash_head within the hashable struct.
*/
#define rht_for_each_entry_continue(tpos, pos, head, tbl, hash, member) \
for (pos = rht_dereference_bucket(head, tbl, hash); \
(!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
pos = rht_dereference_bucket((pos)->next, tbl, hash))
/** /**
* rht_for_each_entry - iterate over hash chain of given type * rht_for_each_entry - iterate over hash chain of given type
* @pos: type * to use as a loop cursor. * @tpos: the type * to use as a loop cursor.
* @head: head of the hash chain (struct rhash_head *) * @pos: the &struct rhash_head to use as a loop cursor.
* @ht: pointer to your struct rhashtable * @tbl: the &struct bucket_table
* @member: name of the rhash_head within the hashable struct. * @hash: the hash value / bucket index
* @member: name of the &struct rhash_head within the hashable struct.
*/ */
#define rht_for_each_entry(pos, head, ht, member) \ #define rht_for_each_entry(tpos, pos, tbl, hash, member) \
for (pos = rht_entry_safe(rht_dereference(head, ht), \ rht_for_each_entry_continue(tpos, pos, (tbl)->buckets[hash], \
typeof(*(pos)), member); \ tbl, hash, member)
pos; \
pos = rht_next_entry_safe(pos, ht, member))
/** /**
* rht_for_each_entry_safe - safely iterate over hash chain of given type * rht_for_each_entry_safe - safely iterate over hash chain of given type
* @pos: type * to use as a loop cursor. * @tpos: the type * to use as a loop cursor.
* @n: type * to use for temporary next object storage * @pos: the &struct rhash_head to use as a loop cursor.
* @head: head of the hash chain (struct rhash_head *) * @next: the &struct rhash_head to use as next in loop cursor.
* @ht: pointer to your struct rhashtable * @tbl: the &struct bucket_table
* @member: name of the rhash_head within the hashable struct. * @hash: the hash value / bucket index
* @member: name of the &struct rhash_head within the hashable struct.
* *
* This hash chain list-traversal primitive allows for the looped code to * This hash chain list-traversal primitive allows for the looped code to
* remove the loop cursor from the list. * remove the loop cursor from the list.
*/ */
#define rht_for_each_entry_safe(pos, n, head, ht, member) \ #define rht_for_each_entry_safe(tpos, pos, next, tbl, hash, member) \
for (pos = rht_entry_safe(rht_dereference(head, ht), \ for (pos = rht_dereference_bucket((tbl)->buckets[hash], tbl, hash), \
typeof(*(pos)), member), \ next = !rht_is_a_nulls(pos) ? \
n = rht_next_entry_safe(pos, ht, member); \ rht_dereference_bucket(pos->next, tbl, hash) : NULL; \
pos; \ (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
pos = n, \ pos = next)
n = rht_next_entry_safe(pos, ht, member))
/**
* rht_for_each_rcu_continue - continue iterating over rcu hash chain
* @pos: the &struct rhash_head to use as a loop cursor.
* @head: the previous &struct rhash_head to continue from
* @tbl: the &struct bucket_table
* @hash: the hash value / bucket index
*
* This hash chain list-traversal primitive may safely run concurrently with
* the _rcu mutation primitives such as rhashtable_insert() as long as the
* traversal is guarded by rcu_read_lock().
*/
#define rht_for_each_rcu_continue(pos, head, tbl, hash) \
for (({barrier(); }), \
pos = rht_dereference_bucket_rcu(head, tbl, hash); \
!rht_is_a_nulls(pos); \
pos = rcu_dereference_raw(pos->next))
/** /**
* rht_for_each_rcu - iterate over rcu hash chain * rht_for_each_rcu - iterate over rcu hash chain
* @pos: &struct rhash_head to use as a loop cursor. * @pos: the &struct rhash_head to use as a loop cursor.
* @head: head of the hash chain (struct rhash_head *) * @tbl: the &struct bucket_table
* @ht: pointer to your struct rhashtable * @hash: the hash value / bucket index
*
* This hash chain list-traversal primitive may safely run concurrently with
* the _rcu mutation primitives such as rhashtable_insert() as long as the
* traversal is guarded by rcu_read_lock().
*/
#define rht_for_each_rcu(pos, tbl, hash) \
rht_for_each_rcu_continue(pos, (tbl)->buckets[hash], tbl, hash)
/**
* rht_for_each_entry_rcu_continue - continue iterating over rcu hash chain
* @tpos: the type * to use as a loop cursor.
* @pos: the &struct rhash_head to use as a loop cursor.
* @head: the previous &struct rhash_head to continue from
* @tbl: the &struct bucket_table
* @hash: the hash value / bucket index
* @member: name of the &struct rhash_head within the hashable struct.
* *
* This hash chain list-traversal primitive may safely run concurrently with * This hash chain list-traversal primitive may safely run concurrently with
* the _rcu fkht mutation primitives such as rht_insert() as long as the * the _rcu mutation primitives such as rhashtable_insert() as long as the
* traversal is guarded by rcu_read_lock(). * traversal is guarded by rcu_read_lock().
*/ */
#define rht_for_each_rcu(pos, head, ht) \ #define rht_for_each_entry_rcu_continue(tpos, pos, head, tbl, hash, member) \
for (pos = rht_dereference_rcu(head, ht); \ for (({barrier(); }), \
pos; \ pos = rht_dereference_bucket_rcu(head, tbl, hash); \
pos = rht_dereference_rcu((pos)->next, ht)) (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
pos = rht_dereference_bucket_rcu(pos->next, tbl, hash))
/** /**
* rht_for_each_entry_rcu - iterate over rcu hash chain of given type * rht_for_each_entry_rcu - iterate over rcu hash chain of given type
* @pos: type * to use as a loop cursor. * @tpos: the type * to use as a loop cursor.
* @head: head of the hash chain (struct rhash_head *) * @pos: the &struct rhash_head to use as a loop cursor.
* @member: name of the rhash_head within the hashable struct. * @tbl: the &struct bucket_table
* @hash: the hash value / bucket index
* @member: name of the &struct rhash_head within the hashable struct.
* *
* This hash chain list-traversal primitive may safely run concurrently with * This hash chain list-traversal primitive may safely run concurrently with
* the _rcu fkht mutation primitives such as rht_insert() as long as the * the _rcu mutation primitives such as rhashtable_insert() as long as the
* traversal is guarded by rcu_read_lock(). * traversal is guarded by rcu_read_lock().
*/ */
#define rht_for_each_entry_rcu(pos, head, member) \ #define rht_for_each_entry_rcu(tpos, pos, tbl, hash, member) \
for (pos = rht_entry_safe(rcu_dereference_raw(head), \ rht_for_each_entry_rcu_continue(tpos, pos, (tbl)->buckets[hash],\
typeof(*(pos)), member); \ tbl, hash, member)
pos; \
pos = rht_entry_safe(rcu_dereference_raw((pos)->member.next), \
typeof(*(pos)), member))
#endif /* _LINUX_RHASHTABLE_H */ #endif /* _LINUX_RHASHTABLE_H */
...@@ -190,6 +190,8 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock) ...@@ -190,6 +190,8 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock)
#ifdef CONFIG_DEBUG_LOCK_ALLOC #ifdef CONFIG_DEBUG_LOCK_ALLOC
# define raw_spin_lock_nested(lock, subclass) \ # define raw_spin_lock_nested(lock, subclass) \
_raw_spin_lock_nested(lock, subclass) _raw_spin_lock_nested(lock, subclass)
# define raw_spin_lock_bh_nested(lock, subclass) \
_raw_spin_lock_bh_nested(lock, subclass)
# define raw_spin_lock_nest_lock(lock, nest_lock) \ # define raw_spin_lock_nest_lock(lock, nest_lock) \
do { \ do { \
...@@ -205,6 +207,7 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock) ...@@ -205,6 +207,7 @@ static inline void do_raw_spin_unlock(raw_spinlock_t *lock) __releases(lock)
# define raw_spin_lock_nested(lock, subclass) \ # define raw_spin_lock_nested(lock, subclass) \
_raw_spin_lock(((void)(subclass), (lock))) _raw_spin_lock(((void)(subclass), (lock)))
# define raw_spin_lock_nest_lock(lock, nest_lock) _raw_spin_lock(lock) # define raw_spin_lock_nest_lock(lock, nest_lock) _raw_spin_lock(lock)
# define raw_spin_lock_bh_nested(lock, subclass) _raw_spin_lock_bh(lock)
#endif #endif
#if defined(CONFIG_SMP) || defined(CONFIG_DEBUG_SPINLOCK) #if defined(CONFIG_SMP) || defined(CONFIG_DEBUG_SPINLOCK)
...@@ -324,6 +327,11 @@ do { \ ...@@ -324,6 +327,11 @@ do { \
raw_spin_lock_nested(spinlock_check(lock), subclass); \ raw_spin_lock_nested(spinlock_check(lock), subclass); \
} while (0) } while (0)
#define spin_lock_bh_nested(lock, subclass) \
do { \
raw_spin_lock_bh_nested(spinlock_check(lock), subclass);\
} while (0)
#define spin_lock_nest_lock(lock, nest_lock) \ #define spin_lock_nest_lock(lock, nest_lock) \
do { \ do { \
raw_spin_lock_nest_lock(spinlock_check(lock), nest_lock); \ raw_spin_lock_nest_lock(spinlock_check(lock), nest_lock); \
......
...@@ -22,6 +22,8 @@ int in_lock_functions(unsigned long addr); ...@@ -22,6 +22,8 @@ int in_lock_functions(unsigned long addr);
void __lockfunc _raw_spin_lock(raw_spinlock_t *lock) __acquires(lock); void __lockfunc _raw_spin_lock(raw_spinlock_t *lock) __acquires(lock);
void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass) void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass)
__acquires(lock); __acquires(lock);
void __lockfunc _raw_spin_lock_bh_nested(raw_spinlock_t *lock, int subclass)
__acquires(lock);
void __lockfunc void __lockfunc
_raw_spin_lock_nest_lock(raw_spinlock_t *lock, struct lockdep_map *map) _raw_spin_lock_nest_lock(raw_spinlock_t *lock, struct lockdep_map *map)
__acquires(lock); __acquires(lock);
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#define _raw_spin_lock(lock) __LOCK(lock) #define _raw_spin_lock(lock) __LOCK(lock)
#define _raw_spin_lock_nested(lock, subclass) __LOCK(lock) #define _raw_spin_lock_nested(lock, subclass) __LOCK(lock)
#define _raw_spin_lock_bh_nested(lock, subclass) __LOCK(lock)
#define _raw_read_lock(lock) __LOCK(lock) #define _raw_read_lock(lock) __LOCK(lock)
#define _raw_write_lock(lock) __LOCK(lock) #define _raw_write_lock(lock) __LOCK(lock)
#define _raw_spin_lock_bh(lock) __LOCK_BH(lock) #define _raw_spin_lock_bh(lock) __LOCK_BH(lock)
......
...@@ -363,6 +363,14 @@ void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass) ...@@ -363,6 +363,14 @@ void __lockfunc _raw_spin_lock_nested(raw_spinlock_t *lock, int subclass)
} }
EXPORT_SYMBOL(_raw_spin_lock_nested); EXPORT_SYMBOL(_raw_spin_lock_nested);
void __lockfunc _raw_spin_lock_bh_nested(raw_spinlock_t *lock, int subclass)
{
__local_bh_disable_ip(_RET_IP_, SOFTIRQ_LOCK_OFFSET);
spin_acquire(&lock->dep_map, subclass, 0, _RET_IP_);
LOCK_CONTENDED(lock, do_raw_spin_trylock, do_raw_spin_lock);
}
EXPORT_SYMBOL(_raw_spin_lock_bh_nested);
unsigned long __lockfunc _raw_spin_lock_irqsave_nested(raw_spinlock_t *lock, unsigned long __lockfunc _raw_spin_lock_irqsave_nested(raw_spinlock_t *lock,
int subclass) int subclass)
{ {
......
...@@ -26,15 +26,47 @@ ...@@ -26,15 +26,47 @@
#define HASH_DEFAULT_SIZE 64UL #define HASH_DEFAULT_SIZE 64UL
#define HASH_MIN_SIZE 4UL #define HASH_MIN_SIZE 4UL
#define BUCKET_LOCKS_PER_CPU 128UL
/* Base bits plus 1 bit for nulls marker */
#define HASH_RESERVED_SPACE (RHT_BASE_BITS + 1)
enum {
RHT_LOCK_NORMAL,
RHT_LOCK_NESTED,
RHT_LOCK_NESTED2,
};
/* The bucket lock is selected based on the hash and protects mutations
* on a group of hash buckets.
*
* IMPORTANT: When holding the bucket lock of both the old and new table
* during expansions and shrinking, the old bucket lock must always be
* acquired first.
*/
static spinlock_t *bucket_lock(const struct bucket_table *tbl, u32 hash)
{
return &tbl->locks[hash & tbl->locks_mask];
}
#define ASSERT_RHT_MUTEX(HT) BUG_ON(!lockdep_rht_mutex_is_held(HT)) #define ASSERT_RHT_MUTEX(HT) BUG_ON(!lockdep_rht_mutex_is_held(HT))
#define ASSERT_BUCKET_LOCK(TBL, HASH) \
BUG_ON(!lockdep_rht_bucket_is_held(TBL, HASH))
#ifdef CONFIG_PROVE_LOCKING #ifdef CONFIG_PROVE_LOCKING
int lockdep_rht_mutex_is_held(const struct rhashtable *ht) int lockdep_rht_mutex_is_held(struct rhashtable *ht)
{ {
return ht->p.mutex_is_held(ht->p.parent); return (debug_locks) ? lockdep_is_held(&ht->mutex) : 1;
} }
EXPORT_SYMBOL_GPL(lockdep_rht_mutex_is_held); EXPORT_SYMBOL_GPL(lockdep_rht_mutex_is_held);
int lockdep_rht_bucket_is_held(const struct bucket_table *tbl, u32 hash)
{
spinlock_t *lock = bucket_lock(tbl, hash);
return (debug_locks) ? lockdep_is_held(lock) : 1;
}
EXPORT_SYMBOL_GPL(lockdep_rht_bucket_is_held);
#endif #endif
static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he) static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he)
...@@ -42,75 +74,101 @@ static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he) ...@@ -42,75 +74,101 @@ static void *rht_obj(const struct rhashtable *ht, const struct rhash_head *he)
return (void *) he - ht->p.head_offset; return (void *) he - ht->p.head_offset;
} }
static u32 __hashfn(const struct rhashtable *ht, const void *key, static u32 rht_bucket_index(const struct bucket_table *tbl, u32 hash)
u32 len, u32 hsize)
{ {
u32 h; return hash & (tbl->size - 1);
}
static u32 obj_raw_hashfn(const struct rhashtable *ht, const void *ptr)
{
u32 hash;
h = ht->p.hashfn(key, len, ht->p.hash_rnd); if (unlikely(!ht->p.key_len))
hash = ht->p.obj_hashfn(ptr, ht->p.hash_rnd);
else
hash = ht->p.hashfn(ptr + ht->p.key_offset, ht->p.key_len,
ht->p.hash_rnd);
return h & (hsize - 1); return hash >> HASH_RESERVED_SPACE;
} }
/** static u32 key_hashfn(struct rhashtable *ht, const void *key, u32 len)
* rhashtable_hashfn - compute hash for key of given length
* @ht: hash table to compute for
* @key: pointer to key
* @len: length of key
*
* Computes the hash value using the hash function provided in the 'hashfn'
* of struct rhashtable_params. The returned value is guaranteed to be
* smaller than the number of buckets in the hash table.
*/
u32 rhashtable_hashfn(const struct rhashtable *ht, const void *key, u32 len)
{ {
struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
u32 hash;
return __hashfn(ht, key, len, tbl->size); hash = ht->p.hashfn(key, len, ht->p.hash_rnd);
hash >>= HASH_RESERVED_SPACE;
return rht_bucket_index(tbl, hash);
} }
EXPORT_SYMBOL_GPL(rhashtable_hashfn);
static u32 obj_hashfn(const struct rhashtable *ht, const void *ptr, u32 hsize) static u32 head_hashfn(const struct rhashtable *ht,
const struct bucket_table *tbl,
const struct rhash_head *he)
{ {
if (unlikely(!ht->p.key_len)) { return rht_bucket_index(tbl, obj_raw_hashfn(ht, rht_obj(ht, he)));
u32 h; }
h = ht->p.obj_hashfn(ptr, ht->p.hash_rnd); static struct rhash_head __rcu **bucket_tail(struct bucket_table *tbl, u32 n)
{
struct rhash_head __rcu **pprev;
return h & (hsize - 1); for (pprev = &tbl->buckets[n];
} !rht_is_a_nulls(rht_dereference_bucket(*pprev, tbl, n));
pprev = &rht_dereference_bucket(*pprev, tbl, n)->next)
;
return __hashfn(ht, ptr + ht->p.key_offset, ht->p.key_len, hsize); return pprev;
} }
/** static int alloc_bucket_locks(struct rhashtable *ht, struct bucket_table *tbl)
* rhashtable_obj_hashfn - compute hash for hashed object
* @ht: hash table to compute for
* @ptr: pointer to hashed object
*
* Computes the hash value using the hash function `hashfn` respectively
* 'obj_hashfn' depending on whether the hash table is set up to work with
* a fixed length key. The returned value is guaranteed to be smaller than
* the number of buckets in the hash table.
*/
u32 rhashtable_obj_hashfn(const struct rhashtable *ht, void *ptr)
{ {
struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); unsigned int i, size;
#if defined(CONFIG_PROVE_LOCKING)
unsigned int nr_pcpus = 2;
#else
unsigned int nr_pcpus = num_possible_cpus();
#endif
return obj_hashfn(ht, ptr, tbl->size); nr_pcpus = min_t(unsigned int, nr_pcpus, 32UL);
size = roundup_pow_of_two(nr_pcpus * ht->p.locks_mul);
/* Never allocate more than one lock per bucket */
size = min_t(unsigned int, size, tbl->size);
if (sizeof(spinlock_t) != 0) {
#ifdef CONFIG_NUMA
if (size * sizeof(spinlock_t) > PAGE_SIZE)
tbl->locks = vmalloc(size * sizeof(spinlock_t));
else
#endif
tbl->locks = kmalloc_array(size, sizeof(spinlock_t),
GFP_KERNEL);
if (!tbl->locks)
return -ENOMEM;
for (i = 0; i < size; i++)
spin_lock_init(&tbl->locks[i]);
}
tbl->locks_mask = size - 1;
return 0;
} }
EXPORT_SYMBOL_GPL(rhashtable_obj_hashfn);
static u32 head_hashfn(const struct rhashtable *ht, static void bucket_table_free(const struct bucket_table *tbl)
const struct rhash_head *he, u32 hsize)
{ {
return obj_hashfn(ht, rht_obj(ht, he), hsize); if (tbl)
kvfree(tbl->locks);
kvfree(tbl);
} }
static struct bucket_table *bucket_table_alloc(size_t nbuckets) static struct bucket_table *bucket_table_alloc(struct rhashtable *ht,
size_t nbuckets)
{ {
struct bucket_table *tbl; struct bucket_table *tbl;
size_t size; size_t size;
int i;
size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]); size = sizeof(*tbl) + nbuckets * sizeof(tbl->buckets[0]);
tbl = kzalloc(size, GFP_KERNEL | __GFP_NOWARN); tbl = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
...@@ -122,12 +180,15 @@ static struct bucket_table *bucket_table_alloc(size_t nbuckets) ...@@ -122,12 +180,15 @@ static struct bucket_table *bucket_table_alloc(size_t nbuckets)
tbl->size = nbuckets; tbl->size = nbuckets;
return tbl; if (alloc_bucket_locks(ht, tbl) < 0) {
} bucket_table_free(tbl);
return NULL;
}
static void bucket_table_free(const struct bucket_table *tbl) for (i = 0; i < nbuckets; i++)
{ INIT_RHT_NULLS_HEAD(tbl->buckets[i], ht, i);
kvfree(tbl);
return tbl;
} }
/** /**
...@@ -138,7 +199,7 @@ static void bucket_table_free(const struct bucket_table *tbl) ...@@ -138,7 +199,7 @@ static void bucket_table_free(const struct bucket_table *tbl)
bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size) bool rht_grow_above_75(const struct rhashtable *ht, size_t new_size)
{ {
/* Expand table when exceeding 75% load */ /* Expand table when exceeding 75% load */
return ht->nelems > (new_size / 4 * 3); return atomic_read(&ht->nelems) > (new_size / 4 * 3);
} }
EXPORT_SYMBOL_GPL(rht_grow_above_75); EXPORT_SYMBOL_GPL(rht_grow_above_75);
...@@ -150,41 +211,59 @@ EXPORT_SYMBOL_GPL(rht_grow_above_75); ...@@ -150,41 +211,59 @@ EXPORT_SYMBOL_GPL(rht_grow_above_75);
bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size) bool rht_shrink_below_30(const struct rhashtable *ht, size_t new_size)
{ {
/* Shrink table beneath 30% load */ /* Shrink table beneath 30% load */
return ht->nelems < (new_size * 3 / 10); return atomic_read(&ht->nelems) < (new_size * 3 / 10);
} }
EXPORT_SYMBOL_GPL(rht_shrink_below_30); EXPORT_SYMBOL_GPL(rht_shrink_below_30);
static void hashtable_chain_unzip(const struct rhashtable *ht, static void hashtable_chain_unzip(const struct rhashtable *ht,
const struct bucket_table *new_tbl, const struct bucket_table *new_tbl,
struct bucket_table *old_tbl, size_t n) struct bucket_table *old_tbl,
size_t old_hash)
{ {
struct rhash_head *he, *p, *next; struct rhash_head *he, *p, *next;
unsigned int h; spinlock_t *new_bucket_lock, *new_bucket_lock2 = NULL;
unsigned int new_hash, new_hash2;
ASSERT_BUCKET_LOCK(old_tbl, old_hash);
/* Old bucket empty, no work needed. */ /* Old bucket empty, no work needed. */
p = rht_dereference(old_tbl->buckets[n], ht); p = rht_dereference_bucket(old_tbl->buckets[old_hash], old_tbl,
if (!p) old_hash);
if (rht_is_a_nulls(p))
return; return;
new_hash = new_hash2 = head_hashfn(ht, new_tbl, p);
new_bucket_lock = bucket_lock(new_tbl, new_hash);
/* Advance the old bucket pointer one or more times until it /* Advance the old bucket pointer one or more times until it
* reaches a node that doesn't hash to the same bucket as the * reaches a node that doesn't hash to the same bucket as the
* previous node p. Call the previous node p; * previous node p. Call the previous node p;
*/ */
h = head_hashfn(ht, p, new_tbl->size); rht_for_each_continue(he, p->next, old_tbl, old_hash) {
rht_for_each(he, p->next, ht) { new_hash2 = head_hashfn(ht, new_tbl, he);
if (head_hashfn(ht, he, new_tbl->size) != h) if (new_hash != new_hash2)
break; break;
p = he; p = he;
} }
RCU_INIT_POINTER(old_tbl->buckets[n], p->next); rcu_assign_pointer(old_tbl->buckets[old_hash], p->next);
spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED);
/* If we have encountered an entry that maps to a different bucket in
* the new table, lock down that bucket as well as we might cut off
* the end of the chain.
*/
new_bucket_lock2 = bucket_lock(new_tbl, new_hash);
if (new_bucket_lock != new_bucket_lock2)
spin_lock_bh_nested(new_bucket_lock2, RHT_LOCK_NESTED2);
/* Find the subsequent node which does hash to the same /* Find the subsequent node which does hash to the same
* bucket as node P, or NULL if no such node exists. * bucket as node P, or NULL if no such node exists.
*/ */
next = NULL; INIT_RHT_NULLS_HEAD(next, ht, old_hash);
if (he) { if (!rht_is_a_nulls(he)) {
rht_for_each(he, he->next, ht) { rht_for_each_continue(he, he->next, old_tbl, old_hash) {
if (head_hashfn(ht, he, new_tbl->size) == h) { if (head_hashfn(ht, new_tbl, he) == new_hash) {
next = he; next = he;
break; break;
} }
...@@ -194,7 +273,23 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, ...@@ -194,7 +273,23 @@ static void hashtable_chain_unzip(const struct rhashtable *ht,
/* Set p's next pointer to that subsequent node pointer, /* Set p's next pointer to that subsequent node pointer,
* bypassing the nodes which do not hash to p's bucket * bypassing the nodes which do not hash to p's bucket
*/ */
RCU_INIT_POINTER(p->next, next); rcu_assign_pointer(p->next, next);
if (new_bucket_lock != new_bucket_lock2)
spin_unlock_bh(new_bucket_lock2);
spin_unlock_bh(new_bucket_lock);
}
static void link_old_to_new(struct bucket_table *new_tbl,
unsigned int new_hash, struct rhash_head *entry)
{
spinlock_t *new_bucket_lock;
new_bucket_lock = bucket_lock(new_tbl, new_hash);
spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED);
rcu_assign_pointer(*bucket_tail(new_tbl, new_hash), entry);
spin_unlock_bh(new_bucket_lock);
} }
/** /**
...@@ -207,43 +302,59 @@ static void hashtable_chain_unzip(const struct rhashtable *ht, ...@@ -207,43 +302,59 @@ static void hashtable_chain_unzip(const struct rhashtable *ht,
* This function may only be called in a context where it is safe to call * This function may only be called in a context where it is safe to call
* synchronize_rcu(), e.g. not within a rcu_read_lock() section. * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
* *
* The caller must ensure that no concurrent table mutations take place. * The caller must ensure that no concurrent resizing occurs by holding
* It is however valid to have concurrent lookups if they are RCU protected. * ht->mutex.
*
* It is valid to have concurrent insertions and deletions protected by per
* bucket locks or concurrent RCU protected lookups and traversals.
*/ */
int rhashtable_expand(struct rhashtable *ht) int rhashtable_expand(struct rhashtable *ht)
{ {
struct bucket_table *new_tbl, *old_tbl = rht_dereference(ht->tbl, ht); struct bucket_table *new_tbl, *old_tbl = rht_dereference(ht->tbl, ht);
struct rhash_head *he; struct rhash_head *he;
unsigned int i, h; spinlock_t *old_bucket_lock;
bool complete; unsigned int new_hash, old_hash;
bool complete = false;
ASSERT_RHT_MUTEX(ht); ASSERT_RHT_MUTEX(ht);
if (ht->p.max_shift && ht->shift >= ht->p.max_shift) if (ht->p.max_shift && ht->shift >= ht->p.max_shift)
return 0; return 0;
new_tbl = bucket_table_alloc(old_tbl->size * 2); new_tbl = bucket_table_alloc(ht, old_tbl->size * 2);
if (new_tbl == NULL) if (new_tbl == NULL)
return -ENOMEM; return -ENOMEM;
ht->shift++; ht->shift++;
/* For each new bucket, search the corresponding old bucket /* Make insertions go into the new, empty table right away. Deletions
* for the first entry that hashes to the new bucket, and * and lookups will be attempted in both tables until we synchronize.
* link the new bucket to that entry. Since all the entries * The synchronize_rcu() guarantees for the new table to be picked up
* which will end up in the new bucket appear in the same * so no new additions go into the old table while we relink.
* old bucket, this constructs an entirely valid new hash */
* table, but with multiple buckets "zipped" together into a rcu_assign_pointer(ht->future_tbl, new_tbl);
* single imprecise chain. synchronize_rcu();
/* For each new bucket, search the corresponding old bucket for the
* first entry that hashes to the new bucket, and link the end of
* newly formed bucket chain (containing entries added to future
* table) to that entry. Since all the entries which will end up in
* the new bucket appear in the same old bucket, this constructs an
* entirely valid new hash table, but with multiple buckets
* "zipped" together into a single imprecise chain.
*/ */
for (i = 0; i < new_tbl->size; i++) { for (new_hash = 0; new_hash < new_tbl->size; new_hash++) {
h = i & (old_tbl->size - 1); old_hash = rht_bucket_index(old_tbl, new_hash);
rht_for_each(he, old_tbl->buckets[h], ht) { old_bucket_lock = bucket_lock(old_tbl, old_hash);
if (head_hashfn(ht, he, new_tbl->size) == i) {
RCU_INIT_POINTER(new_tbl->buckets[i], he); spin_lock_bh(old_bucket_lock);
rht_for_each(he, old_tbl, old_hash) {
if (head_hashfn(ht, new_tbl, he) == new_hash) {
link_old_to_new(new_tbl, new_hash, he);
break; break;
} }
} }
spin_unlock_bh(old_bucket_lock);
} }
/* Publish the new table pointer. Lookups may now traverse /* Publish the new table pointer. Lookups may now traverse
...@@ -253,7 +364,7 @@ int rhashtable_expand(struct rhashtable *ht) ...@@ -253,7 +364,7 @@ int rhashtable_expand(struct rhashtable *ht)
rcu_assign_pointer(ht->tbl, new_tbl); rcu_assign_pointer(ht->tbl, new_tbl);
/* Unzip interleaved hash chains */ /* Unzip interleaved hash chains */
do { while (!complete && !ht->being_destroyed) {
/* Wait for readers. All new readers will see the new /* Wait for readers. All new readers will see the new
* table, and thus no references to the old table will * table, and thus no references to the old table will
* remain. * remain.
...@@ -265,12 +376,21 @@ int rhashtable_expand(struct rhashtable *ht) ...@@ -265,12 +376,21 @@ int rhashtable_expand(struct rhashtable *ht)
* table): ... * table): ...
*/ */
complete = true; complete = true;
for (i = 0; i < old_tbl->size; i++) { for (old_hash = 0; old_hash < old_tbl->size; old_hash++) {
hashtable_chain_unzip(ht, new_tbl, old_tbl, i); struct rhash_head *head;
if (old_tbl->buckets[i] != NULL)
old_bucket_lock = bucket_lock(old_tbl, old_hash);
spin_lock_bh(old_bucket_lock);
hashtable_chain_unzip(ht, new_tbl, old_tbl, old_hash);
head = rht_dereference_bucket(old_tbl->buckets[old_hash],
old_tbl, old_hash);
if (!rht_is_a_nulls(head))
complete = false; complete = false;
spin_unlock_bh(old_bucket_lock);
} }
} while (!complete); }
bucket_table_free(old_tbl); bucket_table_free(old_tbl);
return 0; return 0;
...@@ -284,45 +404,65 @@ EXPORT_SYMBOL_GPL(rhashtable_expand); ...@@ -284,45 +404,65 @@ EXPORT_SYMBOL_GPL(rhashtable_expand);
* This function may only be called in a context where it is safe to call * This function may only be called in a context where it is safe to call
* synchronize_rcu(), e.g. not within a rcu_read_lock() section. * synchronize_rcu(), e.g. not within a rcu_read_lock() section.
* *
* The caller must ensure that no concurrent resizing occurs by holding
* ht->mutex.
*
* The caller must ensure that no concurrent table mutations take place. * The caller must ensure that no concurrent table mutations take place.
* It is however valid to have concurrent lookups if they are RCU protected. * It is however valid to have concurrent lookups if they are RCU protected.
*
* It is valid to have concurrent insertions and deletions protected by per
* bucket locks or concurrent RCU protected lookups and traversals.
*/ */
int rhashtable_shrink(struct rhashtable *ht) int rhashtable_shrink(struct rhashtable *ht)
{ {
struct bucket_table *ntbl, *tbl = rht_dereference(ht->tbl, ht); struct bucket_table *new_tbl, *tbl = rht_dereference(ht->tbl, ht);
struct rhash_head __rcu **pprev; spinlock_t *new_bucket_lock, *old_bucket_lock1, *old_bucket_lock2;
unsigned int i; unsigned int new_hash;
ASSERT_RHT_MUTEX(ht); ASSERT_RHT_MUTEX(ht);
if (ht->shift <= ht->p.min_shift) if (ht->shift <= ht->p.min_shift)
return 0; return 0;
ntbl = bucket_table_alloc(tbl->size / 2); new_tbl = bucket_table_alloc(ht, tbl->size / 2);
if (ntbl == NULL) if (new_tbl == NULL)
return -ENOMEM; return -ENOMEM;
ht->shift--; rcu_assign_pointer(ht->future_tbl, new_tbl);
synchronize_rcu();
/* Link each bucket in the new table to the first bucket /* Link the first entry in the old bucket to the end of the
* in the old table that contains entries which will hash * bucket in the new table. As entries are concurrently being
* to the new bucket. * added to the new table, lock down the new bucket. As we
* always divide the size in half when shrinking, each bucket
* in the new table maps to exactly two buckets in the old
* table.
*
* As removals can occur concurrently on the old table, we need
* to lock down both matching buckets in the old table.
*/ */
for (i = 0; i < ntbl->size; i++) { for (new_hash = 0; new_hash < new_tbl->size; new_hash++) {
ntbl->buckets[i] = tbl->buckets[i]; old_bucket_lock1 = bucket_lock(tbl, new_hash);
old_bucket_lock2 = bucket_lock(tbl, new_hash + new_tbl->size);
/* Link each bucket in the new table to the first bucket new_bucket_lock = bucket_lock(new_tbl, new_hash);
* in the old table that contains entries which will hash
* to the new bucket. spin_lock_bh(old_bucket_lock1);
*/ spin_lock_bh_nested(old_bucket_lock2, RHT_LOCK_NESTED);
for (pprev = &ntbl->buckets[i]; *pprev != NULL; spin_lock_bh_nested(new_bucket_lock, RHT_LOCK_NESTED2);
pprev = &rht_dereference(*pprev, ht)->next)
; rcu_assign_pointer(*bucket_tail(new_tbl, new_hash),
RCU_INIT_POINTER(*pprev, tbl->buckets[i + ntbl->size]); tbl->buckets[new_hash]);
rcu_assign_pointer(*bucket_tail(new_tbl, new_hash),
tbl->buckets[new_hash + new_tbl->size]);
spin_unlock_bh(new_bucket_lock);
spin_unlock_bh(old_bucket_lock2);
spin_unlock_bh(old_bucket_lock1);
} }
/* Publish the new, valid hash table */ /* Publish the new, valid hash table */
rcu_assign_pointer(ht->tbl, ntbl); rcu_assign_pointer(ht->tbl, new_tbl);
ht->shift--;
/* Wait for readers. No new readers will have references to the /* Wait for readers. No new readers will have references to the
* old hash table. * old hash table.
...@@ -335,59 +475,71 @@ int rhashtable_shrink(struct rhashtable *ht) ...@@ -335,59 +475,71 @@ int rhashtable_shrink(struct rhashtable *ht)
} }
EXPORT_SYMBOL_GPL(rhashtable_shrink); EXPORT_SYMBOL_GPL(rhashtable_shrink);
/** static void rht_deferred_worker(struct work_struct *work)
* rhashtable_insert - insert object into hash hash table
* @ht: hash table
* @obj: pointer to hash head inside object
*
* Will automatically grow the table via rhashtable_expand() if the the
* grow_decision function specified at rhashtable_init() returns true.
*
* The caller must ensure that no concurrent table mutations occur. It is
* however valid to have concurrent lookups if they are RCU protected.
*/
void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj)
{ {
struct bucket_table *tbl = rht_dereference(ht->tbl, ht); struct rhashtable *ht;
u32 hash; struct bucket_table *tbl;
ASSERT_RHT_MUTEX(ht); ht = container_of(work, struct rhashtable, run_work.work);
mutex_lock(&ht->mutex);
hash = head_hashfn(ht, obj, tbl->size); tbl = rht_dereference(ht->tbl, ht);
RCU_INIT_POINTER(obj->next, tbl->buckets[hash]);
rcu_assign_pointer(tbl->buckets[hash], obj);
ht->nelems++;
if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size)) if (ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size))
rhashtable_expand(ht); rhashtable_expand(ht);
else if (ht->p.shrink_decision && ht->p.shrink_decision(ht, tbl->size))
rhashtable_shrink(ht);
mutex_unlock(&ht->mutex);
} }
EXPORT_SYMBOL_GPL(rhashtable_insert);
/** /**
* rhashtable_remove_pprev - remove object from hash table given previous element * rhashtable_insert - insert object into hash hash table
* @ht: hash table * @ht: hash table
* @obj: pointer to hash head inside object * @obj: pointer to hash head inside object
* @pprev: pointer to previous element
* *
* Identical to rhashtable_remove() but caller is alreayd aware of the element * Will take a per bucket spinlock to protect against mutual mutations
* in front of the element to be deleted. This is in particular useful for * on the same bucket. Multiple insertions may occur in parallel unless
* deletion when combined with walking or lookup. * they map to the same bucket lock.
*
* It is safe to call this function from atomic context.
*
* Will trigger an automatic deferred table resizing if the size grows
* beyond the watermark indicated by grow_decision() which can be passed
* to rhashtable_init().
*/ */
void rhashtable_remove_pprev(struct rhashtable *ht, struct rhash_head *obj, void rhashtable_insert(struct rhashtable *ht, struct rhash_head *obj)
struct rhash_head __rcu **pprev)
{ {
struct bucket_table *tbl = rht_dereference(ht->tbl, ht); struct bucket_table *tbl;
struct rhash_head *head;
spinlock_t *lock;
unsigned hash;
ASSERT_RHT_MUTEX(ht); rcu_read_lock();
RCU_INIT_POINTER(*pprev, obj->next); tbl = rht_dereference_rcu(ht->future_tbl, ht);
ht->nelems--; hash = head_hashfn(ht, tbl, obj);
lock = bucket_lock(tbl, hash);
if (ht->p.shrink_decision && spin_lock_bh(lock);
ht->p.shrink_decision(ht, tbl->size)) head = rht_dereference_bucket(tbl->buckets[hash], tbl, hash);
rhashtable_shrink(ht); if (rht_is_a_nulls(head))
INIT_RHT_NULLS_HEAD(obj->next, ht, hash);
else
RCU_INIT_POINTER(obj->next, head);
rcu_assign_pointer(tbl->buckets[hash], obj);
spin_unlock_bh(lock);
atomic_inc(&ht->nelems);
/* Only grow the table if no resizing is currently in progress. */
if (ht->tbl != ht->future_tbl &&
ht->p.grow_decision && ht->p.grow_decision(ht, tbl->size))
schedule_delayed_work(&ht->run_work, 0);
rcu_read_unlock();
} }
EXPORT_SYMBOL_GPL(rhashtable_remove_pprev); EXPORT_SYMBOL_GPL(rhashtable_insert);
/** /**
* rhashtable_remove - remove object from hash table * rhashtable_remove - remove object from hash table
...@@ -406,26 +558,56 @@ EXPORT_SYMBOL_GPL(rhashtable_remove_pprev); ...@@ -406,26 +558,56 @@ EXPORT_SYMBOL_GPL(rhashtable_remove_pprev);
*/ */
bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj) bool rhashtable_remove(struct rhashtable *ht, struct rhash_head *obj)
{ {
struct bucket_table *tbl = rht_dereference(ht->tbl, ht); struct bucket_table *tbl;
struct rhash_head __rcu **pprev; struct rhash_head __rcu **pprev;
struct rhash_head *he; struct rhash_head *he;
u32 h; spinlock_t *lock;
unsigned int hash;
ASSERT_RHT_MUTEX(ht); rcu_read_lock();
tbl = rht_dereference_rcu(ht->tbl, ht);
hash = head_hashfn(ht, tbl, obj);
h = head_hashfn(ht, obj, tbl->size); lock = bucket_lock(tbl, hash);
spin_lock_bh(lock);
pprev = &tbl->buckets[h]; restart:
rht_for_each(he, tbl->buckets[h], ht) { pprev = &tbl->buckets[hash];
rht_for_each(he, tbl, hash) {
if (he != obj) { if (he != obj) {
pprev = &he->next; pprev = &he->next;
continue; continue;
} }
rhashtable_remove_pprev(ht, he, pprev); rcu_assign_pointer(*pprev, obj->next);
atomic_dec(&ht->nelems);
spin_unlock_bh(lock);
if (ht->tbl != ht->future_tbl &&
ht->p.shrink_decision &&
ht->p.shrink_decision(ht, tbl->size))
schedule_delayed_work(&ht->run_work, 0);
rcu_read_unlock();
return true; return true;
} }
if (tbl != rht_dereference_rcu(ht->tbl, ht)) {
spin_unlock_bh(lock);
tbl = rht_dereference_rcu(ht->tbl, ht);
hash = head_hashfn(ht, tbl, obj);
lock = bucket_lock(tbl, hash);
spin_lock_bh(lock);
goto restart;
}
spin_unlock_bh(lock);
rcu_read_unlock();
return false; return false;
} }
EXPORT_SYMBOL_GPL(rhashtable_remove); EXPORT_SYMBOL_GPL(rhashtable_remove);
...@@ -441,25 +623,35 @@ EXPORT_SYMBOL_GPL(rhashtable_remove); ...@@ -441,25 +623,35 @@ EXPORT_SYMBOL_GPL(rhashtable_remove);
* This lookup function may only be used for fixed key hash table (key_len * This lookup function may only be used for fixed key hash table (key_len
* paramter set). It will BUG() if used inappropriately. * paramter set). It will BUG() if used inappropriately.
* *
* Lookups may occur in parallel with hash mutations as long as the lookup is * Lookups may occur in parallel with hashtable mutations and resizing.
* guarded by rcu_read_lock(). The caller must take care of this.
*/ */
void *rhashtable_lookup(const struct rhashtable *ht, const void *key) void *rhashtable_lookup(struct rhashtable *ht, const void *key)
{ {
const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); const struct bucket_table *tbl, *old_tbl;
struct rhash_head *he; struct rhash_head *he;
u32 h; u32 hash;
BUG_ON(!ht->p.key_len); BUG_ON(!ht->p.key_len);
h = __hashfn(ht, key, ht->p.key_len, tbl->size); rcu_read_lock();
rht_for_each_rcu(he, tbl->buckets[h], ht) { old_tbl = rht_dereference_rcu(ht->tbl, ht);
tbl = rht_dereference_rcu(ht->future_tbl, ht);
hash = key_hashfn(ht, key, ht->p.key_len);
restart:
rht_for_each_rcu(he, tbl, rht_bucket_index(tbl, hash)) {
if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key, if (memcmp(rht_obj(ht, he) + ht->p.key_offset, key,
ht->p.key_len)) ht->p.key_len))
continue; continue;
return (void *) he - ht->p.head_offset; rcu_read_unlock();
return rht_obj(ht, he);
} }
if (unlikely(tbl != old_tbl)) {
tbl = old_tbl;
goto restart;
}
rcu_read_unlock();
return NULL; return NULL;
} }
EXPORT_SYMBOL_GPL(rhashtable_lookup); EXPORT_SYMBOL_GPL(rhashtable_lookup);
...@@ -467,33 +659,43 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup); ...@@ -467,33 +659,43 @@ EXPORT_SYMBOL_GPL(rhashtable_lookup);
/** /**
* rhashtable_lookup_compare - search hash table with compare function * rhashtable_lookup_compare - search hash table with compare function
* @ht: hash table * @ht: hash table
* @hash: hash value of desired entry * @key: the pointer to the key
* @compare: compare function, must return true on match * @compare: compare function, must return true on match
* @arg: argument passed on to compare function * @arg: argument passed on to compare function
* *
* Traverses the bucket chain behind the provided hash value and calls the * Traverses the bucket chain behind the provided hash value and calls the
* specified compare function for each entry. * specified compare function for each entry.
* *
* Lookups may occur in parallel with hash mutations as long as the lookup is * Lookups may occur in parallel with hashtable mutations and resizing.
* guarded by rcu_read_lock(). The caller must take care of this.
* *
* Returns the first entry on which the compare function returned true. * Returns the first entry on which the compare function returned true.
*/ */
void *rhashtable_lookup_compare(const struct rhashtable *ht, u32 hash, void *rhashtable_lookup_compare(struct rhashtable *ht, const void *key,
bool (*compare)(void *, void *), void *arg) bool (*compare)(void *, void *), void *arg)
{ {
const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); const struct bucket_table *tbl, *old_tbl;
struct rhash_head *he; struct rhash_head *he;
u32 hash;
if (unlikely(hash >= tbl->size)) rcu_read_lock();
return NULL;
rht_for_each_rcu(he, tbl->buckets[hash], ht) { old_tbl = rht_dereference_rcu(ht->tbl, ht);
tbl = rht_dereference_rcu(ht->future_tbl, ht);
hash = key_hashfn(ht, key, ht->p.key_len);
restart:
rht_for_each_rcu(he, tbl, rht_bucket_index(tbl, hash)) {
if (!compare(rht_obj(ht, he), arg)) if (!compare(rht_obj(ht, he), arg))
continue; continue;
return (void *) he - ht->p.head_offset; rcu_read_unlock();
return rht_obj(ht, he);
} }
if (unlikely(tbl != old_tbl)) {
tbl = old_tbl;
goto restart;
}
rcu_read_unlock();
return NULL; return NULL;
} }
EXPORT_SYMBOL_GPL(rhashtable_lookup_compare); EXPORT_SYMBOL_GPL(rhashtable_lookup_compare);
...@@ -525,9 +727,7 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params) ...@@ -525,9 +727,7 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params)
* .key_offset = offsetof(struct test_obj, key), * .key_offset = offsetof(struct test_obj, key),
* .key_len = sizeof(int), * .key_len = sizeof(int),
* .hashfn = jhash, * .hashfn = jhash,
* #ifdef CONFIG_PROVE_LOCKING * .nulls_base = (1U << RHT_BASE_SHIFT),
* .mutex_is_held = &my_mutex_is_held,
* #endif
* }; * };
* *
* Configuration Example 2: Variable length keys * Configuration Example 2: Variable length keys
...@@ -547,9 +747,6 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params) ...@@ -547,9 +747,6 @@ static size_t rounded_hashtable_size(struct rhashtable_params *params)
* .head_offset = offsetof(struct test_obj, node), * .head_offset = offsetof(struct test_obj, node),
* .hashfn = jhash, * .hashfn = jhash,
* .obj_hashfn = my_hash_fn, * .obj_hashfn = my_hash_fn,
* #ifdef CONFIG_PROVE_LOCKING
* .mutex_is_held = &my_mutex_is_held,
* #endif
* }; * };
*/ */
int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params) int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params)
...@@ -563,24 +760,38 @@ int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params) ...@@ -563,24 +760,38 @@ int rhashtable_init(struct rhashtable *ht, struct rhashtable_params *params)
(!params->key_len && !params->obj_hashfn)) (!params->key_len && !params->obj_hashfn))
return -EINVAL; return -EINVAL;
if (params->nulls_base && params->nulls_base < (1U << RHT_BASE_SHIFT))
return -EINVAL;
params->min_shift = max_t(size_t, params->min_shift, params->min_shift = max_t(size_t, params->min_shift,
ilog2(HASH_MIN_SIZE)); ilog2(HASH_MIN_SIZE));
if (params->nelem_hint) if (params->nelem_hint)
size = rounded_hashtable_size(params); size = rounded_hashtable_size(params);
tbl = bucket_table_alloc(size); memset(ht, 0, sizeof(*ht));
mutex_init(&ht->mutex);
memcpy(&ht->p, params, sizeof(*params));
if (params->locks_mul)
ht->p.locks_mul = roundup_pow_of_two(params->locks_mul);
else
ht->p.locks_mul = BUCKET_LOCKS_PER_CPU;
tbl = bucket_table_alloc(ht, size);
if (tbl == NULL) if (tbl == NULL)
return -ENOMEM; return -ENOMEM;
memset(ht, 0, sizeof(*ht));
ht->shift = ilog2(tbl->size); ht->shift = ilog2(tbl->size);
memcpy(&ht->p, params, sizeof(*params));
RCU_INIT_POINTER(ht->tbl, tbl); RCU_INIT_POINTER(ht->tbl, tbl);
RCU_INIT_POINTER(ht->future_tbl, tbl);
if (!ht->p.hash_rnd) if (!ht->p.hash_rnd)
get_random_bytes(&ht->p.hash_rnd, sizeof(ht->p.hash_rnd)); get_random_bytes(&ht->p.hash_rnd, sizeof(ht->p.hash_rnd));
if (ht->p.grow_decision || ht->p.shrink_decision)
INIT_DEFERRABLE_WORK(&ht->run_work, rht_deferred_worker);
return 0; return 0;
} }
EXPORT_SYMBOL_GPL(rhashtable_init); EXPORT_SYMBOL_GPL(rhashtable_init);
...@@ -593,9 +804,16 @@ EXPORT_SYMBOL_GPL(rhashtable_init); ...@@ -593,9 +804,16 @@ EXPORT_SYMBOL_GPL(rhashtable_init);
* has to make sure that no resizing may happen by unpublishing the hashtable * has to make sure that no resizing may happen by unpublishing the hashtable
* and waiting for the quiescent cycle before releasing the bucket array. * and waiting for the quiescent cycle before releasing the bucket array.
*/ */
void rhashtable_destroy(const struct rhashtable *ht) void rhashtable_destroy(struct rhashtable *ht)
{ {
bucket_table_free(ht->tbl); ht->being_destroyed = true;
mutex_lock(&ht->mutex);
cancel_delayed_work(&ht->run_work);
bucket_table_free(rht_dereference(ht->tbl, ht));
mutex_unlock(&ht->mutex);
} }
EXPORT_SYMBOL_GPL(rhashtable_destroy); EXPORT_SYMBOL_GPL(rhashtable_destroy);
...@@ -610,13 +828,6 @@ EXPORT_SYMBOL_GPL(rhashtable_destroy); ...@@ -610,13 +828,6 @@ EXPORT_SYMBOL_GPL(rhashtable_destroy);
#define TEST_PTR ((void *) 0xdeadbeef) #define TEST_PTR ((void *) 0xdeadbeef)
#define TEST_NEXPANDS 4 #define TEST_NEXPANDS 4
#ifdef CONFIG_PROVE_LOCKING
static int test_mutex_is_held(void *parent)
{
return 1;
}
#endif
struct test_obj { struct test_obj {
void *ptr; void *ptr;
int value; int value;
...@@ -656,6 +867,7 @@ static int __init test_rht_lookup(struct rhashtable *ht) ...@@ -656,6 +867,7 @@ static int __init test_rht_lookup(struct rhashtable *ht)
static void test_bucket_stats(struct rhashtable *ht, bool quiet) static void test_bucket_stats(struct rhashtable *ht, bool quiet)
{ {
unsigned int cnt, rcu_cnt, i, total = 0; unsigned int cnt, rcu_cnt, i, total = 0;
struct rhash_head *pos;
struct test_obj *obj; struct test_obj *obj;
struct bucket_table *tbl; struct bucket_table *tbl;
...@@ -666,14 +878,14 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet) ...@@ -666,14 +878,14 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet)
if (!quiet) if (!quiet)
pr_info(" [%#4x/%zu]", i, tbl->size); pr_info(" [%#4x/%zu]", i, tbl->size);
rht_for_each_entry_rcu(obj, tbl->buckets[i], node) { rht_for_each_entry_rcu(obj, pos, tbl, i, node) {
cnt++; cnt++;
total++; total++;
if (!quiet) if (!quiet)
pr_cont(" [%p],", obj); pr_cont(" [%p],", obj);
} }
rht_for_each_entry_rcu(obj, tbl->buckets[i], node) rht_for_each_entry_rcu(obj, pos, tbl, i, node)
rcu_cnt++; rcu_cnt++;
if (rcu_cnt != cnt) if (rcu_cnt != cnt)
...@@ -685,17 +897,18 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet) ...@@ -685,17 +897,18 @@ static void test_bucket_stats(struct rhashtable *ht, bool quiet)
i, tbl->buckets[i], cnt); i, tbl->buckets[i], cnt);
} }
pr_info(" Traversal complete: counted=%u, nelems=%zu, entries=%d\n", pr_info(" Traversal complete: counted=%u, nelems=%u, entries=%d\n",
total, ht->nelems, TEST_ENTRIES); total, atomic_read(&ht->nelems), TEST_ENTRIES);
if (total != ht->nelems || total != TEST_ENTRIES) if (total != atomic_read(&ht->nelems) || total != TEST_ENTRIES)
pr_warn("Test failed: Total count mismatch ^^^"); pr_warn("Test failed: Total count mismatch ^^^");
} }
static int __init test_rhashtable(struct rhashtable *ht) static int __init test_rhashtable(struct rhashtable *ht)
{ {
struct bucket_table *tbl; struct bucket_table *tbl;
struct test_obj *obj, *next; struct test_obj *obj;
struct rhash_head *pos, *next;
int err; int err;
unsigned int i; unsigned int i;
...@@ -726,7 +939,9 @@ static int __init test_rhashtable(struct rhashtable *ht) ...@@ -726,7 +939,9 @@ static int __init test_rhashtable(struct rhashtable *ht)
for (i = 0; i < TEST_NEXPANDS; i++) { for (i = 0; i < TEST_NEXPANDS; i++) {
pr_info(" Table expansion iteration %u...\n", i); pr_info(" Table expansion iteration %u...\n", i);
mutex_lock(&ht->mutex);
rhashtable_expand(ht); rhashtable_expand(ht);
mutex_unlock(&ht->mutex);
rcu_read_lock(); rcu_read_lock();
pr_info(" Verifying lookups...\n"); pr_info(" Verifying lookups...\n");
...@@ -736,7 +951,9 @@ static int __init test_rhashtable(struct rhashtable *ht) ...@@ -736,7 +951,9 @@ static int __init test_rhashtable(struct rhashtable *ht)
for (i = 0; i < TEST_NEXPANDS; i++) { for (i = 0; i < TEST_NEXPANDS; i++) {
pr_info(" Table shrinkage iteration %u...\n", i); pr_info(" Table shrinkage iteration %u...\n", i);
mutex_lock(&ht->mutex);
rhashtable_shrink(ht); rhashtable_shrink(ht);
mutex_unlock(&ht->mutex);
rcu_read_lock(); rcu_read_lock();
pr_info(" Verifying lookups...\n"); pr_info(" Verifying lookups...\n");
...@@ -764,7 +981,7 @@ static int __init test_rhashtable(struct rhashtable *ht) ...@@ -764,7 +981,7 @@ static int __init test_rhashtable(struct rhashtable *ht)
error: error:
tbl = rht_dereference_rcu(ht->tbl, ht); tbl = rht_dereference_rcu(ht->tbl, ht);
for (i = 0; i < tbl->size; i++) for (i = 0; i < tbl->size; i++)
rht_for_each_entry_safe(obj, next, tbl->buckets[i], ht, node) rht_for_each_entry_safe(obj, pos, next, tbl, i, node)
kfree(obj); kfree(obj);
return err; return err;
...@@ -779,9 +996,7 @@ static int __init test_rht_init(void) ...@@ -779,9 +996,7 @@ static int __init test_rht_init(void)
.key_offset = offsetof(struct test_obj, value), .key_offset = offsetof(struct test_obj, value),
.key_len = sizeof(int), .key_len = sizeof(int),
.hashfn = jhash, .hashfn = jhash,
#ifdef CONFIG_PROVE_LOCKING .nulls_base = (3U << RHT_BASE_SHIFT),
.mutex_is_held = &test_mutex_is_held,
#endif
.grow_decision = rht_grow_above_75, .grow_decision = rht_grow_above_75,
.shrink_decision = rht_shrink_below_30, .shrink_decision = rht_shrink_below_30,
}; };
......
...@@ -33,7 +33,7 @@ static bool nft_hash_lookup(const struct nft_set *set, ...@@ -33,7 +33,7 @@ static bool nft_hash_lookup(const struct nft_set *set,
const struct nft_data *key, const struct nft_data *key,
struct nft_data *data) struct nft_data *data)
{ {
const struct rhashtable *priv = nft_set_priv(set); struct rhashtable *priv = nft_set_priv(set);
const struct nft_hash_elem *he; const struct nft_hash_elem *he;
he = rhashtable_lookup(priv, key); he = rhashtable_lookup(priv, key);
...@@ -83,46 +83,53 @@ static void nft_hash_remove(const struct nft_set *set, ...@@ -83,46 +83,53 @@ static void nft_hash_remove(const struct nft_set *set,
const struct nft_set_elem *elem) const struct nft_set_elem *elem)
{ {
struct rhashtable *priv = nft_set_priv(set); struct rhashtable *priv = nft_set_priv(set);
struct rhash_head *he, __rcu **pprev;
pprev = elem->cookie; rhashtable_remove(priv, elem->cookie);
he = rht_dereference((*pprev), priv); synchronize_rcu();
kfree(elem->cookie);
}
rhashtable_remove_pprev(priv, he, pprev); struct nft_compare_arg {
const struct nft_set *set;
struct nft_set_elem *elem;
};
synchronize_rcu(); static bool nft_hash_compare(void *ptr, void *arg)
kfree(he); {
struct nft_hash_elem *he = ptr;
struct nft_compare_arg *x = arg;
if (!nft_data_cmp(&he->key, &x->elem->key, x->set->klen)) {
x->elem->cookie = he;
x->elem->flags = 0;
if (x->set->flags & NFT_SET_MAP)
nft_data_copy(&x->elem->data, he->data);
return true;
}
return false;
} }
static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem) static int nft_hash_get(const struct nft_set *set, struct nft_set_elem *elem)
{ {
const struct rhashtable *priv = nft_set_priv(set); struct rhashtable *priv = nft_set_priv(set);
const struct bucket_table *tbl = rht_dereference_rcu(priv->tbl, priv); struct nft_compare_arg arg = {
struct rhash_head __rcu * const *pprev; .set = set,
struct nft_hash_elem *he; .elem = elem,
u32 h; };
h = rhashtable_hashfn(priv, &elem->key, set->klen);
pprev = &tbl->buckets[h];
rht_for_each_entry_rcu(he, tbl->buckets[h], node) {
if (nft_data_cmp(&he->key, &elem->key, set->klen)) {
pprev = &he->node.next;
continue;
}
elem->cookie = (void *)pprev; if (rhashtable_lookup_compare(priv, &elem->key,
elem->flags = 0; &nft_hash_compare, &arg))
if (set->flags & NFT_SET_MAP)
nft_data_copy(&elem->data, he->data);
return 0; return 0;
}
return -ENOENT; return -ENOENT;
} }
static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set,
struct nft_set_iter *iter) struct nft_set_iter *iter)
{ {
const struct rhashtable *priv = nft_set_priv(set); struct rhashtable *priv = nft_set_priv(set);
const struct bucket_table *tbl; const struct bucket_table *tbl;
const struct nft_hash_elem *he; const struct nft_hash_elem *he;
struct nft_set_elem elem; struct nft_set_elem elem;
...@@ -130,7 +137,9 @@ static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set, ...@@ -130,7 +137,9 @@ static void nft_hash_walk(const struct nft_ctx *ctx, const struct nft_set *set,
tbl = rht_dereference_rcu(priv->tbl, priv); tbl = rht_dereference_rcu(priv->tbl, priv);
for (i = 0; i < tbl->size; i++) { for (i = 0; i < tbl->size; i++) {
rht_for_each_entry_rcu(he, tbl->buckets[i], node) { struct rhash_head *pos;
rht_for_each_entry_rcu(he, pos, tbl, i, node) {
if (iter->count < iter->skip) if (iter->count < iter->skip)
goto cont; goto cont;
...@@ -153,13 +162,6 @@ static unsigned int nft_hash_privsize(const struct nlattr * const nla[]) ...@@ -153,13 +162,6 @@ static unsigned int nft_hash_privsize(const struct nlattr * const nla[])
return sizeof(struct rhashtable); return sizeof(struct rhashtable);
} }
#ifdef CONFIG_PROVE_LOCKING
static int lockdep_nfnl_lock_is_held(void *parent)
{
return lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES);
}
#endif
static int nft_hash_init(const struct nft_set *set, static int nft_hash_init(const struct nft_set *set,
const struct nft_set_desc *desc, const struct nft_set_desc *desc,
const struct nlattr * const tb[]) const struct nlattr * const tb[])
...@@ -173,9 +175,6 @@ static int nft_hash_init(const struct nft_set *set, ...@@ -173,9 +175,6 @@ static int nft_hash_init(const struct nft_set *set,
.hashfn = jhash, .hashfn = jhash,
.grow_decision = rht_grow_above_75, .grow_decision = rht_grow_above_75,
.shrink_decision = rht_shrink_below_30, .shrink_decision = rht_shrink_below_30,
#ifdef CONFIG_PROVE_LOCKING
.mutex_is_held = lockdep_nfnl_lock_is_held,
#endif
}; };
return rhashtable_init(priv, &params); return rhashtable_init(priv, &params);
...@@ -183,18 +182,23 @@ static int nft_hash_init(const struct nft_set *set, ...@@ -183,18 +182,23 @@ static int nft_hash_init(const struct nft_set *set,
static void nft_hash_destroy(const struct nft_set *set) static void nft_hash_destroy(const struct nft_set *set)
{ {
const struct rhashtable *priv = nft_set_priv(set); struct rhashtable *priv = nft_set_priv(set);
const struct bucket_table *tbl = priv->tbl; const struct bucket_table *tbl;
struct nft_hash_elem *he, *next; struct nft_hash_elem *he;
struct rhash_head *pos, *next;
unsigned int i; unsigned int i;
/* Stop an eventual async resizing */
priv->being_destroyed = true;
mutex_lock(&priv->mutex);
tbl = rht_dereference(priv->tbl, priv);
for (i = 0; i < tbl->size; i++) { for (i = 0; i < tbl->size; i++) {
for (he = rht_entry(tbl->buckets[i], struct nft_hash_elem, node); rht_for_each_entry_safe(he, pos, next, tbl, i, node)
he != NULL; he = next) {
next = rht_entry(he->node.next, struct nft_hash_elem, node);
nft_hash_elem_destroy(set, he); nft_hash_elem_destroy(set, he);
}
} }
mutex_unlock(&priv->mutex);
rhashtable_destroy(priv); rhashtable_destroy(priv);
} }
......
...@@ -97,12 +97,12 @@ static int netlink_dump(struct sock *sk); ...@@ -97,12 +97,12 @@ static int netlink_dump(struct sock *sk);
static void netlink_skb_destructor(struct sk_buff *skb); static void netlink_skb_destructor(struct sk_buff *skb);
/* nl_table locking explained: /* nl_table locking explained:
* Lookup and traversal are protected with nl_sk_hash_lock or nl_table_lock * Lookup and traversal are protected with an RCU read-side lock. Insertion
* combined with an RCU read-side lock. Insertion and removal are protected * and removal are protected with nl_sk_hash_lock while using RCU list
* with nl_sk_hash_lock while using RCU list modification primitives and may * modification primitives and may run in parallel to RCU protected lookups.
* run in parallel to nl_table_lock protected lookups. Destruction of the * Destruction of the Netlink socket may only occur *after* nl_table_lock has
* Netlink socket may only occur *after* nl_table_lock has been acquired * been acquired * either during or after the socket has been removed from
* either during or after the socket has been removed from the list. * the list and after an RCU grace period.
*/ */
DEFINE_RWLOCK(nl_table_lock); DEFINE_RWLOCK(nl_table_lock);
EXPORT_SYMBOL_GPL(nl_table_lock); EXPORT_SYMBOL_GPL(nl_table_lock);
...@@ -114,15 +114,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); ...@@ -114,15 +114,6 @@ static atomic_t nl_table_users = ATOMIC_INIT(0);
DEFINE_MUTEX(nl_sk_hash_lock); DEFINE_MUTEX(nl_sk_hash_lock);
EXPORT_SYMBOL_GPL(nl_sk_hash_lock); EXPORT_SYMBOL_GPL(nl_sk_hash_lock);
#ifdef CONFIG_PROVE_LOCKING
static int lockdep_nl_sk_hash_is_held(void *parent)
{
if (debug_locks)
return lockdep_is_held(&nl_sk_hash_lock) || lockdep_is_held(&nl_table_lock);
return 1;
}
#endif
static ATOMIC_NOTIFIER_HEAD(netlink_chain); static ATOMIC_NOTIFIER_HEAD(netlink_chain);
static DEFINE_SPINLOCK(netlink_tap_lock); static DEFINE_SPINLOCK(netlink_tap_lock);
...@@ -1002,11 +993,8 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, ...@@ -1002,11 +993,8 @@ static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid,
.net = net, .net = net,
.portid = portid, .portid = portid,
}; };
u32 hash;
hash = rhashtable_hashfn(&table->hash, &portid, sizeof(portid));
return rhashtable_lookup_compare(&table->hash, hash, return rhashtable_lookup_compare(&table->hash, &portid,
&netlink_compare, &arg); &netlink_compare, &arg);
} }
...@@ -1015,13 +1003,11 @@ static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) ...@@ -1015,13 +1003,11 @@ static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid)
struct netlink_table *table = &nl_table[protocol]; struct netlink_table *table = &nl_table[protocol];
struct sock *sk; struct sock *sk;
read_lock(&nl_table_lock);
rcu_read_lock(); rcu_read_lock();
sk = __netlink_lookup(table, portid, net); sk = __netlink_lookup(table, portid, net);
if (sk) if (sk)
sock_hold(sk); sock_hold(sk);
rcu_read_unlock(); rcu_read_unlock();
read_unlock(&nl_table_lock);
return sk; return sk;
} }
...@@ -1066,7 +1052,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid) ...@@ -1066,7 +1052,8 @@ static int netlink_insert(struct sock *sk, struct net *net, u32 portid)
goto err; goto err;
err = -ENOMEM; err = -ENOMEM;
if (BITS_PER_LONG > 32 && unlikely(table->hash.nelems >= UINT_MAX)) if (BITS_PER_LONG > 32 &&
unlikely(atomic_read(&table->hash.nelems) >= UINT_MAX))
goto err; goto err;
nlk_sk(sk)->portid = portid; nlk_sk(sk)->portid = portid;
...@@ -1194,6 +1181,13 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol, ...@@ -1194,6 +1181,13 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
goto out; goto out;
} }
static void deferred_put_nlk_sk(struct rcu_head *head)
{
struct netlink_sock *nlk = container_of(head, struct netlink_sock, rcu);
sock_put(&nlk->sk);
}
static int netlink_release(struct socket *sock) static int netlink_release(struct socket *sock)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
...@@ -1259,7 +1253,7 @@ static int netlink_release(struct socket *sock) ...@@ -1259,7 +1253,7 @@ static int netlink_release(struct socket *sock)
local_bh_disable(); local_bh_disable();
sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1);
local_bh_enable(); local_bh_enable();
sock_put(sk); call_rcu(&nlk->rcu, deferred_put_nlk_sk);
return 0; return 0;
} }
...@@ -1274,7 +1268,6 @@ static int netlink_autobind(struct socket *sock) ...@@ -1274,7 +1268,6 @@ static int netlink_autobind(struct socket *sock)
retry: retry:
cond_resched(); cond_resched();
netlink_table_grab();
rcu_read_lock(); rcu_read_lock();
if (__netlink_lookup(table, portid, net)) { if (__netlink_lookup(table, portid, net)) {
/* Bind collision, search negative portid values. */ /* Bind collision, search negative portid values. */
...@@ -1282,11 +1275,9 @@ static int netlink_autobind(struct socket *sock) ...@@ -1282,11 +1275,9 @@ static int netlink_autobind(struct socket *sock)
if (rover > -4097) if (rover > -4097)
rover = -4097; rover = -4097;
rcu_read_unlock(); rcu_read_unlock();
netlink_table_ungrab();
goto retry; goto retry;
} }
rcu_read_unlock(); rcu_read_unlock();
netlink_table_ungrab();
err = netlink_insert(sk, net, portid); err = netlink_insert(sk, net, portid);
if (err == -EADDRINUSE) if (err == -EADDRINUSE)
...@@ -2901,7 +2892,9 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) ...@@ -2901,7 +2892,9 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht); const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
for (j = 0; j < tbl->size; j++) { for (j = 0; j < tbl->size; j++) {
rht_for_each_entry_rcu(nlk, tbl->buckets[j], node) { struct rhash_head *node;
rht_for_each_entry_rcu(nlk, node, tbl, j, node) {
s = (struct sock *)nlk; s = (struct sock *)nlk;
if (sock_net(s) != seq_file_net(seq)) if (sock_net(s) != seq_file_net(seq))
...@@ -2919,9 +2912,8 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos) ...@@ -2919,9 +2912,8 @@ static struct sock *netlink_seq_socket_idx(struct seq_file *seq, loff_t pos)
} }
static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
__acquires(nl_table_lock) __acquires(RCU) __acquires(RCU)
{ {
read_lock(&nl_table_lock);
rcu_read_lock(); rcu_read_lock();
return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN; return *pos ? netlink_seq_socket_idx(seq, *pos - 1) : SEQ_START_TOKEN;
} }
...@@ -2929,6 +2921,8 @@ static void *netlink_seq_start(struct seq_file *seq, loff_t *pos) ...@@ -2929,6 +2921,8 @@ static void *netlink_seq_start(struct seq_file *seq, loff_t *pos)
static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
{ {
struct rhashtable *ht; struct rhashtable *ht;
const struct bucket_table *tbl;
struct rhash_head *node;
struct netlink_sock *nlk; struct netlink_sock *nlk;
struct nl_seq_iter *iter; struct nl_seq_iter *iter;
struct net *net; struct net *net;
...@@ -2945,17 +2939,17 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2945,17 +2939,17 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
i = iter->link; i = iter->link;
ht = &nl_table[i].hash; ht = &nl_table[i].hash;
rht_for_each_entry(nlk, nlk->node.next, ht, node) tbl = rht_dereference_rcu(ht->tbl, ht);
rht_for_each_entry_rcu_continue(nlk, node, nlk->node.next, tbl, iter->hash_idx, node)
if (net_eq(sock_net((struct sock *)nlk), net)) if (net_eq(sock_net((struct sock *)nlk), net))
return nlk; return nlk;
j = iter->hash_idx + 1; j = iter->hash_idx + 1;
do { do {
const struct bucket_table *tbl = rht_dereference_rcu(ht->tbl, ht);
for (; j < tbl->size; j++) { for (; j < tbl->size; j++) {
rht_for_each_entry(nlk, tbl->buckets[j], ht, node) { rht_for_each_entry_rcu(nlk, node, tbl, j, node) {
if (net_eq(sock_net((struct sock *)nlk), net)) { if (net_eq(sock_net((struct sock *)nlk), net)) {
iter->link = i; iter->link = i;
iter->hash_idx = j; iter->hash_idx = j;
...@@ -2971,10 +2965,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) ...@@ -2971,10 +2965,9 @@ static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos)
} }
static void netlink_seq_stop(struct seq_file *seq, void *v) static void netlink_seq_stop(struct seq_file *seq, void *v)
__releases(RCU) __releases(nl_table_lock) __releases(RCU)
{ {
rcu_read_unlock(); rcu_read_unlock();
read_unlock(&nl_table_lock);
} }
...@@ -3121,9 +3114,6 @@ static int __init netlink_proto_init(void) ...@@ -3121,9 +3114,6 @@ static int __init netlink_proto_init(void)
.max_shift = 16, /* 64K */ .max_shift = 16, /* 64K */
.grow_decision = rht_grow_above_75, .grow_decision = rht_grow_above_75,
.shrink_decision = rht_shrink_below_30, .shrink_decision = rht_shrink_below_30,
#ifdef CONFIG_PROVE_LOCKING
.mutex_is_held = lockdep_nl_sk_hash_is_held,
#endif
}; };
if (err != 0) if (err != 0)
......
...@@ -50,6 +50,7 @@ struct netlink_sock { ...@@ -50,6 +50,7 @@ struct netlink_sock {
#endif /* CONFIG_NETLINK_MMAP */ #endif /* CONFIG_NETLINK_MMAP */
struct rhash_head node; struct rhash_head node;
struct rcu_head rcu;
}; };
static inline struct netlink_sock *nlk_sk(struct sock *sk) static inline struct netlink_sock *nlk_sk(struct sock *sk)
......
...@@ -113,7 +113,9 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, ...@@ -113,7 +113,9 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
req = nlmsg_data(cb->nlh); req = nlmsg_data(cb->nlh);
for (i = 0; i < htbl->size; i++) { for (i = 0; i < htbl->size; i++) {
rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) { struct rhash_head *pos;
rht_for_each_entry(nlsk, pos, htbl, i, node) {
sk = (struct sock *)nlsk; sk = (struct sock *)nlsk;
if (!net_eq(sock_net(sk), net)) if (!net_eq(sock_net(sk), net))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment