Commit 6d816e08 authored by Jens Axboe's avatar Jens Axboe

io_uring: hold 'ctx' reference around task_work queue + execute

We're holding the request reference, but we need to go one higher
to ensure that the ctx remains valid after the request has finished.
If the ring is closed with pending task_work inflight, and the
given io_kiocb finishes sync during issue, then we need a reference
to the ring itself around the task_work execution cycle.

Cc: stable@vger.kernel.org # v5.7+
Reported-by: syzbot+9b260fc33297966f5a8e@syzkaller.appspotmail.com
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent efa8480a
...@@ -1821,8 +1821,10 @@ static void __io_req_task_submit(struct io_kiocb *req) ...@@ -1821,8 +1821,10 @@ static void __io_req_task_submit(struct io_kiocb *req)
static void io_req_task_submit(struct callback_head *cb) static void io_req_task_submit(struct callback_head *cb)
{ {
struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work); struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
struct io_ring_ctx *ctx = req->ctx;
__io_req_task_submit(req); __io_req_task_submit(req);
percpu_ref_put(&ctx->refs);
} }
static void io_req_task_queue(struct io_kiocb *req) static void io_req_task_queue(struct io_kiocb *req)
...@@ -1830,6 +1832,7 @@ static void io_req_task_queue(struct io_kiocb *req) ...@@ -1830,6 +1832,7 @@ static void io_req_task_queue(struct io_kiocb *req)
int ret; int ret;
init_task_work(&req->task_work, io_req_task_submit); init_task_work(&req->task_work, io_req_task_submit);
percpu_ref_get(&req->ctx->refs);
ret = io_req_task_work_add(req, &req->task_work); ret = io_req_task_work_add(req, &req->task_work);
if (unlikely(ret)) { if (unlikely(ret)) {
...@@ -2318,6 +2321,8 @@ static void io_rw_resubmit(struct callback_head *cb) ...@@ -2318,6 +2321,8 @@ static void io_rw_resubmit(struct callback_head *cb)
refcount_inc(&req->refs); refcount_inc(&req->refs);
io_queue_async_work(req); io_queue_async_work(req);
} }
percpu_ref_put(&ctx->refs);
} }
#endif #endif
...@@ -2330,6 +2335,8 @@ static bool io_rw_reissue(struct io_kiocb *req, long res) ...@@ -2330,6 +2335,8 @@ static bool io_rw_reissue(struct io_kiocb *req, long res)
return false; return false;
init_task_work(&req->task_work, io_rw_resubmit); init_task_work(&req->task_work, io_rw_resubmit);
percpu_ref_get(&req->ctx->refs);
ret = io_req_task_work_add(req, &req->task_work); ret = io_req_task_work_add(req, &req->task_work);
if (!ret) if (!ret)
return true; return true;
...@@ -3033,6 +3040,8 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode, ...@@ -3033,6 +3040,8 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
list_del_init(&wait->entry); list_del_init(&wait->entry);
init_task_work(&req->task_work, io_req_task_submit); init_task_work(&req->task_work, io_req_task_submit);
percpu_ref_get(&req->ctx->refs);
/* submit ref gets dropped, acquire a new one */ /* submit ref gets dropped, acquire a new one */
refcount_inc(&req->refs); refcount_inc(&req->refs);
ret = io_req_task_work_add(req, &req->task_work); ret = io_req_task_work_add(req, &req->task_work);
...@@ -4565,6 +4574,8 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll, ...@@ -4565,6 +4574,8 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
req->result = mask; req->result = mask;
init_task_work(&req->task_work, func); init_task_work(&req->task_work, func);
percpu_ref_get(&req->ctx->refs);
/* /*
* If this fails, then the task is exiting. When a task exits, the * If this fails, then the task is exiting. When a task exits, the
* work gets canceled, so just cancel this request as well instead * work gets canceled, so just cancel this request as well instead
...@@ -4652,11 +4663,13 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt) ...@@ -4652,11 +4663,13 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt)
static void io_poll_task_func(struct callback_head *cb) static void io_poll_task_func(struct callback_head *cb)
{ {
struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work); struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
struct io_ring_ctx *ctx = req->ctx;
struct io_kiocb *nxt = NULL; struct io_kiocb *nxt = NULL;
io_poll_task_handler(req, &nxt); io_poll_task_handler(req, &nxt);
if (nxt) if (nxt)
__io_req_task_submit(nxt); __io_req_task_submit(nxt);
percpu_ref_put(&ctx->refs);
} }
static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode, static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
...@@ -4752,6 +4765,7 @@ static void io_async_task_func(struct callback_head *cb) ...@@ -4752,6 +4765,7 @@ static void io_async_task_func(struct callback_head *cb)
if (io_poll_rewait(req, &apoll->poll)) { if (io_poll_rewait(req, &apoll->poll)) {
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
percpu_ref_put(&ctx->refs);
return; return;
} }
...@@ -4767,6 +4781,7 @@ static void io_async_task_func(struct callback_head *cb) ...@@ -4767,6 +4781,7 @@ static void io_async_task_func(struct callback_head *cb)
else else
__io_req_task_cancel(req, -ECANCELED); __io_req_task_cancel(req, -ECANCELED);
percpu_ref_put(&ctx->refs);
kfree(apoll->double_poll); kfree(apoll->double_poll);
kfree(apoll); kfree(apoll);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment