Commit 0cbc06b3 authored by Florian Westphal's avatar Florian Westphal Committed by Pablo Neira Ayuso

netfilter: nf_tables: remove synchronize_rcu in commit phase

synchronize_rcu() is expensive.

The commit phase currently enforces an unconditional
synchronize_rcu() after incrementing the generation counter.

This is to make sure that a packet always sees a consistent chain, either
nft_do_chain is still using old generation (it will skip the newly added
rules), or the new one (it will skip old ones that might still be linked
into the list).

We could just remove the synchronize_rcu(), it would not cause a crash but
it could cause us to evaluate a rule that was removed and new rule for the
same packet, instead of either-or.

To resolve this, add rule pointer array holding two generations, the
current one and the future generation.

In commit phase, allocate the rule blob and populate it with the rules that
will be active in the new generation.

Then, make this rule blob public, replacing the old generation pointer.

Then the generation counter can be incremented.

nft_do_chain() will either continue to use the current generation
(in case loop was invoked right before increment), or the new one.
Suggested-by: default avatarPablo Neira Ayuso <pablo@netfilter.org>
Signed-off-by: default avatarFlorian Westphal <fw@strlen.de>
Signed-off-by: default avatarPablo Neira Ayuso <pablo@netfilter.org>
parent 00308791
...@@ -858,6 +858,8 @@ enum nft_chain_flags { ...@@ -858,6 +858,8 @@ enum nft_chain_flags {
* @name: name of the chain * @name: name of the chain
*/ */
struct nft_chain { struct nft_chain {
struct nft_rule *__rcu *rules_gen_0;
struct nft_rule *__rcu *rules_gen_1;
struct list_head rules; struct list_head rules;
struct list_head list; struct list_head list;
struct nft_table *table; struct nft_table *table;
...@@ -867,6 +869,9 @@ struct nft_chain { ...@@ -867,6 +869,9 @@ struct nft_chain {
u8 flags:6, u8 flags:6,
genmask:2; genmask:2;
char *name; char *name;
/* Only used during control plane commit phase: */
struct nft_rule **rules_next;
}; };
enum nft_chain_types { enum nft_chain_types {
......
...@@ -1237,12 +1237,29 @@ static void nft_chain_stats_replace(struct nft_base_chain *chain, ...@@ -1237,12 +1237,29 @@ static void nft_chain_stats_replace(struct nft_base_chain *chain,
rcu_assign_pointer(chain->stats, newstats); rcu_assign_pointer(chain->stats, newstats);
} }
static void nf_tables_chain_free_chain_rules(struct nft_chain *chain)
{
struct nft_rule **g0 = rcu_dereference_raw(chain->rules_gen_0);
struct nft_rule **g1 = rcu_dereference_raw(chain->rules_gen_1);
if (g0 != g1)
kvfree(g1);
kvfree(g0);
/* should be NULL either via abort or via successful commit */
WARN_ON_ONCE(chain->rules_next);
kvfree(chain->rules_next);
}
static void nf_tables_chain_destroy(struct nft_ctx *ctx) static void nf_tables_chain_destroy(struct nft_ctx *ctx)
{ {
struct nft_chain *chain = ctx->chain; struct nft_chain *chain = ctx->chain;
BUG_ON(chain->use > 0); BUG_ON(chain->use > 0);
/* no concurrent access possible anymore */
nf_tables_chain_free_chain_rules(chain);
if (nft_is_base_chain(chain)) { if (nft_is_base_chain(chain)) {
struct nft_base_chain *basechain = nft_base_chain(chain); struct nft_base_chain *basechain = nft_base_chain(chain);
...@@ -1335,6 +1352,27 @@ static void nft_chain_release_hook(struct nft_chain_hook *hook) ...@@ -1335,6 +1352,27 @@ static void nft_chain_release_hook(struct nft_chain_hook *hook)
module_put(hook->type->owner); module_put(hook->type->owner);
} }
struct nft_rules_old {
struct rcu_head h;
struct nft_rule **start;
};
static struct nft_rule **nf_tables_chain_alloc_rules(const struct nft_chain *chain,
unsigned int alloc)
{
if (alloc > INT_MAX)
return NULL;
alloc += 1; /* NULL, ends rules */
if (sizeof(struct nft_rule *) > INT_MAX / alloc)
return NULL;
alloc *= sizeof(struct nft_rule *);
alloc += sizeof(struct nft_rules_old);
return kvmalloc(alloc, GFP_KERNEL);
}
static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask, static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
u8 policy, bool create) u8 policy, bool create)
{ {
...@@ -1344,6 +1382,7 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask, ...@@ -1344,6 +1382,7 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
struct nft_stats __percpu *stats; struct nft_stats __percpu *stats;
struct net *net = ctx->net; struct net *net = ctx->net;
struct nft_chain *chain; struct nft_chain *chain;
struct nft_rule **rules;
int err; int err;
if (table->use == UINT_MAX) if (table->use == UINT_MAX)
...@@ -1406,6 +1445,16 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask, ...@@ -1406,6 +1445,16 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
goto err1; goto err1;
} }
rules = nf_tables_chain_alloc_rules(chain, 0);
if (!rules) {
err = -ENOMEM;
goto err1;
}
*rules = NULL;
rcu_assign_pointer(chain->rules_gen_0, rules);
rcu_assign_pointer(chain->rules_gen_1, rules);
err = nf_tables_register_hook(net, table, chain); err = nf_tables_register_hook(net, table, chain);
if (err < 0) if (err < 0)
goto err1; goto err1;
...@@ -5850,21 +5899,162 @@ static void nf_tables_commit_release(struct net *net) ...@@ -5850,21 +5899,162 @@ static void nf_tables_commit_release(struct net *net)
} }
} }
static int nf_tables_commit_chain_prepare(struct net *net, struct nft_chain *chain)
{
struct nft_rule *rule;
unsigned int alloc = 0;
int i;
/* already handled or inactive chain? */
if (chain->rules_next || !nft_is_active_next(net, chain))
return 0;
rule = list_entry(&chain->rules, struct nft_rule, list);
i = 0;
list_for_each_entry_continue(rule, &chain->rules, list) {
if (nft_is_active_next(net, rule))
alloc++;
}
chain->rules_next = nf_tables_chain_alloc_rules(chain, alloc);
if (!chain->rules_next)
return -ENOMEM;
list_for_each_entry_continue(rule, &chain->rules, list) {
if (nft_is_active_next(net, rule))
chain->rules_next[i++] = rule;
}
chain->rules_next[i] = NULL;
return 0;
}
static void nf_tables_commit_chain_prepare_cancel(struct net *net)
{
struct nft_trans *trans, *next;
list_for_each_entry_safe(trans, next, &net->nft.commit_list, list) {
struct nft_chain *chain = trans->ctx.chain;
if (trans->msg_type == NFT_MSG_NEWRULE ||
trans->msg_type == NFT_MSG_DELRULE) {
kvfree(chain->rules_next);
chain->rules_next = NULL;
}
}
}
static void __nf_tables_commit_chain_free_rules_old(struct rcu_head *h)
{
struct nft_rules_old *o = container_of(h, struct nft_rules_old, h);
kvfree(o->start);
}
static void nf_tables_commit_chain_free_rules_old(struct nft_rule **rules)
{
struct nft_rule **r = rules;
struct nft_rules_old *old;
while (*r)
r++;
r++; /* rcu_head is after end marker */
old = (void *) r;
old->start = rules;
call_rcu(&old->h, __nf_tables_commit_chain_free_rules_old);
}
static void nf_tables_commit_chain_active(struct net *net, struct nft_chain *chain)
{
struct nft_rule **g0, **g1;
bool next_genbit;
next_genbit = nft_gencursor_next(net);
g0 = rcu_dereference_protected(chain->rules_gen_0,
lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES));
g1 = rcu_dereference_protected(chain->rules_gen_1,
lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES));
/* No changes to this chain? */
if (chain->rules_next == NULL) {
/* chain had no change in last or next generation */
if (g0 == g1)
return;
/*
* chain had no change in this generation; make sure next
* one uses same rules as current generation.
*/
if (next_genbit) {
rcu_assign_pointer(chain->rules_gen_1, g0);
nf_tables_commit_chain_free_rules_old(g1);
} else {
rcu_assign_pointer(chain->rules_gen_0, g1);
nf_tables_commit_chain_free_rules_old(g0);
}
return;
}
if (next_genbit)
rcu_assign_pointer(chain->rules_gen_1, chain->rules_next);
else
rcu_assign_pointer(chain->rules_gen_0, chain->rules_next);
chain->rules_next = NULL;
if (g0 == g1)
return;
if (next_genbit)
nf_tables_commit_chain_free_rules_old(g1);
else
nf_tables_commit_chain_free_rules_old(g0);
}
static int nf_tables_commit(struct net *net, struct sk_buff *skb) static int nf_tables_commit(struct net *net, struct sk_buff *skb)
{ {
struct nft_trans *trans, *next; struct nft_trans *trans, *next;
struct nft_trans_elem *te; struct nft_trans_elem *te;
struct nft_chain *chain;
struct nft_table *table;
/* Bump generation counter, invalidate any dump in progress */ /* 1. Allocate space for next generation rules_gen_X[] */
while (++net->nft.base_seq == 0); list_for_each_entry_safe(trans, next, &net->nft.commit_list, list) {
int ret;
/* A new generation has just started */ if (trans->msg_type == NFT_MSG_NEWRULE ||
net->nft.gencursor = nft_gencursor_next(net); trans->msg_type == NFT_MSG_DELRULE) {
chain = trans->ctx.chain;
ret = nf_tables_commit_chain_prepare(net, chain);
if (ret < 0) {
nf_tables_commit_chain_prepare_cancel(net);
return ret;
}
}
}
/* Make sure all packets have left the previous generation before /* step 2. Make rules_gen_X visible to packet path */
* purging old rules. list_for_each_entry(table, &net->nft.tables, list) {
list_for_each_entry(chain, &table->chains, list) {
if (!nft_is_active_next(net, chain))
continue;
nf_tables_commit_chain_active(net, chain);
}
}
/*
* Bump generation counter, invalidate any dump in progress.
* Cannot fail after this point.
*/ */
synchronize_rcu(); while (++net->nft.base_seq == 0);
/* step 3. Start new generation, rules_gen_X now in use. */
net->nft.gencursor = nft_gencursor_next(net);
list_for_each_entry_safe(trans, next, &net->nft.commit_list, list) { list_for_each_entry_safe(trans, next, &net->nft.commit_list, list) {
switch (trans->msg_type) { switch (trans->msg_type) {
......
...@@ -133,7 +133,7 @@ static noinline void nft_update_chain_stats(const struct nft_chain *chain, ...@@ -133,7 +133,7 @@ static noinline void nft_update_chain_stats(const struct nft_chain *chain,
struct nft_jumpstack { struct nft_jumpstack {
const struct nft_chain *chain; const struct nft_chain *chain;
const struct nft_rule *rule; struct nft_rule *const *rules;
}; };
unsigned int unsigned int
...@@ -141,27 +141,29 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv) ...@@ -141,27 +141,29 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv)
{ {
const struct nft_chain *chain = priv, *basechain = chain; const struct nft_chain *chain = priv, *basechain = chain;
const struct net *net = nft_net(pkt); const struct net *net = nft_net(pkt);
struct nft_rule *const *rules;
const struct nft_rule *rule; const struct nft_rule *rule;
const struct nft_expr *expr, *last; const struct nft_expr *expr, *last;
struct nft_regs regs; struct nft_regs regs;
unsigned int stackptr = 0; unsigned int stackptr = 0;
struct nft_jumpstack jumpstack[NFT_JUMP_STACK_SIZE]; struct nft_jumpstack jumpstack[NFT_JUMP_STACK_SIZE];
unsigned int gencursor = nft_genmask_cur(net); bool genbit = READ_ONCE(net->nft.gencursor);
struct nft_traceinfo info; struct nft_traceinfo info;
info.trace = false; info.trace = false;
if (static_branch_unlikely(&nft_trace_enabled)) if (static_branch_unlikely(&nft_trace_enabled))
nft_trace_init(&info, pkt, &regs.verdict, basechain); nft_trace_init(&info, pkt, &regs.verdict, basechain);
do_chain: do_chain:
rule = list_entry(&chain->rules, struct nft_rule, list); if (genbit)
rules = rcu_dereference(chain->rules_gen_1);
else
rules = rcu_dereference(chain->rules_gen_0);
next_rule: next_rule:
rule = *rules;
regs.verdict.code = NFT_CONTINUE; regs.verdict.code = NFT_CONTINUE;
list_for_each_entry_continue_rcu(rule, &chain->rules, list) { for (; *rules ; rules++) {
rule = *rules;
/* This rule is not active, skip. */
if (unlikely(rule->genmask & gencursor))
continue;
nft_rule_for_each_expr(expr, last, rule) { nft_rule_for_each_expr(expr, last, rule) {
if (expr->ops == &nft_cmp_fast_ops) if (expr->ops == &nft_cmp_fast_ops)
nft_cmp_fast_eval(expr, &regs); nft_cmp_fast_eval(expr, &regs);
...@@ -199,7 +201,7 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv) ...@@ -199,7 +201,7 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv)
case NFT_JUMP: case NFT_JUMP:
BUG_ON(stackptr >= NFT_JUMP_STACK_SIZE); BUG_ON(stackptr >= NFT_JUMP_STACK_SIZE);
jumpstack[stackptr].chain = chain; jumpstack[stackptr].chain = chain;
jumpstack[stackptr].rule = rule; jumpstack[stackptr].rules = rules + 1;
stackptr++; stackptr++;
/* fall through */ /* fall through */
case NFT_GOTO: case NFT_GOTO:
...@@ -221,7 +223,7 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv) ...@@ -221,7 +223,7 @@ nft_do_chain(struct nft_pktinfo *pkt, void *priv)
if (stackptr > 0) { if (stackptr > 0) {
stackptr--; stackptr--;
chain = jumpstack[stackptr].chain; chain = jumpstack[stackptr].chain;
rule = jumpstack[stackptr].rule; rules = jumpstack[stackptr].rules;
goto next_rule; goto next_rule;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment