Commit cccf0ee8 authored by Jens Axboe's avatar Jens Axboe

io_uring/io-wq: don't use static creds/mm assignments

We currently setup the io_wq with a static set of mm and creds. Even for
a single-use io-wq per io_uring, this is suboptimal as we have may have
multiple enters of the ring. For sharing the io-wq backend, it doesn't
work at all.

Switch to passing in the creds and mm when the work item is setup. This
means that async work is no longer deferred to the io_uring mm and creds,
it is done with the current mm and creds.

Flag this behavior with IORING_FEAT_CUR_PERSONALITY, so applications know
they can rely on the current personality (mm and creds) being the same
for direct issue and async issue.
Reviewed-by: default avatarStefan Metzmacher <metze@samba.org>
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 848f7e18
...@@ -56,7 +56,8 @@ struct io_worker { ...@@ -56,7 +56,8 @@ struct io_worker {
struct rcu_head rcu; struct rcu_head rcu;
struct mm_struct *mm; struct mm_struct *mm;
const struct cred *creds; const struct cred *cur_creds;
const struct cred *saved_creds;
struct files_struct *restore_files; struct files_struct *restore_files;
}; };
...@@ -109,8 +110,6 @@ struct io_wq { ...@@ -109,8 +110,6 @@ struct io_wq {
struct task_struct *manager; struct task_struct *manager;
struct user_struct *user; struct user_struct *user;
const struct cred *creds;
struct mm_struct *mm;
refcount_t refs; refcount_t refs;
struct completion done; struct completion done;
...@@ -137,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker) ...@@ -137,9 +136,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
{ {
bool dropped_lock = false; bool dropped_lock = false;
if (worker->creds) { if (worker->saved_creds) {
revert_creds(worker->creds); revert_creds(worker->saved_creds);
worker->creds = NULL; worker->cur_creds = worker->saved_creds = NULL;
} }
if (current->files != worker->restore_files) { if (current->files != worker->restore_files) {
...@@ -398,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash) ...@@ -398,6 +397,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
return NULL; return NULL;
} }
static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
{
if (worker->mm) {
unuse_mm(worker->mm);
mmput(worker->mm);
worker->mm = NULL;
}
if (!work->mm) {
set_fs(KERNEL_DS);
return;
}
if (mmget_not_zero(work->mm)) {
use_mm(work->mm);
if (!worker->mm)
set_fs(USER_DS);
worker->mm = work->mm;
/* hang on to this mm */
work->mm = NULL;
return;
}
/* failed grabbing mm, ensure work gets cancelled */
work->flags |= IO_WQ_WORK_CANCEL;
}
static void io_wq_switch_creds(struct io_worker *worker,
struct io_wq_work *work)
{
const struct cred *old_creds = override_creds(work->creds);
worker->cur_creds = work->creds;
if (worker->saved_creds)
put_cred(old_creds); /* creds set by previous switch */
else
worker->saved_creds = old_creds;
}
static void io_worker_handle_work(struct io_worker *worker) static void io_worker_handle_work(struct io_worker *worker)
__releases(wqe->lock) __releases(wqe->lock)
{ {
...@@ -446,18 +482,10 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -446,18 +482,10 @@ static void io_worker_handle_work(struct io_worker *worker)
current->files = work->files; current->files = work->files;
task_unlock(current); task_unlock(current);
} }
if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm && if (work->mm != worker->mm)
wq->mm) { io_wq_switch_mm(worker, work);
if (mmget_not_zero(wq->mm)) { if (worker->cur_creds != work->creds)
use_mm(wq->mm); io_wq_switch_creds(worker, work);
set_fs(USER_DS);
worker->mm = wq->mm;
} else {
work->flags |= IO_WQ_WORK_CANCEL;
}
}
if (!worker->creds)
worker->creds = override_creds(wq->creds);
/* /*
* OK to set IO_WQ_WORK_CANCEL even for uncancellable work, * OK to set IO_WQ_WORK_CANCEL even for uncancellable work,
* the worker function will do the right thing. * the worker function will do the right thing.
...@@ -1037,7 +1065,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) ...@@ -1037,7 +1065,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
/* caller must already hold a reference to this */ /* caller must already hold a reference to this */
wq->user = data->user; wq->user = data->user;
wq->creds = data->creds;
for_each_node(node) { for_each_node(node) {
struct io_wqe *wqe; struct io_wqe *wqe;
...@@ -1064,9 +1091,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) ...@@ -1064,9 +1091,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
init_completion(&wq->done); init_completion(&wq->done);
/* caller must have already done mmgrab() on this mm */
wq->mm = data->mm;
wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager"); wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
if (!IS_ERR(wq->manager)) { if (!IS_ERR(wq->manager)) {
wake_up_process(wq->manager); wake_up_process(wq->manager);
......
...@@ -7,7 +7,6 @@ enum { ...@@ -7,7 +7,6 @@ enum {
IO_WQ_WORK_CANCEL = 1, IO_WQ_WORK_CANCEL = 1,
IO_WQ_WORK_HAS_MM = 2, IO_WQ_WORK_HAS_MM = 2,
IO_WQ_WORK_HASHED = 4, IO_WQ_WORK_HASHED = 4,
IO_WQ_WORK_NEEDS_USER = 8,
IO_WQ_WORK_NEEDS_FILES = 16, IO_WQ_WORK_NEEDS_FILES = 16,
IO_WQ_WORK_UNBOUND = 32, IO_WQ_WORK_UNBOUND = 32,
IO_WQ_WORK_INTERNAL = 64, IO_WQ_WORK_INTERNAL = 64,
...@@ -74,6 +73,8 @@ struct io_wq_work { ...@@ -74,6 +73,8 @@ struct io_wq_work {
}; };
void (*func)(struct io_wq_work **); void (*func)(struct io_wq_work **);
struct files_struct *files; struct files_struct *files;
struct mm_struct *mm;
const struct cred *creds;
unsigned flags; unsigned flags;
}; };
...@@ -83,15 +84,15 @@ struct io_wq_work { ...@@ -83,15 +84,15 @@ struct io_wq_work {
(work)->func = _func; \ (work)->func = _func; \
(work)->flags = 0; \ (work)->flags = 0; \
(work)->files = NULL; \ (work)->files = NULL; \
(work)->mm = NULL; \
(work)->creds = NULL; \
} while (0) \ } while (0) \
typedef void (get_work_fn)(struct io_wq_work *); typedef void (get_work_fn)(struct io_wq_work *);
typedef void (put_work_fn)(struct io_wq_work *); typedef void (put_work_fn)(struct io_wq_work *);
struct io_wq_data { struct io_wq_data {
struct mm_struct *mm;
struct user_struct *user; struct user_struct *user;
const struct cred *creds;
get_work_fn *get_work; get_work_fn *get_work;
put_work_fn *put_work; put_work_fn *put_work;
......
...@@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx) ...@@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
} }
} }
static inline void io_req_work_grab_env(struct io_kiocb *req,
const struct io_op_def *def)
{
if (!req->work.mm && def->needs_mm) {
mmgrab(current->mm);
req->work.mm = current->mm;
}
if (!req->work.creds)
req->work.creds = get_current_cred();
}
static inline void io_req_work_drop_env(struct io_kiocb *req)
{
if (req->work.mm) {
mmdrop(req->work.mm);
req->work.mm = NULL;
}
if (req->work.creds) {
put_cred(req->work.creds);
req->work.creds = NULL;
}
}
static inline bool io_prep_async_work(struct io_kiocb *req, static inline bool io_prep_async_work(struct io_kiocb *req,
struct io_kiocb **link) struct io_kiocb **link)
{ {
...@@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req, ...@@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req,
if (def->unbound_nonreg_file) if (def->unbound_nonreg_file)
req->work.flags |= IO_WQ_WORK_UNBOUND; req->work.flags |= IO_WQ_WORK_UNBOUND;
} }
if (def->needs_mm)
req->work.flags |= IO_WQ_WORK_NEEDS_USER; io_req_work_grab_env(req, def);
*link = io_prep_linked_timeout(req); *link = io_prep_linked_timeout(req);
return do_hashed; return do_hashed;
...@@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req) ...@@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req)
else else
fput(req->file); fput(req->file);
} }
io_req_work_drop_env(req);
} }
static void __io_free_req(struct io_kiocb *req) static void __io_free_req(struct io_kiocb *req)
...@@ -3963,6 +3988,8 @@ static int io_req_defer_prep(struct io_kiocb *req, ...@@ -3963,6 +3988,8 @@ static int io_req_defer_prep(struct io_kiocb *req,
{ {
ssize_t ret = 0; ssize_t ret = 0;
io_req_work_grab_env(req, &io_op_defs[req->opcode]);
switch (req->opcode) { switch (req->opcode) {
case IORING_OP_NOP: case IORING_OP_NOP:
break; break;
...@@ -5725,9 +5752,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx, ...@@ -5725,9 +5752,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
goto err; goto err;
} }
data.mm = ctx->sqo_mm;
data.user = ctx->user; data.user = ctx->user;
data.creds = ctx->creds;
data.get_work = io_get_work; data.get_work = io_get_work;
data.put_work = io_put_work; data.put_work = io_put_work;
...@@ -6535,7 +6560,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p) ...@@ -6535,7 +6560,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p)
goto err; goto err;
p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP | p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP |
IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS; IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS |
IORING_FEAT_CUR_PERSONALITY;
trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags); trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
return ret; return ret;
err: err:
......
...@@ -195,6 +195,7 @@ struct io_uring_params { ...@@ -195,6 +195,7 @@ struct io_uring_params {
#define IORING_FEAT_NODROP (1U << 1) #define IORING_FEAT_NODROP (1U << 1)
#define IORING_FEAT_SUBMIT_STABLE (1U << 2) #define IORING_FEAT_SUBMIT_STABLE (1U << 2)
#define IORING_FEAT_RW_CUR_POS (1U << 3) #define IORING_FEAT_RW_CUR_POS (1U << 3)
#define IORING_FEAT_CUR_PERSONALITY (1U << 4)
/* /*
* io_uring_register(2) opcodes and arguments * io_uring_register(2) opcodes and arguments
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment