Commit 3fc50ab5 authored by Jann Horn's avatar Jann Horn Committed by Jens Axboe

io-wq: fix handling of NUMA node IDs

There are several things that can go wrong in the current code on NUMA
systems, especially if not all nodes are online all the time:

 - If the identifiers of the online nodes do not form a single contiguous
   block starting at zero, wq->wqes will be too small, and OOB memory
   accesses will occur e.g. in the loop in io_wq_create().
 - If a node comes online between the call to num_online_nodes() and the
   for_each_node() loop in io_wq_create(), an OOB write will occur.
 - If a node comes online between io_wq_create() and io_wq_enqueue(), a
   lookup is performed for an element that doesn't exist, and an OOB read
   will probably occur.

Fix it by:

 - using nr_node_ids instead of num_online_nodes() for the allocation size;
   nr_node_ids is calculated by setup_nr_node_ids() to be bigger than the
   highest node ID that could possibly come online at some point, even if
   those nodes' identifiers are not a contiguous block
 - creating workers for all possible CPUs, not just all online ones

This is basically what the normal workqueue code also does, as far as I can
tell.
Signed-off-by: default avatarJann Horn <jannh@google.com>
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent ad6e005c
...@@ -105,7 +105,6 @@ struct io_wqe { ...@@ -105,7 +105,6 @@ struct io_wqe {
struct io_wq { struct io_wq {
struct io_wqe **wqes; struct io_wqe **wqes;
unsigned long state; unsigned long state;
unsigned nr_wqes;
get_work_fn *get_work; get_work_fn *get_work;
put_work_fn *put_work; put_work_fn *put_work;
...@@ -632,21 +631,22 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index) ...@@ -632,21 +631,22 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index)
static int io_wq_manager(void *data) static int io_wq_manager(void *data)
{ {
struct io_wq *wq = data; struct io_wq *wq = data;
int i; int workers_to_create = num_possible_nodes();
int node;
/* create fixed workers */ /* create fixed workers */
refcount_set(&wq->refs, wq->nr_wqes); refcount_set(&wq->refs, workers_to_create);
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
if (create_io_worker(wq, wq->wqes[i], IO_WQ_ACCT_BOUND)) if (!create_io_worker(wq, wq->wqes[node], IO_WQ_ACCT_BOUND))
continue; goto err;
goto err; workers_to_create--;
} }
complete(&wq->done); complete(&wq->done);
while (!kthread_should_stop()) { while (!kthread_should_stop()) {
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
bool fork_worker[2] = { false, false }; bool fork_worker[2] = { false, false };
spin_lock_irq(&wqe->lock); spin_lock_irq(&wqe->lock);
...@@ -668,7 +668,7 @@ static int io_wq_manager(void *data) ...@@ -668,7 +668,7 @@ static int io_wq_manager(void *data)
err: err:
set_bit(IO_WQ_BIT_ERROR, &wq->state); set_bit(IO_WQ_BIT_ERROR, &wq->state);
set_bit(IO_WQ_BIT_EXIT, &wq->state); set_bit(IO_WQ_BIT_EXIT, &wq->state);
if (refcount_sub_and_test(wq->nr_wqes - i, &wq->refs)) if (refcount_sub_and_test(workers_to_create, &wq->refs))
complete(&wq->done); complete(&wq->done);
return 0; return 0;
} }
...@@ -776,7 +776,7 @@ static bool io_wq_for_each_worker(struct io_wqe *wqe, ...@@ -776,7 +776,7 @@ static bool io_wq_for_each_worker(struct io_wqe *wqe,
void io_wq_cancel_all(struct io_wq *wq) void io_wq_cancel_all(struct io_wq *wq)
{ {
int i; int node;
set_bit(IO_WQ_BIT_CANCEL, &wq->state); set_bit(IO_WQ_BIT_CANCEL, &wq->state);
...@@ -785,8 +785,8 @@ void io_wq_cancel_all(struct io_wq *wq) ...@@ -785,8 +785,8 @@ void io_wq_cancel_all(struct io_wq *wq)
* to a worker and the worker putting itself on the busy_list * to a worker and the worker putting itself on the busy_list
*/ */
rcu_read_lock(); rcu_read_lock();
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
io_wq_for_each_worker(wqe, io_wqe_worker_send_sig, NULL); io_wq_for_each_worker(wqe, io_wqe_worker_send_sig, NULL);
} }
...@@ -859,10 +859,10 @@ enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel, ...@@ -859,10 +859,10 @@ enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
void *data) void *data)
{ {
enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND; enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
ret = io_wqe_cancel_cb_work(wqe, cancel, data); ret = io_wqe_cancel_cb_work(wqe, cancel, data);
if (ret != IO_WQ_CANCEL_NOTFOUND) if (ret != IO_WQ_CANCEL_NOTFOUND)
...@@ -936,10 +936,10 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe, ...@@ -936,10 +936,10 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe,
enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork) enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork)
{ {
enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND; enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
ret = io_wqe_cancel_work(wqe, cwork); ret = io_wqe_cancel_work(wqe, cwork);
if (ret != IO_WQ_CANCEL_NOTFOUND) if (ret != IO_WQ_CANCEL_NOTFOUND)
...@@ -970,10 +970,10 @@ static void io_wq_flush_func(struct io_wq_work **workptr) ...@@ -970,10 +970,10 @@ static void io_wq_flush_func(struct io_wq_work **workptr)
void io_wq_flush(struct io_wq *wq) void io_wq_flush(struct io_wq *wq)
{ {
struct io_wq_flush_data data; struct io_wq_flush_data data;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
init_completion(&data.done); init_completion(&data.done);
INIT_IO_WORK(&data.work, io_wq_flush_func); INIT_IO_WORK(&data.work, io_wq_flush_func);
...@@ -985,15 +985,14 @@ void io_wq_flush(struct io_wq *wq) ...@@ -985,15 +985,14 @@ void io_wq_flush(struct io_wq *wq)
struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
{ {
int ret = -ENOMEM, i, node; int ret = -ENOMEM, node;
struct io_wq *wq; struct io_wq *wq;
wq = kzalloc(sizeof(*wq), GFP_KERNEL); wq = kzalloc(sizeof(*wq), GFP_KERNEL);
if (!wq) if (!wq)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
wq->nr_wqes = num_online_nodes(); wq->wqes = kcalloc(nr_node_ids, sizeof(struct io_wqe *), GFP_KERNEL);
wq->wqes = kcalloc(wq->nr_wqes, sizeof(struct io_wqe *), GFP_KERNEL);
if (!wq->wqes) { if (!wq->wqes) {
kfree(wq); kfree(wq);
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
...@@ -1006,14 +1005,13 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) ...@@ -1006,14 +1005,13 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
wq->user = data->user; wq->user = data->user;
wq->creds = data->creds; wq->creds = data->creds;
i = 0; for_each_node(node) {
for_each_online_node(node) {
struct io_wqe *wqe; struct io_wqe *wqe;
wqe = kzalloc_node(sizeof(struct io_wqe), GFP_KERNEL, node); wqe = kzalloc_node(sizeof(struct io_wqe), GFP_KERNEL, node);
if (!wqe) if (!wqe)
break; goto err;
wq->wqes[i] = wqe; wq->wqes[node] = wqe;
wqe->node = node; wqe->node = node;
wqe->acct[IO_WQ_ACCT_BOUND].max_workers = bounded; wqe->acct[IO_WQ_ACCT_BOUND].max_workers = bounded;
atomic_set(&wqe->acct[IO_WQ_ACCT_BOUND].nr_running, 0); atomic_set(&wqe->acct[IO_WQ_ACCT_BOUND].nr_running, 0);
...@@ -1029,15 +1027,10 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) ...@@ -1029,15 +1027,10 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
INIT_HLIST_NULLS_HEAD(&wqe->free_list, 0); INIT_HLIST_NULLS_HEAD(&wqe->free_list, 0);
INIT_HLIST_NULLS_HEAD(&wqe->busy_list, 1); INIT_HLIST_NULLS_HEAD(&wqe->busy_list, 1);
INIT_LIST_HEAD(&wqe->all_list); INIT_LIST_HEAD(&wqe->all_list);
i++;
} }
init_completion(&wq->done); init_completion(&wq->done);
if (i != wq->nr_wqes)
goto err;
/* caller must have already done mmgrab() on this mm */ /* caller must have already done mmgrab() on this mm */
wq->mm = data->mm; wq->mm = data->mm;
...@@ -1056,8 +1049,8 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) ...@@ -1056,8 +1049,8 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
ret = PTR_ERR(wq->manager); ret = PTR_ERR(wq->manager);
complete(&wq->done); complete(&wq->done);
err: err:
for (i = 0; i < wq->nr_wqes; i++) for_each_node(node)
kfree(wq->wqes[i]); kfree(wq->wqes[node]);
kfree(wq->wqes); kfree(wq->wqes);
kfree(wq); kfree(wq);
return ERR_PTR(ret); return ERR_PTR(ret);
...@@ -1071,26 +1064,21 @@ static bool io_wq_worker_wake(struct io_worker *worker, void *data) ...@@ -1071,26 +1064,21 @@ static bool io_wq_worker_wake(struct io_worker *worker, void *data)
void io_wq_destroy(struct io_wq *wq) void io_wq_destroy(struct io_wq *wq)
{ {
int i; int node;
set_bit(IO_WQ_BIT_EXIT, &wq->state); set_bit(IO_WQ_BIT_EXIT, &wq->state);
if (wq->manager) if (wq->manager)
kthread_stop(wq->manager); kthread_stop(wq->manager);
rcu_read_lock(); rcu_read_lock();
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node)
struct io_wqe *wqe = wq->wqes[i]; io_wq_for_each_worker(wq->wqes[node], io_wq_worker_wake, NULL);
if (!wqe)
continue;
io_wq_for_each_worker(wqe, io_wq_worker_wake, NULL);
}
rcu_read_unlock(); rcu_read_unlock();
wait_for_completion(&wq->done); wait_for_completion(&wq->done);
for (i = 0; i < wq->nr_wqes; i++) for_each_node(node)
kfree(wq->wqes[i]); kfree(wq->wqes[node]);
kfree(wq->wqes); kfree(wq->wqes);
kfree(wq); kfree(wq);
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment