Commit 304e0242 authored by Cong Wang's avatar Cong Wang Committed by David S. Miller

net_sched: add a temporary refcnt for struct tcindex_data

Although we intentionally use an ordered workqueue for all tc
filter works, the ordering is not guaranteed by RCU work,
given that tcf_queue_work() is esstenially a call_rcu().

This problem is demostrated by Thomas:

  CPU 0:
    tcf_queue_work()
      tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work);

  -> Migration to CPU 1

  CPU 1:
     tcf_queue_work(&p->rwork, tcindex_destroy_work);

so the 2nd work could be queued before the 1st one, which leads
to a free-after-free.

Enforcing this order in RCU work is hard as it requires to change
RCU code too. Fortunately we can workaround this problem in tcindex
filter by taking a temporary refcnt, we only refcnt it right before
we begin to destroy it. This simplifies the code a lot as a full
refcnt requires much more changes in tcindex_set_parms().

Reported-by: syzbot+46f513c3033d592409d2@syzkaller.appspotmail.com
Fixes: 3d210534 ("net_sched: fix a race condition in tcindex_destroy()")
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Paul E. McKenney <paulmck@kernel.org>
Cc: Jamal Hadi Salim <jhs@mojatatu.com>
Cc: Jiri Pirko <jiri@resnulli.us>
Signed-off-by: default avatarCong Wang <xiyou.wangcong@gmail.com>
Reviewed-by: default avatarPaul E. McKenney <paulmck@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 1a323ea5
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <linux/skbuff.h> #include <linux/skbuff.h>
#include <linux/errno.h> #include <linux/errno.h>
#include <linux/slab.h> #include <linux/slab.h>
#include <linux/refcount.h>
#include <net/act_api.h> #include <net/act_api.h>
#include <net/netlink.h> #include <net/netlink.h>
#include <net/pkt_cls.h> #include <net/pkt_cls.h>
...@@ -26,9 +27,12 @@ ...@@ -26,9 +27,12 @@
#define DEFAULT_HASH_SIZE 64 /* optimized for diffserv */ #define DEFAULT_HASH_SIZE 64 /* optimized for diffserv */
struct tcindex_data;
struct tcindex_filter_result { struct tcindex_filter_result {
struct tcf_exts exts; struct tcf_exts exts;
struct tcf_result res; struct tcf_result res;
struct tcindex_data *p;
struct rcu_work rwork; struct rcu_work rwork;
}; };
...@@ -49,6 +53,7 @@ struct tcindex_data { ...@@ -49,6 +53,7 @@ struct tcindex_data {
u32 hash; /* hash table size; 0 if undefined */ u32 hash; /* hash table size; 0 if undefined */
u32 alloc_hash; /* allocated size */ u32 alloc_hash; /* allocated size */
u32 fall_through; /* 0: only classify if explicit match */ u32 fall_through; /* 0: only classify if explicit match */
refcount_t refcnt; /* a temporary refcnt for perfect hash */
struct rcu_work rwork; struct rcu_work rwork;
}; };
...@@ -57,6 +62,20 @@ static inline int tcindex_filter_is_set(struct tcindex_filter_result *r) ...@@ -57,6 +62,20 @@ static inline int tcindex_filter_is_set(struct tcindex_filter_result *r)
return tcf_exts_has_actions(&r->exts) || r->res.classid; return tcf_exts_has_actions(&r->exts) || r->res.classid;
} }
static void tcindex_data_get(struct tcindex_data *p)
{
refcount_inc(&p->refcnt);
}
static void tcindex_data_put(struct tcindex_data *p)
{
if (refcount_dec_and_test(&p->refcnt)) {
kfree(p->perfect);
kfree(p->h);
kfree(p);
}
}
static struct tcindex_filter_result *tcindex_lookup(struct tcindex_data *p, static struct tcindex_filter_result *tcindex_lookup(struct tcindex_data *p,
u16 key) u16 key)
{ {
...@@ -141,6 +160,7 @@ static void __tcindex_destroy_rexts(struct tcindex_filter_result *r) ...@@ -141,6 +160,7 @@ static void __tcindex_destroy_rexts(struct tcindex_filter_result *r)
{ {
tcf_exts_destroy(&r->exts); tcf_exts_destroy(&r->exts);
tcf_exts_put_net(&r->exts); tcf_exts_put_net(&r->exts);
tcindex_data_put(r->p);
} }
static void tcindex_destroy_rexts_work(struct work_struct *work) static void tcindex_destroy_rexts_work(struct work_struct *work)
...@@ -212,6 +232,8 @@ static int tcindex_delete(struct tcf_proto *tp, void *arg, bool *last, ...@@ -212,6 +232,8 @@ static int tcindex_delete(struct tcf_proto *tp, void *arg, bool *last,
else else
__tcindex_destroy_fexts(f); __tcindex_destroy_fexts(f);
} else { } else {
tcindex_data_get(p);
if (tcf_exts_get_net(&r->exts)) if (tcf_exts_get_net(&r->exts))
tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work); tcf_queue_work(&r->rwork, tcindex_destroy_rexts_work);
else else
...@@ -228,9 +250,7 @@ static void tcindex_destroy_work(struct work_struct *work) ...@@ -228,9 +250,7 @@ static void tcindex_destroy_work(struct work_struct *work)
struct tcindex_data, struct tcindex_data,
rwork); rwork);
kfree(p->perfect); tcindex_data_put(p);
kfree(p->h);
kfree(p);
} }
static inline int static inline int
...@@ -248,9 +268,11 @@ static const struct nla_policy tcindex_policy[TCA_TCINDEX_MAX + 1] = { ...@@ -248,9 +268,11 @@ static const struct nla_policy tcindex_policy[TCA_TCINDEX_MAX + 1] = {
}; };
static int tcindex_filter_result_init(struct tcindex_filter_result *r, static int tcindex_filter_result_init(struct tcindex_filter_result *r,
struct tcindex_data *p,
struct net *net) struct net *net)
{ {
memset(r, 0, sizeof(*r)); memset(r, 0, sizeof(*r));
r->p = p;
return tcf_exts_init(&r->exts, net, TCA_TCINDEX_ACT, return tcf_exts_init(&r->exts, net, TCA_TCINDEX_ACT,
TCA_TCINDEX_POLICE); TCA_TCINDEX_POLICE);
} }
...@@ -290,6 +312,7 @@ static int tcindex_alloc_perfect_hash(struct net *net, struct tcindex_data *cp) ...@@ -290,6 +312,7 @@ static int tcindex_alloc_perfect_hash(struct net *net, struct tcindex_data *cp)
TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE); TCA_TCINDEX_ACT, TCA_TCINDEX_POLICE);
if (err < 0) if (err < 0)
goto errout; goto errout;
cp->perfect[i].p = cp;
} }
return 0; return 0;
...@@ -334,6 +357,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, ...@@ -334,6 +357,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
cp->alloc_hash = p->alloc_hash; cp->alloc_hash = p->alloc_hash;
cp->fall_through = p->fall_through; cp->fall_through = p->fall_through;
cp->tp = tp; cp->tp = tp;
refcount_set(&cp->refcnt, 1); /* Paired with tcindex_destroy_work() */
if (tb[TCA_TCINDEX_HASH]) if (tb[TCA_TCINDEX_HASH])
cp->hash = nla_get_u32(tb[TCA_TCINDEX_HASH]); cp->hash = nla_get_u32(tb[TCA_TCINDEX_HASH]);
...@@ -366,7 +390,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, ...@@ -366,7 +390,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
} }
cp->h = p->h; cp->h = p->h;
err = tcindex_filter_result_init(&new_filter_result, net); err = tcindex_filter_result_init(&new_filter_result, cp, net);
if (err < 0) if (err < 0)
goto errout_alloc; goto errout_alloc;
if (old_r) if (old_r)
...@@ -434,7 +458,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, ...@@ -434,7 +458,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
goto errout_alloc; goto errout_alloc;
f->key = handle; f->key = handle;
f->next = NULL; f->next = NULL;
err = tcindex_filter_result_init(&f->result, net); err = tcindex_filter_result_init(&f->result, cp, net);
if (err < 0) { if (err < 0) {
kfree(f); kfree(f);
goto errout_alloc; goto errout_alloc;
...@@ -447,7 +471,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base, ...@@ -447,7 +471,7 @@ tcindex_set_parms(struct net *net, struct tcf_proto *tp, unsigned long base,
} }
if (old_r && old_r != r) { if (old_r && old_r != r) {
err = tcindex_filter_result_init(old_r, net); err = tcindex_filter_result_init(old_r, cp, net);
if (err < 0) { if (err < 0) {
kfree(f); kfree(f);
goto errout_alloc; goto errout_alloc;
...@@ -571,6 +595,14 @@ static void tcindex_destroy(struct tcf_proto *tp, bool rtnl_held, ...@@ -571,6 +595,14 @@ static void tcindex_destroy(struct tcf_proto *tp, bool rtnl_held,
for (i = 0; i < p->hash; i++) { for (i = 0; i < p->hash; i++) {
struct tcindex_filter_result *r = p->perfect + i; struct tcindex_filter_result *r = p->perfect + i;
/* tcf_queue_work() does not guarantee the ordering we
* want, so we have to take this refcnt temporarily to
* ensure 'p' is freed after all tcindex_filter_result
* here. Imperfect hash does not need this, because it
* uses linked lists rather than an array.
*/
tcindex_data_get(p);
tcf_unbind_filter(tp, &r->res); tcf_unbind_filter(tp, &r->res);
if (tcf_exts_get_net(&r->exts)) if (tcf_exts_get_net(&r->exts))
tcf_queue_work(&r->rwork, tcf_queue_work(&r->rwork,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment