Commit 05bd375b authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for-5.5/io_uring-post-20191128' of git://git.kernel.dk/linux-block

Pull more io_uring updates from Jens Axboe:
 "As mentioned in the first pull request, there was a later batch as
  well. This contains fixes to the stuff that already went in, cleanups,
  and a few later additions. In particular, this contains:

   - Cleanups/fixes/unification of the submission and completion path
     (Pavel,me)

   - Linked timeouts improvements (Pavel,me)

   - Error path fixes (me)

   - Fix lookup window where cancellations wouldn't work (me)

   - Improve DRAIN support (Pavel)

   - Fix backlog flushing -EBUSY on submit (me)

   - Add support for connect(2) (me)

   - Fix for non-iter based fixed IO (Pavel)

   - creds inheritance for async workers (me)

   - Disable cmsg/ancillary data for sendmsg/recvmsg (me)

   - Shrink io_kiocb to 3 cachelines (me)

   - NUMA fix for io-wq (Jann)"

* tag 'for-5.5/io_uring-post-20191128' of git://git.kernel.dk/linux-block: (42 commits)
  io_uring: make poll->wait dynamically allocated
  io-wq: shrink io_wq_work a bit
  io-wq: fix handling of NUMA node IDs
  io_uring: use kzalloc instead of kcalloc for single-element allocations
  io_uring: cleanup io_import_fixed()
  io_uring: inline struct sqe_submit
  io_uring: store timeout's sqe->off in proper place
  net: disallow ancillary data for __sys_{send,recv}msg_file()
  net: separate out the msghdr copy from ___sys_{send,recv}msg()
  io_uring: remove superfluous check for sqe->off in io_accept()
  io_uring: async workers should inherit the user creds
  io-wq: have io_wq_create() take a 'data' argument
  io_uring: fix dead-hung for non-iter fixed rw
  io_uring: add support for IORING_OP_CONNECT
  net: add __sys_connect_file() helper
  io_uring: only return -EBUSY for submit on non-flushed backlog
  io_uring: only !null ptr to io_issue_sqe()
  io_uring: simplify io_req_link_next()
  io_uring: pass only !null to io_req_find_next()
  io_uring: remove io_free_req_find_next()
  ...
parents a6ed68d6 e944475e
...@@ -33,6 +33,7 @@ enum { ...@@ -33,6 +33,7 @@ enum {
enum { enum {
IO_WQ_BIT_EXIT = 0, /* wq exiting */ IO_WQ_BIT_EXIT = 0, /* wq exiting */
IO_WQ_BIT_CANCEL = 1, /* cancel work on list */ IO_WQ_BIT_CANCEL = 1, /* cancel work on list */
IO_WQ_BIT_ERROR = 2, /* error on setup */
}; };
enum { enum {
...@@ -56,6 +57,7 @@ struct io_worker { ...@@ -56,6 +57,7 @@ struct io_worker {
struct rcu_head rcu; struct rcu_head rcu;
struct mm_struct *mm; struct mm_struct *mm;
const struct cred *creds;
struct files_struct *restore_files; struct files_struct *restore_files;
}; };
...@@ -82,7 +84,7 @@ enum { ...@@ -82,7 +84,7 @@ enum {
struct io_wqe { struct io_wqe {
struct { struct {
spinlock_t lock; spinlock_t lock;
struct list_head work_list; struct io_wq_work_list work_list;
unsigned long hash_map; unsigned long hash_map;
unsigned flags; unsigned flags;
} ____cacheline_aligned_in_smp; } ____cacheline_aligned_in_smp;
...@@ -103,13 +105,13 @@ struct io_wqe { ...@@ -103,13 +105,13 @@ struct io_wqe {
struct io_wq { struct io_wq {
struct io_wqe **wqes; struct io_wqe **wqes;
unsigned long state; unsigned long state;
unsigned nr_wqes;
get_work_fn *get_work; get_work_fn *get_work;
put_work_fn *put_work; put_work_fn *put_work;
struct task_struct *manager; struct task_struct *manager;
struct user_struct *user; struct user_struct *user;
struct cred *creds;
struct mm_struct *mm; struct mm_struct *mm;
refcount_t refs; refcount_t refs;
struct completion done; struct completion done;
...@@ -135,6 +137,11 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker) ...@@ -135,6 +137,11 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
{ {
bool dropped_lock = false; bool dropped_lock = false;
if (worker->creds) {
revert_creds(worker->creds);
worker->creds = NULL;
}
if (current->files != worker->restore_files) { if (current->files != worker->restore_files) {
__acquire(&wqe->lock); __acquire(&wqe->lock);
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
...@@ -229,7 +236,8 @@ static void io_worker_exit(struct io_worker *worker) ...@@ -229,7 +236,8 @@ static void io_worker_exit(struct io_worker *worker)
static inline bool io_wqe_run_queue(struct io_wqe *wqe) static inline bool io_wqe_run_queue(struct io_wqe *wqe)
__must_hold(wqe->lock) __must_hold(wqe->lock)
{ {
if (!list_empty(&wqe->work_list) && !(wqe->flags & IO_WQE_FLAG_STALLED)) if (!wq_list_empty(&wqe->work_list) &&
!(wqe->flags & IO_WQE_FLAG_STALLED))
return true; return true;
return false; return false;
} }
...@@ -327,9 +335,9 @@ static void __io_worker_busy(struct io_wqe *wqe, struct io_worker *worker, ...@@ -327,9 +335,9 @@ static void __io_worker_busy(struct io_wqe *wqe, struct io_worker *worker,
* If worker is moving from bound to unbound (or vice versa), then * If worker is moving from bound to unbound (or vice versa), then
* ensure we update the running accounting. * ensure we update the running accounting.
*/ */
worker_bound = (worker->flags & IO_WORKER_F_BOUND) != 0; worker_bound = (worker->flags & IO_WORKER_F_BOUND) != 0;
work_bound = (work->flags & IO_WQ_WORK_UNBOUND) == 0; work_bound = (work->flags & IO_WQ_WORK_UNBOUND) == 0;
if (worker_bound != work_bound) { if (worker_bound != work_bound) {
io_wqe_dec_running(wqe, worker); io_wqe_dec_running(wqe, worker);
if (work_bound) { if (work_bound) {
worker->flags |= IO_WORKER_F_BOUND; worker->flags |= IO_WORKER_F_BOUND;
...@@ -368,12 +376,15 @@ static bool __io_worker_idle(struct io_wqe *wqe, struct io_worker *worker) ...@@ -368,12 +376,15 @@ static bool __io_worker_idle(struct io_wqe *wqe, struct io_worker *worker)
static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash) static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
__must_hold(wqe->lock) __must_hold(wqe->lock)
{ {
struct io_wq_work_node *node, *prev;
struct io_wq_work *work; struct io_wq_work *work;
list_for_each_entry(work, &wqe->work_list, list) { wq_list_for_each(node, prev, &wqe->work_list) {
work = container_of(node, struct io_wq_work, list);
/* not hashed, can run anytime */ /* not hashed, can run anytime */
if (!(work->flags & IO_WQ_WORK_HASHED)) { if (!(work->flags & IO_WQ_WORK_HASHED)) {
list_del(&work->list); wq_node_del(&wqe->work_list, node, prev);
return work; return work;
} }
...@@ -381,7 +392,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash) ...@@ -381,7 +392,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
*hash = work->flags >> IO_WQ_HASH_SHIFT; *hash = work->flags >> IO_WQ_HASH_SHIFT;
if (!(wqe->hash_map & BIT_ULL(*hash))) { if (!(wqe->hash_map & BIT_ULL(*hash))) {
wqe->hash_map |= BIT_ULL(*hash); wqe->hash_map |= BIT_ULL(*hash);
list_del(&work->list); wq_node_del(&wqe->work_list, node, prev);
return work; return work;
} }
} }
...@@ -409,7 +420,7 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -409,7 +420,7 @@ static void io_worker_handle_work(struct io_worker *worker)
work = io_get_next_work(wqe, &hash); work = io_get_next_work(wqe, &hash);
if (work) if (work)
__io_worker_busy(wqe, worker, work); __io_worker_busy(wqe, worker, work);
else if (!list_empty(&wqe->work_list)) else if (!wq_list_empty(&wqe->work_list))
wqe->flags |= IO_WQE_FLAG_STALLED; wqe->flags |= IO_WQE_FLAG_STALLED;
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
...@@ -426,6 +437,9 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -426,6 +437,9 @@ static void io_worker_handle_work(struct io_worker *worker)
worker->cur_work = work; worker->cur_work = work;
spin_unlock_irq(&worker->lock); spin_unlock_irq(&worker->lock);
if (work->flags & IO_WQ_WORK_CB)
work->func(&work);
if ((work->flags & IO_WQ_WORK_NEEDS_FILES) && if ((work->flags & IO_WQ_WORK_NEEDS_FILES) &&
current->files != work->files) { current->files != work->files) {
task_lock(current); task_lock(current);
...@@ -438,6 +452,8 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -438,6 +452,8 @@ static void io_worker_handle_work(struct io_worker *worker)
set_fs(USER_DS); set_fs(USER_DS);
worker->mm = wq->mm; worker->mm = wq->mm;
} }
if (!worker->creds)
worker->creds = override_creds(wq->creds);
if (test_bit(IO_WQ_BIT_CANCEL, &wq->state)) if (test_bit(IO_WQ_BIT_CANCEL, &wq->state))
work->flags |= IO_WQ_WORK_CANCEL; work->flags |= IO_WQ_WORK_CANCEL;
if (worker->mm) if (worker->mm)
...@@ -514,7 +530,7 @@ static int io_wqe_worker(void *data) ...@@ -514,7 +530,7 @@ static int io_wqe_worker(void *data)
if (test_bit(IO_WQ_BIT_EXIT, &wq->state)) { if (test_bit(IO_WQ_BIT_EXIT, &wq->state)) {
spin_lock_irq(&wqe->lock); spin_lock_irq(&wqe->lock);
if (!list_empty(&wqe->work_list)) if (!wq_list_empty(&wqe->work_list))
io_worker_handle_work(worker); io_worker_handle_work(worker);
else else
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
...@@ -562,14 +578,14 @@ void io_wq_worker_sleeping(struct task_struct *tsk) ...@@ -562,14 +578,14 @@ void io_wq_worker_sleeping(struct task_struct *tsk)
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
} }
static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index) static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
{ {
struct io_wqe_acct *acct =&wqe->acct[index]; struct io_wqe_acct *acct =&wqe->acct[index];
struct io_worker *worker; struct io_worker *worker;
worker = kcalloc_node(1, sizeof(*worker), GFP_KERNEL, wqe->node); worker = kzalloc_node(sizeof(*worker), GFP_KERNEL, wqe->node);
if (!worker) if (!worker)
return; return false;
refcount_set(&worker->ref, 1); refcount_set(&worker->ref, 1);
worker->nulls_node.pprev = NULL; worker->nulls_node.pprev = NULL;
...@@ -581,7 +597,7 @@ static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index) ...@@ -581,7 +597,7 @@ static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
"io_wqe_worker-%d/%d", index, wqe->node); "io_wqe_worker-%d/%d", index, wqe->node);
if (IS_ERR(worker->task)) { if (IS_ERR(worker->task)) {
kfree(worker); kfree(worker);
return; return false;
} }
spin_lock_irq(&wqe->lock); spin_lock_irq(&wqe->lock);
...@@ -599,6 +615,7 @@ static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index) ...@@ -599,6 +615,7 @@ static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
atomic_inc(&wq->user->processes); atomic_inc(&wq->user->processes);
wake_up_process(worker->task); wake_up_process(worker->task);
return true;
} }
static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index) static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index)
...@@ -606,9 +623,6 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index) ...@@ -606,9 +623,6 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index)
{ {
struct io_wqe_acct *acct = &wqe->acct[index]; struct io_wqe_acct *acct = &wqe->acct[index];
/* always ensure we have one bounded worker */
if (index == IO_WQ_ACCT_BOUND && !acct->nr_workers)
return true;
/* if we have available workers or no work, no need */ /* if we have available workers or no work, no need */
if (!hlist_nulls_empty(&wqe->free_list) || !io_wqe_run_queue(wqe)) if (!hlist_nulls_empty(&wqe->free_list) || !io_wqe_run_queue(wqe))
return false; return false;
...@@ -621,12 +635,22 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index) ...@@ -621,12 +635,22 @@ static inline bool io_wqe_need_worker(struct io_wqe *wqe, int index)
static int io_wq_manager(void *data) static int io_wq_manager(void *data)
{ {
struct io_wq *wq = data; struct io_wq *wq = data;
int workers_to_create = num_possible_nodes();
int node;
while (!kthread_should_stop()) { /* create fixed workers */
int i; refcount_set(&wq->refs, workers_to_create);
for_each_node(node) {
if (!create_io_worker(wq, wq->wqes[node], IO_WQ_ACCT_BOUND))
goto err;
workers_to_create--;
}
complete(&wq->done);
for (i = 0; i < wq->nr_wqes; i++) { while (!kthread_should_stop()) {
struct io_wqe *wqe = wq->wqes[i]; for_each_node(node) {
struct io_wqe *wqe = wq->wqes[node];
bool fork_worker[2] = { false, false }; bool fork_worker[2] = { false, false };
spin_lock_irq(&wqe->lock); spin_lock_irq(&wqe->lock);
...@@ -645,6 +669,12 @@ static int io_wq_manager(void *data) ...@@ -645,6 +669,12 @@ static int io_wq_manager(void *data)
} }
return 0; return 0;
err:
set_bit(IO_WQ_BIT_ERROR, &wq->state);
set_bit(IO_WQ_BIT_EXIT, &wq->state);
if (refcount_sub_and_test(workers_to_create, &wq->refs))
complete(&wq->done);
return 0;
} }
static bool io_wq_can_queue(struct io_wqe *wqe, struct io_wqe_acct *acct, static bool io_wq_can_queue(struct io_wqe *wqe, struct io_wqe_acct *acct,
...@@ -688,7 +718,7 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work) ...@@ -688,7 +718,7 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
} }
spin_lock_irqsave(&wqe->lock, flags); spin_lock_irqsave(&wqe->lock, flags);
list_add_tail(&work->list, &wqe->work_list); wq_list_add_tail(&work->list, &wqe->work_list);
wqe->flags &= ~IO_WQE_FLAG_STALLED; wqe->flags &= ~IO_WQE_FLAG_STALLED;
spin_unlock_irqrestore(&wqe->lock, flags); spin_unlock_irqrestore(&wqe->lock, flags);
...@@ -750,7 +780,7 @@ static bool io_wq_for_each_worker(struct io_wqe *wqe, ...@@ -750,7 +780,7 @@ static bool io_wq_for_each_worker(struct io_wqe *wqe,
void io_wq_cancel_all(struct io_wq *wq) void io_wq_cancel_all(struct io_wq *wq)
{ {
int i; int node;
set_bit(IO_WQ_BIT_CANCEL, &wq->state); set_bit(IO_WQ_BIT_CANCEL, &wq->state);
...@@ -759,8 +789,8 @@ void io_wq_cancel_all(struct io_wq *wq) ...@@ -759,8 +789,8 @@ void io_wq_cancel_all(struct io_wq *wq)
* to a worker and the worker putting itself on the busy_list * to a worker and the worker putting itself on the busy_list
*/ */
rcu_read_lock(); rcu_read_lock();
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
io_wq_for_each_worker(wqe, io_wqe_worker_send_sig, NULL); io_wq_for_each_worker(wqe, io_wqe_worker_send_sig, NULL);
} }
...@@ -803,14 +833,17 @@ static enum io_wq_cancel io_wqe_cancel_cb_work(struct io_wqe *wqe, ...@@ -803,14 +833,17 @@ static enum io_wq_cancel io_wqe_cancel_cb_work(struct io_wqe *wqe,
.cancel = cancel, .cancel = cancel,
.caller_data = cancel_data, .caller_data = cancel_data,
}; };
struct io_wq_work_node *node, *prev;
struct io_wq_work *work; struct io_wq_work *work;
unsigned long flags; unsigned long flags;
bool found = false; bool found = false;
spin_lock_irqsave(&wqe->lock, flags); spin_lock_irqsave(&wqe->lock, flags);
list_for_each_entry(work, &wqe->work_list, list) { wq_list_for_each(node, prev, &wqe->work_list) {
work = container_of(node, struct io_wq_work, list);
if (cancel(work, cancel_data)) { if (cancel(work, cancel_data)) {
list_del(&work->list); wq_node_del(&wqe->work_list, node, prev);
found = true; found = true;
break; break;
} }
...@@ -833,10 +866,10 @@ enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel, ...@@ -833,10 +866,10 @@ enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
void *data) void *data)
{ {
enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND; enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
ret = io_wqe_cancel_cb_work(wqe, cancel, data); ret = io_wqe_cancel_cb_work(wqe, cancel, data);
if (ret != IO_WQ_CANCEL_NOTFOUND) if (ret != IO_WQ_CANCEL_NOTFOUND)
...@@ -868,6 +901,7 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data) ...@@ -868,6 +901,7 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe, static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe,
struct io_wq_work *cwork) struct io_wq_work *cwork)
{ {
struct io_wq_work_node *node, *prev;
struct io_wq_work *work; struct io_wq_work *work;
unsigned long flags; unsigned long flags;
bool found = false; bool found = false;
...@@ -880,9 +914,11 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe, ...@@ -880,9 +914,11 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe,
* no completion will be posted for it. * no completion will be posted for it.
*/ */
spin_lock_irqsave(&wqe->lock, flags); spin_lock_irqsave(&wqe->lock, flags);
list_for_each_entry(work, &wqe->work_list, list) { wq_list_for_each(node, prev, &wqe->work_list) {
work = container_of(node, struct io_wq_work, list);
if (work == cwork) { if (work == cwork) {
list_del(&work->list); wq_node_del(&wqe->work_list, node, prev);
found = true; found = true;
break; break;
} }
...@@ -910,10 +946,10 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe, ...@@ -910,10 +946,10 @@ static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe,
enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork) enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork)
{ {
enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND; enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
ret = io_wqe_cancel_work(wqe, cwork); ret = io_wqe_cancel_work(wqe, cwork);
if (ret != IO_WQ_CANCEL_NOTFOUND) if (ret != IO_WQ_CANCEL_NOTFOUND)
...@@ -944,10 +980,10 @@ static void io_wq_flush_func(struct io_wq_work **workptr) ...@@ -944,10 +980,10 @@ static void io_wq_flush_func(struct io_wq_work **workptr)
void io_wq_flush(struct io_wq *wq) void io_wq_flush(struct io_wq *wq)
{ {
struct io_wq_flush_data data; struct io_wq_flush_data data;
int i; int node;
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[i]; struct io_wqe *wqe = wq->wqes[node];
init_completion(&data.done); init_completion(&data.done);
INIT_IO_WORK(&data.work, io_wq_flush_func); INIT_IO_WORK(&data.work, io_wq_flush_func);
...@@ -957,43 +993,39 @@ void io_wq_flush(struct io_wq *wq) ...@@ -957,43 +993,39 @@ void io_wq_flush(struct io_wq *wq)
} }
} }
struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
struct user_struct *user, get_work_fn *get_work,
put_work_fn *put_work)
{ {
int ret = -ENOMEM, i, node; int ret = -ENOMEM, node;
struct io_wq *wq; struct io_wq *wq;
wq = kcalloc(1, sizeof(*wq), GFP_KERNEL); wq = kzalloc(sizeof(*wq), GFP_KERNEL);
if (!wq) if (!wq)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
wq->nr_wqes = num_online_nodes(); wq->wqes = kcalloc(nr_node_ids, sizeof(struct io_wqe *), GFP_KERNEL);
wq->wqes = kcalloc(wq->nr_wqes, sizeof(struct io_wqe *), GFP_KERNEL);
if (!wq->wqes) { if (!wq->wqes) {
kfree(wq); kfree(wq);
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
} }
wq->get_work = get_work; wq->get_work = data->get_work;
wq->put_work = put_work; wq->put_work = data->put_work;
/* caller must already hold a reference to this */ /* caller must already hold a reference to this */
wq->user = user; wq->user = data->user;
wq->creds = data->creds;
i = 0; for_each_node(node) {
refcount_set(&wq->refs, wq->nr_wqes);
for_each_online_node(node) {
struct io_wqe *wqe; struct io_wqe *wqe;
wqe = kcalloc_node(1, sizeof(struct io_wqe), GFP_KERNEL, node); wqe = kzalloc_node(sizeof(struct io_wqe), GFP_KERNEL, node);
if (!wqe) if (!wqe)
break; goto err;
wq->wqes[i] = wqe; wq->wqes[node] = wqe;
wqe->node = node; wqe->node = node;
wqe->acct[IO_WQ_ACCT_BOUND].max_workers = bounded; wqe->acct[IO_WQ_ACCT_BOUND].max_workers = bounded;
atomic_set(&wqe->acct[IO_WQ_ACCT_BOUND].nr_running, 0); atomic_set(&wqe->acct[IO_WQ_ACCT_BOUND].nr_running, 0);
if (user) { if (wq->user) {
wqe->acct[IO_WQ_ACCT_UNBOUND].max_workers = wqe->acct[IO_WQ_ACCT_UNBOUND].max_workers =
task_rlimit(current, RLIMIT_NPROC); task_rlimit(current, RLIMIT_NPROC);
} }
...@@ -1001,33 +1033,36 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, ...@@ -1001,33 +1033,36 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
wqe->node = node; wqe->node = node;
wqe->wq = wq; wqe->wq = wq;
spin_lock_init(&wqe->lock); spin_lock_init(&wqe->lock);
INIT_LIST_HEAD(&wqe->work_list); INIT_WQ_LIST(&wqe->work_list);
INIT_HLIST_NULLS_HEAD(&wqe->free_list, 0); INIT_HLIST_NULLS_HEAD(&wqe->free_list, 0);
INIT_HLIST_NULLS_HEAD(&wqe->busy_list, 1); INIT_HLIST_NULLS_HEAD(&wqe->busy_list, 1);
INIT_LIST_HEAD(&wqe->all_list); INIT_LIST_HEAD(&wqe->all_list);
i++;
} }
init_completion(&wq->done); init_completion(&wq->done);
if (i != wq->nr_wqes)
goto err;
/* caller must have already done mmgrab() on this mm */ /* caller must have already done mmgrab() on this mm */
wq->mm = mm; wq->mm = data->mm;
wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager"); wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
if (!IS_ERR(wq->manager)) { if (!IS_ERR(wq->manager)) {
wake_up_process(wq->manager); wake_up_process(wq->manager);
wait_for_completion(&wq->done);
if (test_bit(IO_WQ_BIT_ERROR, &wq->state)) {
ret = -ENOMEM;
goto err;
}
reinit_completion(&wq->done);
return wq; return wq;
} }
ret = PTR_ERR(wq->manager); ret = PTR_ERR(wq->manager);
wq->manager = NULL;
err:
complete(&wq->done); complete(&wq->done);
io_wq_destroy(wq); err:
for_each_node(node)
kfree(wq->wqes[node]);
kfree(wq->wqes);
kfree(wq);
return ERR_PTR(ret); return ERR_PTR(ret);
} }
...@@ -1039,27 +1074,21 @@ static bool io_wq_worker_wake(struct io_worker *worker, void *data) ...@@ -1039,27 +1074,21 @@ static bool io_wq_worker_wake(struct io_worker *worker, void *data)
void io_wq_destroy(struct io_wq *wq) void io_wq_destroy(struct io_wq *wq)
{ {
int i; int node;
if (wq->manager) { set_bit(IO_WQ_BIT_EXIT, &wq->state);
set_bit(IO_WQ_BIT_EXIT, &wq->state); if (wq->manager)
kthread_stop(wq->manager); kthread_stop(wq->manager);
}
rcu_read_lock(); rcu_read_lock();
for (i = 0; i < wq->nr_wqes; i++) { for_each_node(node)
struct io_wqe *wqe = wq->wqes[i]; io_wq_for_each_worker(wq->wqes[node], io_wq_worker_wake, NULL);
if (!wqe)
continue;
io_wq_for_each_worker(wqe, io_wq_worker_wake, NULL);
}
rcu_read_unlock(); rcu_read_unlock();
wait_for_completion(&wq->done); wait_for_completion(&wq->done);
for (i = 0; i < wq->nr_wqes; i++) for_each_node(node)
kfree(wq->wqes[i]); kfree(wq->wqes[node]);
kfree(wq->wqes); kfree(wq->wqes);
kfree(wq); kfree(wq);
} }
...@@ -11,6 +11,7 @@ enum { ...@@ -11,6 +11,7 @@ enum {
IO_WQ_WORK_NEEDS_FILES = 16, IO_WQ_WORK_NEEDS_FILES = 16,
IO_WQ_WORK_UNBOUND = 32, IO_WQ_WORK_UNBOUND = 32,
IO_WQ_WORK_INTERNAL = 64, IO_WQ_WORK_INTERNAL = 64,
IO_WQ_WORK_CB = 128,
IO_WQ_HASH_SHIFT = 24, /* upper 8 bits are used for hash key */ IO_WQ_HASH_SHIFT = 24, /* upper 8 bits are used for hash key */
}; };
...@@ -21,15 +22,60 @@ enum io_wq_cancel { ...@@ -21,15 +22,60 @@ enum io_wq_cancel {
IO_WQ_CANCEL_NOTFOUND, /* work not found */ IO_WQ_CANCEL_NOTFOUND, /* work not found */
}; };
struct io_wq_work_node {
struct io_wq_work_node *next;
};
struct io_wq_work_list {
struct io_wq_work_node *first;
struct io_wq_work_node *last;
};
static inline void wq_list_add_tail(struct io_wq_work_node *node,
struct io_wq_work_list *list)
{
if (!list->first) {
list->first = list->last = node;
} else {
list->last->next = node;
list->last = node;
}
}
static inline void wq_node_del(struct io_wq_work_list *list,
struct io_wq_work_node *node,
struct io_wq_work_node *prev)
{
if (node == list->first)
list->first = node->next;
if (node == list->last)
list->last = prev;
if (prev)
prev->next = node->next;
}
#define wq_list_for_each(pos, prv, head) \
for (pos = (head)->first, prv = NULL; pos; prv = pos, pos = (pos)->next)
#define wq_list_empty(list) ((list)->first == NULL)
#define INIT_WQ_LIST(list) do { \
(list)->first = NULL; \
(list)->last = NULL; \
} while (0)
struct io_wq_work { struct io_wq_work {
struct list_head list; union {
struct io_wq_work_node list;
void *data;
};
void (*func)(struct io_wq_work **); void (*func)(struct io_wq_work **);
unsigned flags;
struct files_struct *files; struct files_struct *files;
unsigned flags;
}; };
#define INIT_IO_WORK(work, _func) \ #define INIT_IO_WORK(work, _func) \
do { \ do { \
(work)->list.next = NULL; \
(work)->func = _func; \ (work)->func = _func; \
(work)->flags = 0; \ (work)->flags = 0; \
(work)->files = NULL; \ (work)->files = NULL; \
...@@ -38,9 +84,16 @@ struct io_wq_work { ...@@ -38,9 +84,16 @@ struct io_wq_work {
typedef void (get_work_fn)(struct io_wq_work *); typedef void (get_work_fn)(struct io_wq_work *);
typedef void (put_work_fn)(struct io_wq_work *); typedef void (put_work_fn)(struct io_wq_work *);
struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, struct io_wq_data {
struct user_struct *user, struct mm_struct *mm;
get_work_fn *get_work, put_work_fn *put_work); struct user_struct *user;
struct cred *creds;
get_work_fn *get_work;
put_work_fn *put_work;
};
struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data);
void io_wq_destroy(struct io_wq *wq); void io_wq_destroy(struct io_wq *wq);
void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work); void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work);
......
...@@ -186,6 +186,7 @@ struct io_ring_ctx { ...@@ -186,6 +186,7 @@ struct io_ring_ctx {
bool compat; bool compat;
bool account_mem; bool account_mem;
bool cq_overflow_flushed; bool cq_overflow_flushed;
bool drain_next;
/* /*
* Ring buffer of indices into array of io_uring_sqe, which is * Ring buffer of indices into array of io_uring_sqe, which is
...@@ -236,6 +237,8 @@ struct io_ring_ctx { ...@@ -236,6 +237,8 @@ struct io_ring_ctx {
struct user_struct *user; struct user_struct *user;
struct cred *creds;
/* 0 is for ctx quiesce/reinit/free, 1 is for sqo_thread started */ /* 0 is for ctx quiesce/reinit/free, 1 is for sqo_thread started */
struct completion *completions; struct completion *completions;
...@@ -278,16 +281,6 @@ struct io_ring_ctx { ...@@ -278,16 +281,6 @@ struct io_ring_ctx {
} ____cacheline_aligned_in_smp; } ____cacheline_aligned_in_smp;
}; };
struct sqe_submit {
const struct io_uring_sqe *sqe;
struct file *ring_file;
int ring_fd;
u32 sequence;
bool has_user;
bool in_async;
bool needs_fixed_file;
};
/* /*
* First field must be the file pointer in all the * First field must be the file pointer in all the
* iocb unions! See also 'struct kiocb' in <linux/fs.h> * iocb unions! See also 'struct kiocb' in <linux/fs.h>
...@@ -298,12 +291,20 @@ struct io_poll_iocb { ...@@ -298,12 +291,20 @@ struct io_poll_iocb {
__poll_t events; __poll_t events;
bool done; bool done;
bool canceled; bool canceled;
struct wait_queue_entry wait; struct wait_queue_entry *wait;
};
struct io_timeout_data {
struct io_kiocb *req;
struct hrtimer timer;
struct timespec64 ts;
enum hrtimer_mode mode;
u32 seq_offset;
}; };
struct io_timeout { struct io_timeout {
struct file *file; struct file *file;
struct hrtimer timer; struct io_timeout_data *data;
}; };
/* /*
...@@ -320,7 +321,12 @@ struct io_kiocb { ...@@ -320,7 +321,12 @@ struct io_kiocb {
struct io_timeout timeout; struct io_timeout timeout;
}; };
struct sqe_submit submit; const struct io_uring_sqe *sqe;
struct file *ring_file;
int ring_fd;
bool has_user;
bool in_async;
bool needs_fixed_file;
struct io_ring_ctx *ctx; struct io_ring_ctx *ctx;
union { union {
...@@ -333,19 +339,20 @@ struct io_kiocb { ...@@ -333,19 +339,20 @@ struct io_kiocb {
#define REQ_F_NOWAIT 1 /* must not punt to workers */ #define REQ_F_NOWAIT 1 /* must not punt to workers */
#define REQ_F_IOPOLL_COMPLETED 2 /* polled IO has completed */ #define REQ_F_IOPOLL_COMPLETED 2 /* polled IO has completed */
#define REQ_F_FIXED_FILE 4 /* ctx owns file */ #define REQ_F_FIXED_FILE 4 /* ctx owns file */
#define REQ_F_SEQ_PREV 8 /* sequential with previous */ #define REQ_F_LINK_NEXT 8 /* already grabbed next link */
#define REQ_F_IO_DRAIN 16 /* drain existing IO first */ #define REQ_F_IO_DRAIN 16 /* drain existing IO first */
#define REQ_F_IO_DRAINED 32 /* drain done */ #define REQ_F_IO_DRAINED 32 /* drain done */
#define REQ_F_LINK 64 /* linked sqes */ #define REQ_F_LINK 64 /* linked sqes */
#define REQ_F_LINK_TIMEOUT 128 /* has linked timeout */ #define REQ_F_LINK_TIMEOUT 128 /* has linked timeout */
#define REQ_F_FAIL_LINK 256 /* fail rest of links */ #define REQ_F_FAIL_LINK 256 /* fail rest of links */
#define REQ_F_SHADOW_DRAIN 512 /* link-drain shadow req */ #define REQ_F_DRAIN_LINK 512 /* link should be fully drained */
#define REQ_F_TIMEOUT 1024 /* timeout request */ #define REQ_F_TIMEOUT 1024 /* timeout request */
#define REQ_F_ISREG 2048 /* regular file */ #define REQ_F_ISREG 2048 /* regular file */
#define REQ_F_MUST_PUNT 4096 /* must be punted even for NONBLOCK */ #define REQ_F_MUST_PUNT 4096 /* must be punted even for NONBLOCK */
#define REQ_F_TIMEOUT_NOSEQ 8192 /* no timeout sequence */ #define REQ_F_TIMEOUT_NOSEQ 8192 /* no timeout sequence */
#define REQ_F_INFLIGHT 16384 /* on inflight list */ #define REQ_F_INFLIGHT 16384 /* on inflight list */
#define REQ_F_COMP_LOCKED 32768 /* completion under lock */ #define REQ_F_COMP_LOCKED 32768 /* completion under lock */
#define REQ_F_FREE_SQE 65536 /* free sqe if not async queued */
u64 user_data; u64 user_data;
u32 result; u32 result;
u32 sequence; u32 sequence;
...@@ -383,6 +390,9 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res); ...@@ -383,6 +390,9 @@ static void io_cqring_fill_event(struct io_kiocb *req, long res);
static void __io_free_req(struct io_kiocb *req); static void __io_free_req(struct io_kiocb *req);
static void io_put_req(struct io_kiocb *req); static void io_put_req(struct io_kiocb *req);
static void io_double_put_req(struct io_kiocb *req); static void io_double_put_req(struct io_kiocb *req);
static void __io_double_put_req(struct io_kiocb *req);
static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req);
static void io_queue_linked_timeout(struct io_kiocb *req);
static struct kmem_cache *req_cachep; static struct kmem_cache *req_cachep;
...@@ -521,12 +531,13 @@ static inline bool io_sqe_needs_user(const struct io_uring_sqe *sqe) ...@@ -521,12 +531,13 @@ static inline bool io_sqe_needs_user(const struct io_uring_sqe *sqe)
opcode == IORING_OP_WRITE_FIXED); opcode == IORING_OP_WRITE_FIXED);
} }
static inline bool io_prep_async_work(struct io_kiocb *req) static inline bool io_prep_async_work(struct io_kiocb *req,
struct io_kiocb **link)
{ {
bool do_hashed = false; bool do_hashed = false;
if (req->submit.sqe) { if (req->sqe) {
switch (req->submit.sqe->opcode) { switch (req->sqe->opcode) {
case IORING_OP_WRITEV: case IORING_OP_WRITEV:
case IORING_OP_WRITE_FIXED: case IORING_OP_WRITE_FIXED:
do_hashed = true; do_hashed = true;
...@@ -537,6 +548,7 @@ static inline bool io_prep_async_work(struct io_kiocb *req) ...@@ -537,6 +548,7 @@ static inline bool io_prep_async_work(struct io_kiocb *req)
case IORING_OP_RECVMSG: case IORING_OP_RECVMSG:
case IORING_OP_ACCEPT: case IORING_OP_ACCEPT:
case IORING_OP_POLL_ADD: case IORING_OP_POLL_ADD:
case IORING_OP_CONNECT:
/* /*
* We know REQ_F_ISREG is not set on some of these * We know REQ_F_ISREG is not set on some of these
* opcodes, but this enables us to keep the check in * opcodes, but this enables us to keep the check in
...@@ -546,17 +558,21 @@ static inline bool io_prep_async_work(struct io_kiocb *req) ...@@ -546,17 +558,21 @@ static inline bool io_prep_async_work(struct io_kiocb *req)
req->work.flags |= IO_WQ_WORK_UNBOUND; req->work.flags |= IO_WQ_WORK_UNBOUND;
break; break;
} }
if (io_sqe_needs_user(req->submit.sqe)) if (io_sqe_needs_user(req->sqe))
req->work.flags |= IO_WQ_WORK_NEEDS_USER; req->work.flags |= IO_WQ_WORK_NEEDS_USER;
} }
*link = io_prep_linked_timeout(req);
return do_hashed; return do_hashed;
} }
static inline void io_queue_async_work(struct io_kiocb *req) static inline void io_queue_async_work(struct io_kiocb *req)
{ {
bool do_hashed = io_prep_async_work(req);
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct io_kiocb *link;
bool do_hashed;
do_hashed = io_prep_async_work(req, &link);
trace_io_uring_queue_async_work(ctx, do_hashed, req, &req->work, trace_io_uring_queue_async_work(ctx, do_hashed, req, &req->work,
req->flags); req->flags);
...@@ -566,13 +582,16 @@ static inline void io_queue_async_work(struct io_kiocb *req) ...@@ -566,13 +582,16 @@ static inline void io_queue_async_work(struct io_kiocb *req)
io_wq_enqueue_hashed(ctx->io_wq, &req->work, io_wq_enqueue_hashed(ctx->io_wq, &req->work,
file_inode(req->file)); file_inode(req->file));
} }
if (link)
io_queue_linked_timeout(link);
} }
static void io_kill_timeout(struct io_kiocb *req) static void io_kill_timeout(struct io_kiocb *req)
{ {
int ret; int ret;
ret = hrtimer_try_to_cancel(&req->timeout.timer); ret = hrtimer_try_to_cancel(&req->timeout.data->timer);
if (ret != -1) { if (ret != -1) {
atomic_inc(&req->ctx->cq_timeouts); atomic_inc(&req->ctx->cq_timeouts);
list_del_init(&req->list); list_del_init(&req->list);
...@@ -601,11 +620,6 @@ static void io_commit_cqring(struct io_ring_ctx *ctx) ...@@ -601,11 +620,6 @@ static void io_commit_cqring(struct io_ring_ctx *ctx)
__io_commit_cqring(ctx); __io_commit_cqring(ctx);
while ((req = io_get_deferred_req(ctx)) != NULL) { while ((req = io_get_deferred_req(ctx)) != NULL) {
if (req->flags & REQ_F_SHADOW_DRAIN) {
/* Just for drain, free it. */
__io_free_req(req);
continue;
}
req->flags |= REQ_F_IO_DRAINED; req->flags |= REQ_F_IO_DRAINED;
io_queue_async_work(req); io_queue_async_work(req);
} }
...@@ -639,7 +653,8 @@ static void io_cqring_ev_posted(struct io_ring_ctx *ctx) ...@@ -639,7 +653,8 @@ static void io_cqring_ev_posted(struct io_ring_ctx *ctx)
eventfd_signal(ctx->cq_ev_fd, 1); eventfd_signal(ctx->cq_ev_fd, 1);
} }
static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force) /* Returns true if there are no backlogged entries after the flush */
static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
{ {
struct io_rings *rings = ctx->rings; struct io_rings *rings = ctx->rings;
struct io_uring_cqe *cqe; struct io_uring_cqe *cqe;
...@@ -649,10 +664,10 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force) ...@@ -649,10 +664,10 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
if (!force) { if (!force) {
if (list_empty_careful(&ctx->cq_overflow_list)) if (list_empty_careful(&ctx->cq_overflow_list))
return; return true;
if ((ctx->cached_cq_tail - READ_ONCE(rings->cq.head) == if ((ctx->cached_cq_tail - READ_ONCE(rings->cq.head) ==
rings->cq_ring_entries)) rings->cq_ring_entries))
return; return false;
} }
spin_lock_irqsave(&ctx->completion_lock, flags); spin_lock_irqsave(&ctx->completion_lock, flags);
...@@ -661,6 +676,7 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force) ...@@ -661,6 +676,7 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
if (force) if (force)
ctx->cq_overflow_flushed = true; ctx->cq_overflow_flushed = true;
cqe = NULL;
while (!list_empty(&ctx->cq_overflow_list)) { while (!list_empty(&ctx->cq_overflow_list)) {
cqe = io_get_cqring(ctx); cqe = io_get_cqring(ctx);
if (!cqe && !force) if (!cqe && !force)
...@@ -688,6 +704,8 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force) ...@@ -688,6 +704,8 @@ static void io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
list_del(&req->list); list_del(&req->list);
io_put_req(req); io_put_req(req);
} }
return cqe != NULL;
} }
static void io_cqring_fill_event(struct io_kiocb *req, long res) static void io_cqring_fill_event(struct io_kiocb *req, long res)
...@@ -787,6 +805,7 @@ static struct io_kiocb *io_get_req(struct io_ring_ctx *ctx, ...@@ -787,6 +805,7 @@ static struct io_kiocb *io_get_req(struct io_ring_ctx *ctx,
} }
got_it: got_it:
req->ring_file = NULL;
req->file = NULL; req->file = NULL;
req->ctx = ctx; req->ctx = ctx;
req->flags = 0; req->flags = 0;
...@@ -816,6 +835,8 @@ static void __io_free_req(struct io_kiocb *req) ...@@ -816,6 +835,8 @@ static void __io_free_req(struct io_kiocb *req)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
if (req->flags & REQ_F_FREE_SQE)
kfree(req->sqe);
if (req->file && !(req->flags & REQ_F_FIXED_FILE)) if (req->file && !(req->flags & REQ_F_FIXED_FILE))
fput(req->file); fput(req->file);
if (req->flags & REQ_F_INFLIGHT) { if (req->flags & REQ_F_INFLIGHT) {
...@@ -827,6 +848,8 @@ static void __io_free_req(struct io_kiocb *req) ...@@ -827,6 +848,8 @@ static void __io_free_req(struct io_kiocb *req)
wake_up(&ctx->inflight_wait); wake_up(&ctx->inflight_wait);
spin_unlock_irqrestore(&ctx->inflight_lock, flags); spin_unlock_irqrestore(&ctx->inflight_lock, flags);
} }
if (req->flags & REQ_F_TIMEOUT)
kfree(req->timeout.data);
percpu_ref_put(&ctx->refs); percpu_ref_put(&ctx->refs);
if (likely(!io_is_fallback_req(req))) if (likely(!io_is_fallback_req(req)))
kmem_cache_free(req_cachep, req); kmem_cache_free(req_cachep, req);
...@@ -839,7 +862,7 @@ static bool io_link_cancel_timeout(struct io_kiocb *req) ...@@ -839,7 +862,7 @@ static bool io_link_cancel_timeout(struct io_kiocb *req)
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
int ret; int ret;
ret = hrtimer_try_to_cancel(&req->timeout.timer); ret = hrtimer_try_to_cancel(&req->timeout.data->timer);
if (ret != -1) { if (ret != -1) {
io_cqring_fill_event(req, -ECANCELED); io_cqring_fill_event(req, -ECANCELED);
io_commit_cqring(ctx); io_commit_cqring(ctx);
...@@ -857,6 +880,10 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr) ...@@ -857,6 +880,10 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
struct io_kiocb *nxt; struct io_kiocb *nxt;
bool wake_ev = false; bool wake_ev = false;
/* Already got next link */
if (req->flags & REQ_F_LINK_NEXT)
return;
/* /*
* The list should never be empty when we are called here. But could * The list should never be empty when we are called here. But could
* potentially happen if the chain is messed up, check to be on the * potentially happen if the chain is messed up, check to be on the
...@@ -865,31 +892,26 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr) ...@@ -865,31 +892,26 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
nxt = list_first_entry_or_null(&req->link_list, struct io_kiocb, list); nxt = list_first_entry_or_null(&req->link_list, struct io_kiocb, list);
while (nxt) { while (nxt) {
list_del_init(&nxt->list); list_del_init(&nxt->list);
if ((req->flags & REQ_F_LINK_TIMEOUT) &&
(nxt->flags & REQ_F_TIMEOUT)) {
wake_ev |= io_link_cancel_timeout(nxt);
nxt = list_first_entry_or_null(&req->link_list,
struct io_kiocb, list);
req->flags &= ~REQ_F_LINK_TIMEOUT;
continue;
}
if (!list_empty(&req->link_list)) { if (!list_empty(&req->link_list)) {
INIT_LIST_HEAD(&nxt->link_list); INIT_LIST_HEAD(&nxt->link_list);
list_splice(&req->link_list, &nxt->link_list); list_splice(&req->link_list, &nxt->link_list);
nxt->flags |= REQ_F_LINK; nxt->flags |= REQ_F_LINK;
} }
/* *nxtptr = nxt;
* If we're in async work, we can continue processing the chain break;
* in this context instead of having to queue up new async work.
*/
if (req->flags & REQ_F_LINK_TIMEOUT) {
wake_ev = io_link_cancel_timeout(nxt);
/* we dropped this link, get next */
nxt = list_first_entry_or_null(&req->link_list,
struct io_kiocb, list);
} else if (nxtptr && io_wq_current_is_worker()) {
*nxtptr = nxt;
break;
} else {
io_queue_async_work(nxt);
break;
}
} }
req->flags |= REQ_F_LINK_NEXT;
if (wake_ev) if (wake_ev)
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
} }
...@@ -912,12 +934,13 @@ static void io_fail_links(struct io_kiocb *req) ...@@ -912,12 +934,13 @@ static void io_fail_links(struct io_kiocb *req)
trace_io_uring_fail_link(req, link); trace_io_uring_fail_link(req, link);
if ((req->flags & REQ_F_LINK_TIMEOUT) && if ((req->flags & REQ_F_LINK_TIMEOUT) &&
link->submit.sqe->opcode == IORING_OP_LINK_TIMEOUT) { link->sqe->opcode == IORING_OP_LINK_TIMEOUT) {
io_link_cancel_timeout(link); io_link_cancel_timeout(link);
} else { } else {
io_cqring_fill_event(link, -ECANCELED); io_cqring_fill_event(link, -ECANCELED);
io_double_put_req(link); __io_double_put_req(link);
} }
req->flags &= ~REQ_F_LINK_TIMEOUT;
} }
io_commit_cqring(ctx); io_commit_cqring(ctx);
...@@ -925,12 +948,10 @@ static void io_fail_links(struct io_kiocb *req) ...@@ -925,12 +948,10 @@ static void io_fail_links(struct io_kiocb *req)
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
} }
static void io_free_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt) static void io_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt)
{ {
if (likely(!(req->flags & REQ_F_LINK))) { if (likely(!(req->flags & REQ_F_LINK)))
__io_free_req(req);
return; return;
}
/* /*
* If LINK is set, we have dependent requests in this chain. If we * If LINK is set, we have dependent requests in this chain. If we
...@@ -956,32 +977,30 @@ static void io_free_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt) ...@@ -956,32 +977,30 @@ static void io_free_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt)
} else { } else {
io_req_link_next(req, nxt); io_req_link_next(req, nxt);
} }
__io_free_req(req);
} }
static void io_free_req(struct io_kiocb *req) static void io_free_req(struct io_kiocb *req)
{ {
io_free_req_find_next(req, NULL); struct io_kiocb *nxt = NULL;
io_req_find_next(req, &nxt);
__io_free_req(req);
if (nxt)
io_queue_async_work(nxt);
} }
/* /*
* Drop reference to request, return next in chain (if there is one) if this * Drop reference to request, return next in chain (if there is one) if this
* was the last reference to this request. * was the last reference to this request.
*/ */
__attribute__((nonnull))
static void io_put_req_find_next(struct io_kiocb *req, struct io_kiocb **nxtptr) static void io_put_req_find_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
{ {
struct io_kiocb *nxt = NULL; io_req_find_next(req, nxtptr);
if (refcount_dec_and_test(&req->refs)) if (refcount_dec_and_test(&req->refs))
io_free_req_find_next(req, &nxt); __io_free_req(req);
if (nxt) {
if (nxtptr)
*nxtptr = nxt;
else
io_queue_async_work(nxt);
}
} }
static void io_put_req(struct io_kiocb *req) static void io_put_req(struct io_kiocb *req)
...@@ -990,13 +1009,24 @@ static void io_put_req(struct io_kiocb *req) ...@@ -990,13 +1009,24 @@ static void io_put_req(struct io_kiocb *req)
io_free_req(req); io_free_req(req);
} }
static void io_double_put_req(struct io_kiocb *req) /*
* Must only be used if we don't need to care about links, usually from
* within the completion handling itself.
*/
static void __io_double_put_req(struct io_kiocb *req)
{ {
/* drop both submit and complete references */ /* drop both submit and complete references */
if (refcount_sub_and_test(2, &req->refs)) if (refcount_sub_and_test(2, &req->refs))
__io_free_req(req); __io_free_req(req);
} }
static void io_double_put_req(struct io_kiocb *req)
{
/* drop both submit and complete references */
if (refcount_sub_and_test(2, &req->refs))
io_free_req(req);
}
static unsigned io_cqring_events(struct io_ring_ctx *ctx, bool noflush) static unsigned io_cqring_events(struct io_ring_ctx *ctx, bool noflush)
{ {
struct io_rings *rings = ctx->rings; struct io_rings *rings = ctx->rings;
...@@ -1048,7 +1078,8 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1048,7 +1078,8 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
* completions for those, only batch free for fixed * completions for those, only batch free for fixed
* file and non-linked commands. * file and non-linked commands.
*/ */
if (((req->flags & (REQ_F_FIXED_FILE|REQ_F_LINK)) == if (((req->flags &
(REQ_F_FIXED_FILE|REQ_F_LINK|REQ_F_FREE_SQE)) ==
REQ_F_FIXED_FILE) && !io_is_fallback_req(req)) { REQ_F_FIXED_FILE) && !io_is_fallback_req(req)) {
reqs[to_free++] = req; reqs[to_free++] = req;
if (to_free == ARRAY_SIZE(reqs)) if (to_free == ARRAY_SIZE(reqs))
...@@ -1366,7 +1397,7 @@ static bool io_file_supports_async(struct file *file) ...@@ -1366,7 +1397,7 @@ static bool io_file_supports_async(struct file *file)
static int io_prep_rw(struct io_kiocb *req, bool force_nonblock) static int io_prep_rw(struct io_kiocb *req, bool force_nonblock)
{ {
const struct io_uring_sqe *sqe = req->submit.sqe; const struct io_uring_sqe *sqe = req->sqe;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct kiocb *kiocb = &req->rw; struct kiocb *kiocb = &req->rw;
unsigned ioprio; unsigned ioprio;
...@@ -1453,15 +1484,15 @@ static inline void io_rw_done(struct kiocb *kiocb, ssize_t ret) ...@@ -1453,15 +1484,15 @@ static inline void io_rw_done(struct kiocb *kiocb, ssize_t ret)
static void kiocb_done(struct kiocb *kiocb, ssize_t ret, struct io_kiocb **nxt, static void kiocb_done(struct kiocb *kiocb, ssize_t ret, struct io_kiocb **nxt,
bool in_async) bool in_async)
{ {
if (in_async && ret >= 0 && nxt && kiocb->ki_complete == io_complete_rw) if (in_async && ret >= 0 && kiocb->ki_complete == io_complete_rw)
*nxt = __io_complete_rw(kiocb, ret); *nxt = __io_complete_rw(kiocb, ret);
else else
io_rw_done(kiocb, ret); io_rw_done(kiocb, ret);
} }
static int io_import_fixed(struct io_ring_ctx *ctx, int rw, static ssize_t io_import_fixed(struct io_ring_ctx *ctx, int rw,
const struct io_uring_sqe *sqe, const struct io_uring_sqe *sqe,
struct iov_iter *iter) struct iov_iter *iter)
{ {
size_t len = READ_ONCE(sqe->len); size_t len = READ_ONCE(sqe->len);
struct io_mapped_ubuf *imu; struct io_mapped_ubuf *imu;
...@@ -1533,11 +1564,10 @@ static int io_import_fixed(struct io_ring_ctx *ctx, int rw, ...@@ -1533,11 +1564,10 @@ static int io_import_fixed(struct io_ring_ctx *ctx, int rw,
return len; return len;
} }
static ssize_t io_import_iovec(struct io_ring_ctx *ctx, int rw, static ssize_t io_import_iovec(int rw, struct io_kiocb *req,
const struct sqe_submit *s, struct iovec **iovec, struct iovec **iovec, struct iov_iter *iter)
struct iov_iter *iter)
{ {
const struct io_uring_sqe *sqe = s->sqe; const struct io_uring_sqe *sqe = req->sqe;
void __user *buf = u64_to_user_ptr(READ_ONCE(sqe->addr)); void __user *buf = u64_to_user_ptr(READ_ONCE(sqe->addr));
size_t sqe_len = READ_ONCE(sqe->len); size_t sqe_len = READ_ONCE(sqe->len);
u8 opcode; u8 opcode;
...@@ -1551,18 +1581,16 @@ static ssize_t io_import_iovec(struct io_ring_ctx *ctx, int rw, ...@@ -1551,18 +1581,16 @@ static ssize_t io_import_iovec(struct io_ring_ctx *ctx, int rw,
* flag. * flag.
*/ */
opcode = READ_ONCE(sqe->opcode); opcode = READ_ONCE(sqe->opcode);
if (opcode == IORING_OP_READ_FIXED || if (opcode == IORING_OP_READ_FIXED || opcode == IORING_OP_WRITE_FIXED) {
opcode == IORING_OP_WRITE_FIXED) {
ssize_t ret = io_import_fixed(ctx, rw, sqe, iter);
*iovec = NULL; *iovec = NULL;
return ret; return io_import_fixed(req->ctx, rw, sqe, iter);
} }
if (!s->has_user) if (!req->has_user)
return -EFAULT; return -EFAULT;
#ifdef CONFIG_COMPAT #ifdef CONFIG_COMPAT
if (ctx->compat) if (req->ctx->compat)
return compat_import_iovec(rw, buf, sqe_len, UIO_FASTIOV, return compat_import_iovec(rw, buf, sqe_len, UIO_FASTIOV,
iovec, iter); iovec, iter);
#endif #endif
...@@ -1590,9 +1618,19 @@ static ssize_t loop_rw_iter(int rw, struct file *file, struct kiocb *kiocb, ...@@ -1590,9 +1618,19 @@ static ssize_t loop_rw_iter(int rw, struct file *file, struct kiocb *kiocb,
return -EAGAIN; return -EAGAIN;
while (iov_iter_count(iter)) { while (iov_iter_count(iter)) {
struct iovec iovec = iov_iter_iovec(iter); struct iovec iovec;
ssize_t nr; ssize_t nr;
if (!iov_iter_is_bvec(iter)) {
iovec = iov_iter_iovec(iter);
} else {
/* fixed buffers import bvec */
iovec.iov_base = kmap(iter->bvec->bv_page)
+ iter->iov_offset;
iovec.iov_len = min(iter->count,
iter->bvec->bv_len - iter->iov_offset);
}
if (rw == READ) { if (rw == READ) {
nr = file->f_op->read(file, iovec.iov_base, nr = file->f_op->read(file, iovec.iov_base,
iovec.iov_len, &kiocb->ki_pos); iovec.iov_len, &kiocb->ki_pos);
...@@ -1601,6 +1639,9 @@ static ssize_t loop_rw_iter(int rw, struct file *file, struct kiocb *kiocb, ...@@ -1601,6 +1639,9 @@ static ssize_t loop_rw_iter(int rw, struct file *file, struct kiocb *kiocb,
iovec.iov_len, &kiocb->ki_pos); iovec.iov_len, &kiocb->ki_pos);
} }
if (iov_iter_is_bvec(iter))
kunmap(iter->bvec->bv_page);
if (nr < 0) { if (nr < 0) {
if (!ret) if (!ret)
ret = nr; ret = nr;
...@@ -1633,7 +1674,7 @@ static int io_read(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -1633,7 +1674,7 @@ static int io_read(struct io_kiocb *req, struct io_kiocb **nxt,
if (unlikely(!(file->f_mode & FMODE_READ))) if (unlikely(!(file->f_mode & FMODE_READ)))
return -EBADF; return -EBADF;
ret = io_import_iovec(req->ctx, READ, &req->submit, &iovec, &iter); ret = io_import_iovec(READ, req, &iovec, &iter);
if (ret < 0) if (ret < 0)
return ret; return ret;
...@@ -1665,7 +1706,7 @@ static int io_read(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -1665,7 +1706,7 @@ static int io_read(struct io_kiocb *req, struct io_kiocb **nxt,
ret2 = -EAGAIN; ret2 = -EAGAIN;
/* Catch -EAGAIN return for forced non-blocking submission */ /* Catch -EAGAIN return for forced non-blocking submission */
if (!force_nonblock || ret2 != -EAGAIN) if (!force_nonblock || ret2 != -EAGAIN)
kiocb_done(kiocb, ret2, nxt, req->submit.in_async); kiocb_done(kiocb, ret2, nxt, req->in_async);
else else
ret = -EAGAIN; ret = -EAGAIN;
} }
...@@ -1691,7 +1732,7 @@ static int io_write(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -1691,7 +1732,7 @@ static int io_write(struct io_kiocb *req, struct io_kiocb **nxt,
if (unlikely(!(file->f_mode & FMODE_WRITE))) if (unlikely(!(file->f_mode & FMODE_WRITE)))
return -EBADF; return -EBADF;
ret = io_import_iovec(req->ctx, WRITE, &req->submit, &iovec, &iter); ret = io_import_iovec(WRITE, req, &iovec, &iter);
if (ret < 0) if (ret < 0)
return ret; return ret;
...@@ -1728,7 +1769,7 @@ static int io_write(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -1728,7 +1769,7 @@ static int io_write(struct io_kiocb *req, struct io_kiocb **nxt,
else else
ret2 = loop_rw_iter(WRITE, file, kiocb, &iter); ret2 = loop_rw_iter(WRITE, file, kiocb, &iter);
if (!force_nonblock || ret2 != -EAGAIN) if (!force_nonblock || ret2 != -EAGAIN)
kiocb_done(kiocb, ret2, nxt, req->submit.in_async); kiocb_done(kiocb, ret2, nxt, req->in_async);
else else
ret = -EAGAIN; ret = -EAGAIN;
} }
...@@ -1918,7 +1959,7 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -1918,7 +1959,7 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL))) if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL)))
return -EINVAL; return -EINVAL;
if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index) if (sqe->ioprio || sqe->len || sqe->buf_index)
return -EINVAL; return -EINVAL;
addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr); addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
...@@ -1943,6 +1984,38 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -1943,6 +1984,38 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
#endif #endif
} }
static int io_connect(struct io_kiocb *req, const struct io_uring_sqe *sqe,
struct io_kiocb **nxt, bool force_nonblock)
{
#if defined(CONFIG_NET)
struct sockaddr __user *addr;
unsigned file_flags;
int addr_len, ret;
if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL)))
return -EINVAL;
if (sqe->ioprio || sqe->len || sqe->buf_index || sqe->rw_flags)
return -EINVAL;
addr = (struct sockaddr __user *) (unsigned long) READ_ONCE(sqe->addr);
addr_len = READ_ONCE(sqe->addr2);
file_flags = force_nonblock ? O_NONBLOCK : 0;
ret = __sys_connect_file(req->file, addr, addr_len, file_flags);
if (ret == -EAGAIN && force_nonblock)
return -EAGAIN;
if (ret == -ERESTARTSYS)
ret = -EINTR;
if (ret < 0 && (req->flags & REQ_F_LINK))
req->flags |= REQ_F_FAIL_LINK;
io_cqring_add_event(req, ret);
io_put_req_find_next(req, nxt);
return 0;
#else
return -EOPNOTSUPP;
#endif
}
static inline void io_poll_remove_req(struct io_kiocb *req) static inline void io_poll_remove_req(struct io_kiocb *req)
{ {
if (!RB_EMPTY_NODE(&req->rb_node)) { if (!RB_EMPTY_NODE(&req->rb_node)) {
...@@ -1957,8 +2030,8 @@ static void io_poll_remove_one(struct io_kiocb *req) ...@@ -1957,8 +2030,8 @@ static void io_poll_remove_one(struct io_kiocb *req)
spin_lock(&poll->head->lock); spin_lock(&poll->head->lock);
WRITE_ONCE(poll->canceled, true); WRITE_ONCE(poll->canceled, true);
if (!list_empty(&poll->wait.entry)) { if (!list_empty(&poll->wait->entry)) {
list_del_init(&poll->wait.entry); list_del_init(&poll->wait->entry);
io_queue_async_work(req); io_queue_async_work(req);
} }
spin_unlock(&poll->head->lock); spin_unlock(&poll->head->lock);
...@@ -2026,12 +2099,16 @@ static int io_poll_remove(struct io_kiocb *req, const struct io_uring_sqe *sqe) ...@@ -2026,12 +2099,16 @@ static int io_poll_remove(struct io_kiocb *req, const struct io_uring_sqe *sqe)
return 0; return 0;
} }
static void io_poll_complete(struct io_kiocb *req, __poll_t mask) static void io_poll_complete(struct io_kiocb *req, __poll_t mask, int error)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
req->poll.done = true; req->poll.done = true;
io_cqring_fill_event(req, mangle_poll(mask)); kfree(req->poll.wait);
if (error)
io_cqring_fill_event(req, error);
else
io_cqring_fill_event(req, mangle_poll(mask));
io_commit_cqring(ctx); io_commit_cqring(ctx);
} }
...@@ -2044,11 +2121,16 @@ static void io_poll_complete_work(struct io_wq_work **workptr) ...@@ -2044,11 +2121,16 @@ static void io_poll_complete_work(struct io_wq_work **workptr)
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct io_kiocb *nxt = NULL; struct io_kiocb *nxt = NULL;
__poll_t mask = 0; __poll_t mask = 0;
int ret = 0;
if (work->flags & IO_WQ_WORK_CANCEL) if (work->flags & IO_WQ_WORK_CANCEL) {
WRITE_ONCE(poll->canceled, true); WRITE_ONCE(poll->canceled, true);
ret = -ECANCELED;
} else if (READ_ONCE(poll->canceled)) {
ret = -ECANCELED;
}
if (!READ_ONCE(poll->canceled)) if (ret != -ECANCELED)
mask = vfs_poll(poll->file, &pt) & poll->events; mask = vfs_poll(poll->file, &pt) & poll->events;
/* /*
...@@ -2059,17 +2141,19 @@ static void io_poll_complete_work(struct io_wq_work **workptr) ...@@ -2059,17 +2141,19 @@ static void io_poll_complete_work(struct io_wq_work **workptr)
* avoid further branches in the fast path. * avoid further branches in the fast path.
*/ */
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
if (!mask && !READ_ONCE(poll->canceled)) { if (!mask && ret != -ECANCELED) {
add_wait_queue(poll->head, &poll->wait); add_wait_queue(poll->head, poll->wait);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
return; return;
} }
io_poll_remove_req(req); io_poll_remove_req(req);
io_poll_complete(req, mask); io_poll_complete(req, mask, ret);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
if (ret < 0 && req->flags & REQ_F_LINK)
req->flags |= REQ_F_FAIL_LINK;
io_put_req_find_next(req, &nxt); io_put_req_find_next(req, &nxt);
if (nxt) if (nxt)
*workptr = &nxt->work; *workptr = &nxt->work;
...@@ -2078,8 +2162,7 @@ static void io_poll_complete_work(struct io_wq_work **workptr) ...@@ -2078,8 +2162,7 @@ static void io_poll_complete_work(struct io_wq_work **workptr)
static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
void *key) void *key)
{ {
struct io_poll_iocb *poll = container_of(wait, struct io_poll_iocb, struct io_poll_iocb *poll = wait->private;
wait);
struct io_kiocb *req = container_of(poll, struct io_kiocb, poll); struct io_kiocb *req = container_of(poll, struct io_kiocb, poll);
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
__poll_t mask = key_to_poll(key); __poll_t mask = key_to_poll(key);
...@@ -2089,7 +2172,7 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, ...@@ -2089,7 +2172,7 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
if (mask && !(mask & poll->events)) if (mask && !(mask & poll->events))
return 0; return 0;
list_del_init(&poll->wait.entry); list_del_init(&poll->wait->entry);
/* /*
* Run completion inline if we can. We're using trylock here because * Run completion inline if we can. We're using trylock here because
...@@ -2099,7 +2182,7 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, ...@@ -2099,7 +2182,7 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
*/ */
if (mask && spin_trylock_irqsave(&ctx->completion_lock, flags)) { if (mask && spin_trylock_irqsave(&ctx->completion_lock, flags)) {
io_poll_remove_req(req); io_poll_remove_req(req);
io_poll_complete(req, mask); io_poll_complete(req, mask, 0);
req->flags |= REQ_F_COMP_LOCKED; req->flags |= REQ_F_COMP_LOCKED;
io_put_req(req); io_put_req(req);
spin_unlock_irqrestore(&ctx->completion_lock, flags); spin_unlock_irqrestore(&ctx->completion_lock, flags);
...@@ -2130,7 +2213,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head, ...@@ -2130,7 +2213,7 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
pt->error = 0; pt->error = 0;
pt->req->poll.head = head; pt->req->poll.head = head;
add_wait_queue(head, &pt->req->poll.wait); add_wait_queue(head, pt->req->poll.wait);
} }
static void io_poll_req_insert(struct io_kiocb *req) static void io_poll_req_insert(struct io_kiocb *req)
...@@ -2169,7 +2252,11 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2169,7 +2252,11 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
if (!poll->file) if (!poll->file)
return -EBADF; return -EBADF;
req->submit.sqe = NULL; poll->wait = kmalloc(sizeof(*poll->wait), GFP_KERNEL);
if (!poll->wait)
return -ENOMEM;
req->sqe = NULL;
INIT_IO_WORK(&req->work, io_poll_complete_work); INIT_IO_WORK(&req->work, io_poll_complete_work);
events = READ_ONCE(sqe->poll_events); events = READ_ONCE(sqe->poll_events);
poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP; poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
...@@ -2185,8 +2272,9 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2185,8 +2272,9 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
ipt.error = -EINVAL; /* same as no support for IOCB_CMD_POLL */ ipt.error = -EINVAL; /* same as no support for IOCB_CMD_POLL */
/* initialized the list so that we can do list_empty checks */ /* initialized the list so that we can do list_empty checks */
INIT_LIST_HEAD(&poll->wait.entry); INIT_LIST_HEAD(&poll->wait->entry);
init_waitqueue_func_entry(&poll->wait, io_poll_wake); init_waitqueue_func_entry(poll->wait, io_poll_wake);
poll->wait->private = poll;
INIT_LIST_HEAD(&req->list); INIT_LIST_HEAD(&req->list);
...@@ -2195,14 +2283,14 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2195,14 +2283,14 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
if (likely(poll->head)) { if (likely(poll->head)) {
spin_lock(&poll->head->lock); spin_lock(&poll->head->lock);
if (unlikely(list_empty(&poll->wait.entry))) { if (unlikely(list_empty(&poll->wait->entry))) {
if (ipt.error) if (ipt.error)
cancel = true; cancel = true;
ipt.error = 0; ipt.error = 0;
mask = 0; mask = 0;
} }
if (mask || ipt.error) if (mask || ipt.error)
list_del_init(&poll->wait.entry); list_del_init(&poll->wait->entry);
else if (cancel) else if (cancel)
WRITE_ONCE(poll->canceled, true); WRITE_ONCE(poll->canceled, true);
else if (!poll->done) /* actually waiting for an event */ else if (!poll->done) /* actually waiting for an event */
...@@ -2211,7 +2299,7 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2211,7 +2299,7 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
} }
if (mask) { /* no async, we'd stolen it */ if (mask) { /* no async, we'd stolen it */
ipt.error = 0; ipt.error = 0;
io_poll_complete(req, mask); io_poll_complete(req, mask, 0);
} }
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
...@@ -2224,12 +2312,12 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2224,12 +2312,12 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer) static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
{ {
struct io_ring_ctx *ctx; struct io_timeout_data *data = container_of(timer,
struct io_kiocb *req; struct io_timeout_data, timer);
struct io_kiocb *req = data->req;
struct io_ring_ctx *ctx = req->ctx;
unsigned long flags; unsigned long flags;
req = container_of(timer, struct io_kiocb, timeout.timer);
ctx = req->ctx;
atomic_inc(&ctx->cq_timeouts); atomic_inc(&ctx->cq_timeouts);
spin_lock_irqsave(&ctx->completion_lock, flags); spin_lock_irqsave(&ctx->completion_lock, flags);
...@@ -2279,10 +2367,12 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data) ...@@ -2279,10 +2367,12 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
if (ret == -ENOENT) if (ret == -ENOENT)
return ret; return ret;
ret = hrtimer_try_to_cancel(&req->timeout.timer); ret = hrtimer_try_to_cancel(&req->timeout.data->timer);
if (ret == -1) if (ret == -1)
return -EALREADY; return -EALREADY;
if (req->flags & REQ_F_LINK)
req->flags |= REQ_F_FAIL_LINK;
io_cqring_fill_event(req, -ECANCELED); io_cqring_fill_event(req, -ECANCELED);
io_put_req(req); io_put_req(req);
return 0; return 0;
...@@ -2319,34 +2409,54 @@ static int io_timeout_remove(struct io_kiocb *req, ...@@ -2319,34 +2409,54 @@ static int io_timeout_remove(struct io_kiocb *req,
return 0; return 0;
} }
static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe) static int io_timeout_setup(struct io_kiocb *req)
{ {
unsigned count; const struct io_uring_sqe *sqe = req->sqe;
struct io_ring_ctx *ctx = req->ctx; struct io_timeout_data *data;
struct list_head *entry;
enum hrtimer_mode mode;
struct timespec64 ts;
unsigned span = 0;
unsigned flags; unsigned flags;
if (unlikely(ctx->flags & IORING_SETUP_IOPOLL)) if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
return -EINVAL; return -EINVAL;
if (sqe->flags || sqe->ioprio || sqe->buf_index || sqe->len != 1) if (sqe->ioprio || sqe->buf_index || sqe->len != 1)
return -EINVAL; return -EINVAL;
flags = READ_ONCE(sqe->timeout_flags); flags = READ_ONCE(sqe->timeout_flags);
if (flags & ~IORING_TIMEOUT_ABS) if (flags & ~IORING_TIMEOUT_ABS)
return -EINVAL; return -EINVAL;
if (get_timespec64(&ts, u64_to_user_ptr(sqe->addr))) data = kzalloc(sizeof(struct io_timeout_data), GFP_KERNEL);
if (!data)
return -ENOMEM;
data->req = req;
req->timeout.data = data;
req->flags |= REQ_F_TIMEOUT;
if (get_timespec64(&data->ts, u64_to_user_ptr(sqe->addr)))
return -EFAULT; return -EFAULT;
if (flags & IORING_TIMEOUT_ABS) if (flags & IORING_TIMEOUT_ABS)
mode = HRTIMER_MODE_ABS; data->mode = HRTIMER_MODE_ABS;
else else
mode = HRTIMER_MODE_REL; data->mode = HRTIMER_MODE_REL;
hrtimer_init(&req->timeout.timer, CLOCK_MONOTONIC, mode); hrtimer_init(&data->timer, CLOCK_MONOTONIC, data->mode);
req->flags |= REQ_F_TIMEOUT; return 0;
}
static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
{
unsigned count;
struct io_ring_ctx *ctx = req->ctx;
struct io_timeout_data *data;
struct list_head *entry;
unsigned span = 0;
int ret;
ret = io_timeout_setup(req);
/* common setup allows flags (like links) set, we don't */
if (!ret && sqe->flags)
ret = -EINVAL;
if (ret)
return ret;
/* /*
* sqe->off holds how many events that need to occur for this * sqe->off holds how many events that need to occur for this
...@@ -2362,8 +2472,7 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe) ...@@ -2362,8 +2472,7 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
} }
req->sequence = ctx->cached_sq_head + count - 1; req->sequence = ctx->cached_sq_head + count - 1;
/* reuse it to store the count */ req->timeout.data->seq_offset = count;
req->submit.sequence = count;
/* /*
* Insertion sort, ensuring the first entry in the list is always * Insertion sort, ensuring the first entry in the list is always
...@@ -2374,6 +2483,7 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe) ...@@ -2374,6 +2483,7 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list); struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list);
unsigned nxt_sq_head; unsigned nxt_sq_head;
long long tmp, tmp_nxt; long long tmp, tmp_nxt;
u32 nxt_offset = nxt->timeout.data->seq_offset;
if (nxt->flags & REQ_F_TIMEOUT_NOSEQ) if (nxt->flags & REQ_F_TIMEOUT_NOSEQ)
continue; continue;
...@@ -2383,8 +2493,8 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe) ...@@ -2383,8 +2493,8 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
* long to store it. * long to store it.
*/ */
tmp = (long long)ctx->cached_sq_head + count - 1; tmp = (long long)ctx->cached_sq_head + count - 1;
nxt_sq_head = nxt->sequence - nxt->submit.sequence + 1; nxt_sq_head = nxt->sequence - nxt_offset + 1;
tmp_nxt = (long long)nxt_sq_head + nxt->submit.sequence - 1; tmp_nxt = (long long)nxt_sq_head + nxt_offset - 1;
/* /*
* cached_sq_head may overflow, and it will never overflow twice * cached_sq_head may overflow, and it will never overflow twice
...@@ -2406,8 +2516,9 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe) ...@@ -2406,8 +2516,9 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
req->sequence -= span; req->sequence -= span;
add: add:
list_add(&req->list, entry); list_add(&req->list, entry);
req->timeout.timer.function = io_timeout_fn; data = req->timeout.data;
hrtimer_start(&req->timeout.timer, timespec64_to_ktime(ts), mode); data->timer.function = io_timeout_fn;
hrtimer_start(&data->timer, timespec64_to_ktime(data->ts), data->mode);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
return 0; return 0;
} }
...@@ -2442,7 +2553,7 @@ static int io_async_cancel_one(struct io_ring_ctx *ctx, void *sqe_addr) ...@@ -2442,7 +2553,7 @@ static int io_async_cancel_one(struct io_ring_ctx *ctx, void *sqe_addr)
static void io_async_find_and_cancel(struct io_ring_ctx *ctx, static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
struct io_kiocb *req, __u64 sqe_addr, struct io_kiocb *req, __u64 sqe_addr,
struct io_kiocb **nxt) struct io_kiocb **nxt, int success_ret)
{ {
unsigned long flags; unsigned long flags;
int ret; int ret;
...@@ -2459,6 +2570,8 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx, ...@@ -2459,6 +2570,8 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
goto done; goto done;
ret = io_poll_cancel(ctx, sqe_addr); ret = io_poll_cancel(ctx, sqe_addr);
done: done:
if (!ret)
ret = success_ret;
io_cqring_fill_event(req, ret); io_cqring_fill_event(req, ret);
io_commit_cqring(ctx); io_commit_cqring(ctx);
spin_unlock_irqrestore(&ctx->completion_lock, flags); spin_unlock_irqrestore(&ctx->completion_lock, flags);
...@@ -2480,13 +2593,12 @@ static int io_async_cancel(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -2480,13 +2593,12 @@ static int io_async_cancel(struct io_kiocb *req, const struct io_uring_sqe *sqe,
sqe->cancel_flags) sqe->cancel_flags)
return -EINVAL; return -EINVAL;
io_async_find_and_cancel(ctx, req, READ_ONCE(sqe->addr), NULL); io_async_find_and_cancel(ctx, req, READ_ONCE(sqe->addr), nxt, 0);
return 0; return 0;
} }
static int io_req_defer(struct io_kiocb *req) static int io_req_defer(struct io_kiocb *req)
{ {
const struct io_uring_sqe *sqe = req->submit.sqe;
struct io_uring_sqe *sqe_copy; struct io_uring_sqe *sqe_copy;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
...@@ -2505,34 +2617,35 @@ static int io_req_defer(struct io_kiocb *req) ...@@ -2505,34 +2617,35 @@ static int io_req_defer(struct io_kiocb *req)
return 0; return 0;
} }
memcpy(sqe_copy, sqe, sizeof(*sqe_copy)); memcpy(sqe_copy, req->sqe, sizeof(*sqe_copy));
req->submit.sqe = sqe_copy; req->flags |= REQ_F_FREE_SQE;
req->sqe = sqe_copy;
trace_io_uring_defer(ctx, req, false); trace_io_uring_defer(ctx, req, req->user_data);
list_add_tail(&req->list, &ctx->defer_list); list_add_tail(&req->list, &ctx->defer_list);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
return -EIOCBQUEUED; return -EIOCBQUEUED;
} }
static int __io_submit_sqe(struct io_kiocb *req, struct io_kiocb **nxt, __attribute__((nonnull))
bool force_nonblock) static int io_issue_sqe(struct io_kiocb *req, struct io_kiocb **nxt,
bool force_nonblock)
{ {
int ret, opcode; int ret, opcode;
struct sqe_submit *s = &req->submit;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
opcode = READ_ONCE(s->sqe->opcode); opcode = READ_ONCE(req->sqe->opcode);
switch (opcode) { switch (opcode) {
case IORING_OP_NOP: case IORING_OP_NOP:
ret = io_nop(req); ret = io_nop(req);
break; break;
case IORING_OP_READV: case IORING_OP_READV:
if (unlikely(s->sqe->buf_index)) if (unlikely(req->sqe->buf_index))
return -EINVAL; return -EINVAL;
ret = io_read(req, nxt, force_nonblock); ret = io_read(req, nxt, force_nonblock);
break; break;
case IORING_OP_WRITEV: case IORING_OP_WRITEV:
if (unlikely(s->sqe->buf_index)) if (unlikely(req->sqe->buf_index))
return -EINVAL; return -EINVAL;
ret = io_write(req, nxt, force_nonblock); ret = io_write(req, nxt, force_nonblock);
break; break;
...@@ -2543,34 +2656,37 @@ static int __io_submit_sqe(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -2543,34 +2656,37 @@ static int __io_submit_sqe(struct io_kiocb *req, struct io_kiocb **nxt,
ret = io_write(req, nxt, force_nonblock); ret = io_write(req, nxt, force_nonblock);
break; break;
case IORING_OP_FSYNC: case IORING_OP_FSYNC:
ret = io_fsync(req, s->sqe, nxt, force_nonblock); ret = io_fsync(req, req->sqe, nxt, force_nonblock);
break; break;
case IORING_OP_POLL_ADD: case IORING_OP_POLL_ADD:
ret = io_poll_add(req, s->sqe, nxt); ret = io_poll_add(req, req->sqe, nxt);
break; break;
case IORING_OP_POLL_REMOVE: case IORING_OP_POLL_REMOVE:
ret = io_poll_remove(req, s->sqe); ret = io_poll_remove(req, req->sqe);
break; break;
case IORING_OP_SYNC_FILE_RANGE: case IORING_OP_SYNC_FILE_RANGE:
ret = io_sync_file_range(req, s->sqe, nxt, force_nonblock); ret = io_sync_file_range(req, req->sqe, nxt, force_nonblock);
break; break;
case IORING_OP_SENDMSG: case IORING_OP_SENDMSG:
ret = io_sendmsg(req, s->sqe, nxt, force_nonblock); ret = io_sendmsg(req, req->sqe, nxt, force_nonblock);
break; break;
case IORING_OP_RECVMSG: case IORING_OP_RECVMSG:
ret = io_recvmsg(req, s->sqe, nxt, force_nonblock); ret = io_recvmsg(req, req->sqe, nxt, force_nonblock);
break; break;
case IORING_OP_TIMEOUT: case IORING_OP_TIMEOUT:
ret = io_timeout(req, s->sqe); ret = io_timeout(req, req->sqe);
break; break;
case IORING_OP_TIMEOUT_REMOVE: case IORING_OP_TIMEOUT_REMOVE:
ret = io_timeout_remove(req, s->sqe); ret = io_timeout_remove(req, req->sqe);
break; break;
case IORING_OP_ACCEPT: case IORING_OP_ACCEPT:
ret = io_accept(req, s->sqe, nxt, force_nonblock); ret = io_accept(req, req->sqe, nxt, force_nonblock);
break;
case IORING_OP_CONNECT:
ret = io_connect(req, req->sqe, nxt, force_nonblock);
break; break;
case IORING_OP_ASYNC_CANCEL: case IORING_OP_ASYNC_CANCEL:
ret = io_async_cancel(req, s->sqe, nxt); ret = io_async_cancel(req, req->sqe, nxt);
break; break;
default: default:
ret = -EINVAL; ret = -EINVAL;
...@@ -2585,22 +2701,29 @@ static int __io_submit_sqe(struct io_kiocb *req, struct io_kiocb **nxt, ...@@ -2585,22 +2701,29 @@ static int __io_submit_sqe(struct io_kiocb *req, struct io_kiocb **nxt,
return -EAGAIN; return -EAGAIN;
/* workqueue context doesn't hold uring_lock, grab it now */ /* workqueue context doesn't hold uring_lock, grab it now */
if (s->in_async) if (req->in_async)
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
io_iopoll_req_issued(req); io_iopoll_req_issued(req);
if (s->in_async) if (req->in_async)
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
} }
return 0; return 0;
} }
static void io_link_work_cb(struct io_wq_work **workptr)
{
struct io_wq_work *work = *workptr;
struct io_kiocb *link = work->data;
io_queue_linked_timeout(link);
work->func = io_wq_submit_work;
}
static void io_wq_submit_work(struct io_wq_work **workptr) static void io_wq_submit_work(struct io_wq_work **workptr)
{ {
struct io_wq_work *work = *workptr; struct io_wq_work *work = *workptr;
struct io_kiocb *req = container_of(work, struct io_kiocb, work); struct io_kiocb *req = container_of(work, struct io_kiocb, work);
struct sqe_submit *s = &req->submit;
const struct io_uring_sqe *sqe = s->sqe;
struct io_kiocb *nxt = NULL; struct io_kiocb *nxt = NULL;
int ret = 0; int ret = 0;
...@@ -2611,10 +2734,10 @@ static void io_wq_submit_work(struct io_wq_work **workptr) ...@@ -2611,10 +2734,10 @@ static void io_wq_submit_work(struct io_wq_work **workptr)
ret = -ECANCELED; ret = -ECANCELED;
if (!ret) { if (!ret) {
s->has_user = (work->flags & IO_WQ_WORK_HAS_MM) != 0; req->has_user = (work->flags & IO_WQ_WORK_HAS_MM) != 0;
s->in_async = true; req->in_async = true;
do { do {
ret = __io_submit_sqe(req, &nxt, false); ret = io_issue_sqe(req, &nxt, false);
/* /*
* We can get EAGAIN for polled IO even though we're * We can get EAGAIN for polled IO even though we're
* forcing a sync submission from here, since we can't * forcing a sync submission from here, since we can't
...@@ -2636,13 +2759,17 @@ static void io_wq_submit_work(struct io_wq_work **workptr) ...@@ -2636,13 +2759,17 @@ static void io_wq_submit_work(struct io_wq_work **workptr)
io_put_req(req); io_put_req(req);
} }
/* async context always use a copy of the sqe */
kfree(sqe);
/* if a dependent link is ready, pass it back */ /* if a dependent link is ready, pass it back */
if (!ret && nxt) { if (!ret && nxt) {
io_prep_async_work(nxt); struct io_kiocb *link;
io_prep_async_work(nxt, &link);
*workptr = &nxt->work; *workptr = &nxt->work;
if (link) {
nxt->work.flags |= IO_WQ_WORK_CB;
nxt->work.func = io_link_work_cb;
nxt->work.data = link;
}
} }
} }
...@@ -2674,24 +2801,17 @@ static inline struct file *io_file_from_index(struct io_ring_ctx *ctx, ...@@ -2674,24 +2801,17 @@ static inline struct file *io_file_from_index(struct io_ring_ctx *ctx,
static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req) static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req)
{ {
struct sqe_submit *s = &req->submit;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
unsigned flags; unsigned flags;
int fd; int fd;
flags = READ_ONCE(s->sqe->flags); flags = READ_ONCE(req->sqe->flags);
fd = READ_ONCE(s->sqe->fd); fd = READ_ONCE(req->sqe->fd);
if (flags & IOSQE_IO_DRAIN) if (flags & IOSQE_IO_DRAIN)
req->flags |= REQ_F_IO_DRAIN; req->flags |= REQ_F_IO_DRAIN;
/*
* All io need record the previous position, if LINK vs DARIN,
* it can be used to mark the position of the first IO in the
* link list.
*/
req->sequence = s->sequence;
if (!io_op_needs_file(s->sqe)) if (!io_op_needs_file(req->sqe))
return 0; return 0;
if (flags & IOSQE_FIXED_FILE) { if (flags & IOSQE_FIXED_FILE) {
...@@ -2704,7 +2824,7 @@ static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req) ...@@ -2704,7 +2824,7 @@ static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req)
return -EBADF; return -EBADF;
req->flags |= REQ_F_FIXED_FILE; req->flags |= REQ_F_FIXED_FILE;
} else { } else {
if (s->needs_fixed_file) if (req->needs_fixed_file)
return -EBADF; return -EBADF;
trace_io_uring_file_get(ctx, fd); trace_io_uring_file_get(ctx, fd);
req->file = io_file_get(state, fd); req->file = io_file_get(state, fd);
...@@ -2728,7 +2848,7 @@ static int io_grab_files(struct io_kiocb *req) ...@@ -2728,7 +2848,7 @@ static int io_grab_files(struct io_kiocb *req)
* the fd has changed since we started down this path, and disallow * the fd has changed since we started down this path, and disallow
* this operation if it has. * this operation if it has.
*/ */
if (fcheck(req->submit.ring_fd) == req->submit.ring_file) { if (fcheck(req->ring_fd) == req->ring_file) {
list_add(&req->inflight_entry, &ctx->inflight_list); list_add(&req->inflight_entry, &ctx->inflight_list);
req->flags |= REQ_F_INFLIGHT; req->flags |= REQ_F_INFLIGHT;
req->work.files = current->files; req->work.files = current->files;
...@@ -2742,8 +2862,9 @@ static int io_grab_files(struct io_kiocb *req) ...@@ -2742,8 +2862,9 @@ static int io_grab_files(struct io_kiocb *req)
static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer) static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
{ {
struct io_kiocb *req = container_of(timer, struct io_kiocb, struct io_timeout_data *data = container_of(timer,
timeout.timer); struct io_timeout_data, timer);
struct io_kiocb *req = data->req;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct io_kiocb *prev = NULL; struct io_kiocb *prev = NULL;
unsigned long flags; unsigned long flags;
...@@ -2756,16 +2877,20 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer) ...@@ -2756,16 +2877,20 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
*/ */
if (!list_empty(&req->list)) { if (!list_empty(&req->list)) {
prev = list_entry(req->list.prev, struct io_kiocb, link_list); prev = list_entry(req->list.prev, struct io_kiocb, link_list);
if (refcount_inc_not_zero(&prev->refs)) if (refcount_inc_not_zero(&prev->refs)) {
list_del_init(&req->list); list_del_init(&req->list);
else prev->flags &= ~REQ_F_LINK_TIMEOUT;
} else
prev = NULL; prev = NULL;
} }
spin_unlock_irqrestore(&ctx->completion_lock, flags); spin_unlock_irqrestore(&ctx->completion_lock, flags);
if (prev) { if (prev) {
io_async_find_and_cancel(ctx, req, prev->user_data, NULL); if (prev->flags & REQ_F_LINK)
prev->flags |= REQ_F_FAIL_LINK;
io_async_find_and_cancel(ctx, req, prev->user_data, NULL,
-ETIME);
io_put_req(prev); io_put_req(prev);
} else { } else {
io_cqring_add_event(req, -ETIME); io_cqring_add_event(req, -ETIME);
...@@ -2774,8 +2899,7 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer) ...@@ -2774,8 +2899,7 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
return HRTIMER_NORESTART; return HRTIMER_NORESTART;
} }
static void io_queue_linked_timeout(struct io_kiocb *req, struct timespec64 *ts, static void io_queue_linked_timeout(struct io_kiocb *req)
enum hrtimer_mode *mode)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
...@@ -2785,9 +2909,11 @@ static void io_queue_linked_timeout(struct io_kiocb *req, struct timespec64 *ts, ...@@ -2785,9 +2909,11 @@ static void io_queue_linked_timeout(struct io_kiocb *req, struct timespec64 *ts,
*/ */
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
if (!list_empty(&req->list)) { if (!list_empty(&req->list)) {
req->timeout.timer.function = io_link_timeout_fn; struct io_timeout_data *data = req->timeout.data;
hrtimer_start(&req->timeout.timer, timespec64_to_ktime(*ts),
*mode); data->timer.function = io_link_timeout_fn;
hrtimer_start(&data->timer, timespec64_to_ktime(data->ts),
data->mode);
} }
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
...@@ -2795,66 +2921,30 @@ static void io_queue_linked_timeout(struct io_kiocb *req, struct timespec64 *ts, ...@@ -2795,66 +2921,30 @@ static void io_queue_linked_timeout(struct io_kiocb *req, struct timespec64 *ts,
io_put_req(req); io_put_req(req);
} }
static int io_validate_link_timeout(const struct io_uring_sqe *sqe, static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req)
struct timespec64 *ts)
{
if (sqe->ioprio || sqe->buf_index || sqe->len != 1 || sqe->off)
return -EINVAL;
if (sqe->timeout_flags & ~IORING_TIMEOUT_ABS)
return -EINVAL;
if (get_timespec64(ts, u64_to_user_ptr(sqe->addr)))
return -EFAULT;
return 0;
}
static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req,
struct timespec64 *ts,
enum hrtimer_mode *mode)
{ {
struct io_kiocb *nxt; struct io_kiocb *nxt;
int ret;
if (!(req->flags & REQ_F_LINK)) if (!(req->flags & REQ_F_LINK))
return NULL; return NULL;
nxt = list_first_entry_or_null(&req->link_list, struct io_kiocb, list); nxt = list_first_entry_or_null(&req->link_list, struct io_kiocb, list);
if (!nxt || nxt->submit.sqe->opcode != IORING_OP_LINK_TIMEOUT) if (!nxt || nxt->sqe->opcode != IORING_OP_LINK_TIMEOUT)
return NULL; return NULL;
ret = io_validate_link_timeout(nxt->submit.sqe, ts);
if (ret) {
list_del_init(&nxt->list);
io_cqring_add_event(nxt, ret);
io_double_put_req(nxt);
return ERR_PTR(-ECANCELED);
}
if (nxt->submit.sqe->timeout_flags & IORING_TIMEOUT_ABS)
*mode = HRTIMER_MODE_ABS;
else
*mode = HRTIMER_MODE_REL;
req->flags |= REQ_F_LINK_TIMEOUT; req->flags |= REQ_F_LINK_TIMEOUT;
hrtimer_init(&nxt->timeout.timer, CLOCK_MONOTONIC, *mode);
return nxt; return nxt;
} }
static int __io_queue_sqe(struct io_kiocb *req) static void __io_queue_sqe(struct io_kiocb *req)
{ {
enum hrtimer_mode mode; struct io_kiocb *linked_timeout = io_prep_linked_timeout(req);
struct io_kiocb *nxt; struct io_kiocb *nxt = NULL;
struct timespec64 ts;
int ret; int ret;
nxt = io_prep_linked_timeout(req, &ts, &mode); ret = io_issue_sqe(req, &nxt, true);
if (IS_ERR(nxt)) { if (nxt)
ret = PTR_ERR(nxt); io_queue_async_work(nxt);
nxt = NULL;
goto err;
}
ret = __io_submit_sqe(req, NULL, true);
/* /*
* We async punt it if the file wasn't marked NOWAIT, or if the file * We async punt it if the file wasn't marked NOWAIT, or if the file
...@@ -2862,42 +2952,38 @@ static int __io_queue_sqe(struct io_kiocb *req) ...@@ -2862,42 +2952,38 @@ static int __io_queue_sqe(struct io_kiocb *req)
*/ */
if (ret == -EAGAIN && (!(req->flags & REQ_F_NOWAIT) || if (ret == -EAGAIN && (!(req->flags & REQ_F_NOWAIT) ||
(req->flags & REQ_F_MUST_PUNT))) { (req->flags & REQ_F_MUST_PUNT))) {
struct sqe_submit *s = &req->submit;
struct io_uring_sqe *sqe_copy; struct io_uring_sqe *sqe_copy;
sqe_copy = kmemdup(s->sqe, sizeof(*sqe_copy), GFP_KERNEL); sqe_copy = kmemdup(req->sqe, sizeof(*sqe_copy), GFP_KERNEL);
if (sqe_copy) { if (!sqe_copy)
s->sqe = sqe_copy; goto err;
if (req->work.flags & IO_WQ_WORK_NEEDS_FILES) {
ret = io_grab_files(req);
if (ret) {
kfree(sqe_copy);
goto err;
}
}
/*
* Queued up for async execution, worker will release
* submit reference when the iocb is actually submitted.
*/
io_queue_async_work(req);
if (nxt) req->sqe = sqe_copy;
io_queue_linked_timeout(nxt, &ts, &mode); req->flags |= REQ_F_FREE_SQE;
return 0; if (req->work.flags & IO_WQ_WORK_NEEDS_FILES) {
ret = io_grab_files(req);
if (ret)
goto err;
} }
/*
* Queued up for async execution, worker will release
* submit reference when the iocb is actually submitted.
*/
io_queue_async_work(req);
return;
} }
err: err:
/* drop submission reference */ /* drop submission reference */
io_put_req(req); io_put_req(req);
if (nxt) { if (linked_timeout) {
if (!ret) if (!ret)
io_queue_linked_timeout(nxt, &ts, &mode); io_queue_linked_timeout(linked_timeout);
else else
io_put_req(nxt); io_put_req(linked_timeout);
} }
/* and drop final reference, if we failed */ /* and drop final reference, if we failed */
...@@ -2907,83 +2993,52 @@ static int __io_queue_sqe(struct io_kiocb *req) ...@@ -2907,83 +2993,52 @@ static int __io_queue_sqe(struct io_kiocb *req)
req->flags |= REQ_F_FAIL_LINK; req->flags |= REQ_F_FAIL_LINK;
io_put_req(req); io_put_req(req);
} }
return ret;
} }
static int io_queue_sqe(struct io_kiocb *req) static void io_queue_sqe(struct io_kiocb *req)
{ {
int ret; int ret;
ret = io_req_defer(req); if (unlikely(req->ctx->drain_next)) {
if (ret) { req->flags |= REQ_F_IO_DRAIN;
if (ret != -EIOCBQUEUED) { req->ctx->drain_next = false;
io_cqring_add_event(req, ret);
io_double_put_req(req);
}
return 0;
} }
req->ctx->drain_next = (req->flags & REQ_F_DRAIN_LINK);
return __io_queue_sqe(req);
}
static int io_queue_link_head(struct io_kiocb *req, struct io_kiocb *shadow)
{
int ret;
int need_submit = false;
struct io_ring_ctx *ctx = req->ctx;
if (!shadow)
return io_queue_sqe(req);
/*
* Mark the first IO in link list as DRAIN, let all the following
* IOs enter the defer list. all IO needs to be completed before link
* list.
*/
req->flags |= REQ_F_IO_DRAIN;
ret = io_req_defer(req); ret = io_req_defer(req);
if (ret) { if (ret) {
if (ret != -EIOCBQUEUED) { if (ret != -EIOCBQUEUED) {
io_cqring_add_event(req, ret); io_cqring_add_event(req, ret);
if (req->flags & REQ_F_LINK)
req->flags |= REQ_F_FAIL_LINK;
io_double_put_req(req); io_double_put_req(req);
__io_free_req(shadow);
return 0;
} }
} else { } else
/* __io_queue_sqe(req);
* If ret == 0 means that all IOs in front of link io are }
* running done. let's queue link head.
*/
need_submit = true;
}
/* Insert shadow req to defer_list, blocking next IOs */
spin_lock_irq(&ctx->completion_lock);
trace_io_uring_defer(ctx, shadow, true);
list_add_tail(&shadow->list, &ctx->defer_list);
spin_unlock_irq(&ctx->completion_lock);
if (need_submit)
return __io_queue_sqe(req);
return 0; static inline void io_queue_link_head(struct io_kiocb *req)
{
if (unlikely(req->flags & REQ_F_FAIL_LINK)) {
io_cqring_add_event(req, -ECANCELED);
io_double_put_req(req);
} else
io_queue_sqe(req);
} }
#define SQE_VALID_FLAGS (IOSQE_FIXED_FILE|IOSQE_IO_DRAIN|IOSQE_IO_LINK) #define SQE_VALID_FLAGS (IOSQE_FIXED_FILE|IOSQE_IO_DRAIN|IOSQE_IO_LINK)
static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state, static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state,
struct io_kiocb **link) struct io_kiocb **link)
{ {
struct io_uring_sqe *sqe_copy;
struct sqe_submit *s = &req->submit;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
int ret; int ret;
req->user_data = s->sqe->user_data; req->user_data = req->sqe->user_data;
/* enforce forwards compatibility on users */ /* enforce forwards compatibility on users */
if (unlikely(s->sqe->flags & ~SQE_VALID_FLAGS)) { if (unlikely(req->sqe->flags & ~SQE_VALID_FLAGS)) {
ret = -EINVAL; ret = -EINVAL;
goto err_req; goto err_req;
} }
...@@ -3005,25 +3060,37 @@ static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state, ...@@ -3005,25 +3060,37 @@ static void io_submit_sqe(struct io_kiocb *req, struct io_submit_state *state,
*/ */
if (*link) { if (*link) {
struct io_kiocb *prev = *link; struct io_kiocb *prev = *link;
struct io_uring_sqe *sqe_copy;
sqe_copy = kmemdup(s->sqe, sizeof(*sqe_copy), GFP_KERNEL); if (req->sqe->flags & IOSQE_IO_DRAIN)
(*link)->flags |= REQ_F_DRAIN_LINK | REQ_F_IO_DRAIN;
if (READ_ONCE(req->sqe->opcode) == IORING_OP_LINK_TIMEOUT) {
ret = io_timeout_setup(req);
/* common setup allows offset being set, we don't */
if (!ret && req->sqe->off)
ret = -EINVAL;
if (ret) {
prev->flags |= REQ_F_FAIL_LINK;
goto err_req;
}
}
sqe_copy = kmemdup(req->sqe, sizeof(*sqe_copy), GFP_KERNEL);
if (!sqe_copy) { if (!sqe_copy) {
ret = -EAGAIN; ret = -EAGAIN;
goto err_req; goto err_req;
} }
s->sqe = sqe_copy; req->sqe = sqe_copy;
req->flags |= REQ_F_FREE_SQE;
trace_io_uring_link(ctx, req, prev); trace_io_uring_link(ctx, req, prev);
list_add_tail(&req->list, &prev->link_list); list_add_tail(&req->list, &prev->link_list);
} else if (s->sqe->flags & IOSQE_IO_LINK) { } else if (req->sqe->flags & IOSQE_IO_LINK) {
req->flags |= REQ_F_LINK; req->flags |= REQ_F_LINK;
INIT_LIST_HEAD(&req->link_list); INIT_LIST_HEAD(&req->link_list);
*link = req; *link = req;
} else if (READ_ONCE(s->sqe->opcode) == IORING_OP_LINK_TIMEOUT) {
/* Only valid as a linked SQE */
ret = -EINVAL;
goto err_req;
} else { } else {
io_queue_sqe(req); io_queue_sqe(req);
} }
...@@ -3075,7 +3142,7 @@ static void io_commit_sqring(struct io_ring_ctx *ctx) ...@@ -3075,7 +3142,7 @@ static void io_commit_sqring(struct io_ring_ctx *ctx)
* used, it's important that those reads are done through READ_ONCE() to * used, it's important that those reads are done through READ_ONCE() to
* prevent a re-load down the line. * prevent a re-load down the line.
*/ */
static bool io_get_sqring(struct io_ring_ctx *ctx, struct sqe_submit *s) static bool io_get_sqring(struct io_ring_ctx *ctx, struct io_kiocb *req)
{ {
struct io_rings *rings = ctx->rings; struct io_rings *rings = ctx->rings;
u32 *sq_array = ctx->sq_array; u32 *sq_array = ctx->sq_array;
...@@ -3091,14 +3158,18 @@ static bool io_get_sqring(struct io_ring_ctx *ctx, struct sqe_submit *s) ...@@ -3091,14 +3158,18 @@ static bool io_get_sqring(struct io_ring_ctx *ctx, struct sqe_submit *s)
*/ */
head = ctx->cached_sq_head; head = ctx->cached_sq_head;
/* make sure SQ entry isn't read before tail */ /* make sure SQ entry isn't read before tail */
if (head == smp_load_acquire(&rings->sq.tail)) if (unlikely(head == smp_load_acquire(&rings->sq.tail)))
return false; return false;
head = READ_ONCE(sq_array[head & ctx->sq_mask]); head = READ_ONCE(sq_array[head & ctx->sq_mask]);
if (head < ctx->sq_entries) { if (likely(head < ctx->sq_entries)) {
s->ring_file = NULL; /*
s->sqe = &ctx->sq_sqes[head]; * All io need record the previous position, if LINK vs DARIN,
s->sequence = ctx->cached_sq_head; * it can be used to mark the position of the first IO in the
* link list.
*/
req->sequence = ctx->cached_sq_head;
req->sqe = &ctx->sq_sqes[head];
ctx->cached_sq_head++; ctx->cached_sq_head++;
return true; return true;
} }
...@@ -3116,14 +3187,13 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -3116,14 +3187,13 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
{ {
struct io_submit_state state, *statep = NULL; struct io_submit_state state, *statep = NULL;
struct io_kiocb *link = NULL; struct io_kiocb *link = NULL;
struct io_kiocb *shadow_req = NULL;
int i, submitted = 0; int i, submitted = 0;
bool mm_fault = false; bool mm_fault = false;
if (!list_empty(&ctx->cq_overflow_list)) { /* if we have a backlog and couldn't flush it all, return BUSY */
io_cqring_overflow_flush(ctx, false); if (!list_empty(&ctx->cq_overflow_list) &&
!io_cqring_overflow_flush(ctx, false))
return -EBUSY; return -EBUSY;
}
if (nr > IO_PLUG_THRESHOLD) { if (nr > IO_PLUG_THRESHOLD) {
io_submit_state_start(&state, ctx, nr); io_submit_state_start(&state, ctx, nr);
...@@ -3140,12 +3210,12 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -3140,12 +3210,12 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
submitted = -EAGAIN; submitted = -EAGAIN;
break; break;
} }
if (!io_get_sqring(ctx, &req->submit)) { if (!io_get_sqring(ctx, req)) {
__io_free_req(req); __io_free_req(req);
break; break;
} }
if (io_sqe_needs_user(req->submit.sqe) && !*mm) { if (io_sqe_needs_user(req->sqe) && !*mm) {
mm_fault = mm_fault || !mmget_not_zero(ctx->sqo_mm); mm_fault = mm_fault || !mmget_not_zero(ctx->sqo_mm);
if (!mm_fault) { if (!mm_fault) {
use_mm(ctx->sqo_mm); use_mm(ctx->sqo_mm);
...@@ -3153,26 +3223,14 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -3153,26 +3223,14 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
} }
} }
sqe_flags = req->submit.sqe->flags; sqe_flags = req->sqe->flags;
if (link && (sqe_flags & IOSQE_IO_DRAIN)) { req->ring_file = ring_file;
if (!shadow_req) { req->ring_fd = ring_fd;
shadow_req = io_get_req(ctx, NULL); req->has_user = *mm != NULL;
if (unlikely(!shadow_req)) req->in_async = async;
goto out; req->needs_fixed_file = async;
shadow_req->flags |= (REQ_F_IO_DRAIN | REQ_F_SHADOW_DRAIN); trace_io_uring_submit_sqe(ctx, req->sqe->user_data,
refcount_dec(&shadow_req->refs);
}
shadow_req->sequence = req->submit.sequence;
}
out:
req->submit.ring_file = ring_file;
req->submit.ring_fd = ring_fd;
req->submit.has_user = *mm != NULL;
req->submit.in_async = async;
req->submit.needs_fixed_file = async;
trace_io_uring_submit_sqe(ctx, req->submit.sqe->user_data,
true, async); true, async);
io_submit_sqe(req, statep, &link); io_submit_sqe(req, statep, &link);
submitted++; submitted++;
...@@ -3182,14 +3240,13 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -3182,14 +3240,13 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
* that's the end of the chain. Submit the previous link. * that's the end of the chain. Submit the previous link.
*/ */
if (!(sqe_flags & IOSQE_IO_LINK) && link) { if (!(sqe_flags & IOSQE_IO_LINK) && link) {
io_queue_link_head(link, shadow_req); io_queue_link_head(link);
link = NULL; link = NULL;
shadow_req = NULL;
} }
} }
if (link) if (link)
io_queue_link_head(link, shadow_req); io_queue_link_head(link);
if (statep) if (statep)
io_submit_state_end(&state); io_submit_state_end(&state);
...@@ -3203,6 +3260,7 @@ static int io_sq_thread(void *data) ...@@ -3203,6 +3260,7 @@ static int io_sq_thread(void *data)
{ {
struct io_ring_ctx *ctx = data; struct io_ring_ctx *ctx = data;
struct mm_struct *cur_mm = NULL; struct mm_struct *cur_mm = NULL;
const struct cred *old_cred;
mm_segment_t old_fs; mm_segment_t old_fs;
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
unsigned inflight; unsigned inflight;
...@@ -3213,6 +3271,7 @@ static int io_sq_thread(void *data) ...@@ -3213,6 +3271,7 @@ static int io_sq_thread(void *data)
old_fs = get_fs(); old_fs = get_fs();
set_fs(USER_DS); set_fs(USER_DS);
old_cred = override_creds(ctx->creds);
ret = timeout = inflight = 0; ret = timeout = inflight = 0;
while (!kthread_should_park()) { while (!kthread_should_park()) {
...@@ -3319,6 +3378,7 @@ static int io_sq_thread(void *data) ...@@ -3319,6 +3378,7 @@ static int io_sq_thread(void *data)
unuse_mm(cur_mm); unuse_mm(cur_mm);
mmput(cur_mm); mmput(cur_mm);
} }
revert_creds(old_cred);
kthread_parkme(); kthread_parkme();
...@@ -3898,6 +3958,7 @@ static void io_get_work(struct io_wq_work *work) ...@@ -3898,6 +3958,7 @@ static void io_get_work(struct io_wq_work *work)
static int io_sq_offload_start(struct io_ring_ctx *ctx, static int io_sq_offload_start(struct io_ring_ctx *ctx,
struct io_uring_params *p) struct io_uring_params *p)
{ {
struct io_wq_data data;
unsigned concurrency; unsigned concurrency;
int ret; int ret;
...@@ -3942,10 +4003,15 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx, ...@@ -3942,10 +4003,15 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
goto err; goto err;
} }
data.mm = ctx->sqo_mm;
data.user = ctx->user;
data.creds = ctx->creds;
data.get_work = io_get_work;
data.put_work = io_put_work;
/* Do QD, or 4 * CPUS, whatever is smallest */ /* Do QD, or 4 * CPUS, whatever is smallest */
concurrency = min(ctx->sq_entries, 4 * num_online_cpus()); concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user, ctx->io_wq = io_wq_create(concurrency, &data);
io_get_work, io_put_work);
if (IS_ERR(ctx->io_wq)) { if (IS_ERR(ctx->io_wq)) {
ret = PTR_ERR(ctx->io_wq); ret = PTR_ERR(ctx->io_wq);
ctx->io_wq = NULL; ctx->io_wq = NULL;
...@@ -4294,6 +4360,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx) ...@@ -4294,6 +4360,7 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
io_unaccount_mem(ctx->user, io_unaccount_mem(ctx->user,
ring_pages(ctx->sq_entries, ctx->cq_entries)); ring_pages(ctx->sq_entries, ctx->cq_entries));
free_uid(ctx->user); free_uid(ctx->user);
put_cred(ctx->creds);
kfree(ctx->completions); kfree(ctx->completions);
kmem_cache_free(req_cachep, ctx->fallback_req); kmem_cache_free(req_cachep, ctx->fallback_req);
kfree(ctx); kfree(ctx);
...@@ -4531,12 +4598,18 @@ static int io_allocate_scq_urings(struct io_ring_ctx *ctx, ...@@ -4531,12 +4598,18 @@ static int io_allocate_scq_urings(struct io_ring_ctx *ctx,
ctx->cq_entries = rings->cq_ring_entries; ctx->cq_entries = rings->cq_ring_entries;
size = array_size(sizeof(struct io_uring_sqe), p->sq_entries); size = array_size(sizeof(struct io_uring_sqe), p->sq_entries);
if (size == SIZE_MAX) if (size == SIZE_MAX) {
io_mem_free(ctx->rings);
ctx->rings = NULL;
return -EOVERFLOW; return -EOVERFLOW;
}
ctx->sq_sqes = io_mem_alloc(size); ctx->sq_sqes = io_mem_alloc(size);
if (!ctx->sq_sqes) if (!ctx->sq_sqes) {
io_mem_free(ctx->rings);
ctx->rings = NULL;
return -ENOMEM; return -ENOMEM;
}
return 0; return 0;
} }
...@@ -4640,6 +4713,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p) ...@@ -4640,6 +4713,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p)
ctx->compat = in_compat_syscall(); ctx->compat = in_compat_syscall();
ctx->account_mem = account_mem; ctx->account_mem = account_mem;
ctx->user = user; ctx->user = user;
ctx->creds = prepare_creds();
ret = io_allocate_scq_urings(ctx, p); ret = io_allocate_scq_urings(ctx, p);
if (ret) if (ret)
......
...@@ -399,6 +399,9 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr, ...@@ -399,6 +399,9 @@ extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
int __user *upeer_addrlen, int flags); int __user *upeer_addrlen, int flags);
extern int __sys_socket(int family, int type, int protocol); extern int __sys_socket(int family, int type, int protocol);
extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen); extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
extern int __sys_connect_file(struct file *file,
struct sockaddr __user *uservaddr, int addrlen,
int file_flags);
extern int __sys_connect(int fd, struct sockaddr __user *uservaddr, extern int __sys_connect(int fd, struct sockaddr __user *uservaddr,
int addrlen); int addrlen);
extern int __sys_listen(int fd, int backlog); extern int __sys_listen(int fd, int backlog);
......
...@@ -163,35 +163,35 @@ TRACE_EVENT(io_uring_queue_async_work, ...@@ -163,35 +163,35 @@ TRACE_EVENT(io_uring_queue_async_work,
); );
/** /**
* io_uring_defer_list - called before the io_uring work added into defer_list * io_uring_defer - called when an io_uring request is deferred
* *
* @ctx: pointer to a ring context structure * @ctx: pointer to a ring context structure
* @req: pointer to a deferred request * @req: pointer to a deferred request
* @shadow: whether request is shadow or not * @user_data: user data associated with the request
* *
* Allows to track deferred requests, to get an insight about what requests are * Allows to track deferred requests, to get an insight about what requests are
* not started immediately. * not started immediately.
*/ */
TRACE_EVENT(io_uring_defer, TRACE_EVENT(io_uring_defer,
TP_PROTO(void *ctx, void *req, bool shadow), TP_PROTO(void *ctx, void *req, unsigned long long user_data),
TP_ARGS(ctx, req, shadow), TP_ARGS(ctx, req, user_data),
TP_STRUCT__entry ( TP_STRUCT__entry (
__field( void *, ctx ) __field( void *, ctx )
__field( void *, req ) __field( void *, req )
__field( bool, shadow ) __field( unsigned long long, data )
), ),
TP_fast_assign( TP_fast_assign(
__entry->ctx = ctx; __entry->ctx = ctx;
__entry->req = req; __entry->req = req;
__entry->shadow = shadow; __entry->data = user_data;
), ),
TP_printk("ring %p, request %p%s", __entry->ctx, __entry->req, TP_printk("ring %p, request %p user_data %llu", __entry->ctx,
__entry->shadow ? ", shadow": "") __entry->req, __entry->data)
); );
/** /**
......
...@@ -73,6 +73,7 @@ struct io_uring_sqe { ...@@ -73,6 +73,7 @@ struct io_uring_sqe {
#define IORING_OP_ACCEPT 13 #define IORING_OP_ACCEPT 13
#define IORING_OP_ASYNC_CANCEL 14 #define IORING_OP_ASYNC_CANCEL 14
#define IORING_OP_LINK_TIMEOUT 15 #define IORING_OP_LINK_TIMEOUT 15
#define IORING_OP_CONNECT 16
/* /*
* sqe->fsync_flags * sqe->fsync_flags
......
...@@ -1825,32 +1825,46 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr, ...@@ -1825,32 +1825,46 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
* include the -EINPROGRESS status for such sockets. * include the -EINPROGRESS status for such sockets.
*/ */
int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen) int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr,
int addrlen, int file_flags)
{ {
struct socket *sock; struct socket *sock;
struct sockaddr_storage address; struct sockaddr_storage address;
int err, fput_needed; int err;
sock = sockfd_lookup_light(fd, &err, &fput_needed); sock = sock_from_file(file, &err);
if (!sock) if (!sock)
goto out; goto out;
err = move_addr_to_kernel(uservaddr, addrlen, &address); err = move_addr_to_kernel(uservaddr, addrlen, &address);
if (err < 0) if (err < 0)
goto out_put; goto out;
err = err =
security_socket_connect(sock, (struct sockaddr *)&address, addrlen); security_socket_connect(sock, (struct sockaddr *)&address, addrlen);
if (err) if (err)
goto out_put; goto out;
err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen, err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen,
sock->file->f_flags); sock->file->f_flags | file_flags);
out_put:
fput_light(sock->file, fput_needed);
out: out:
return err; return err;
} }
int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
{
int ret = -EBADF;
struct fd f;
f = fdget(fd);
if (f.file) {
ret = __sys_connect_file(f.file, uservaddr, addrlen, 0);
if (f.flags)
fput(f.file);
}
return ret;
}
SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr, SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr,
int, addrlen) int, addrlen)
{ {
...@@ -2250,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg, ...@@ -2250,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
return err < 0 ? err : 0; return err < 0 ? err : 0;
} }
static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
struct msghdr *msg_sys, unsigned int flags, unsigned int flags, struct used_address *used_address,
struct used_address *used_address, unsigned int allowed_msghdr_flags)
unsigned int allowed_msghdr_flags)
{ {
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *)msg;
struct sockaddr_storage address;
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
unsigned char ctl[sizeof(struct cmsghdr) + 20] unsigned char ctl[sizeof(struct cmsghdr) + 20]
__aligned(sizeof(__kernel_size_t)); __aligned(sizeof(__kernel_size_t));
/* 20 is size of ipv6_pktinfo */ /* 20 is size of ipv6_pktinfo */
...@@ -2266,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2266,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
int ctl_len; int ctl_len;
ssize_t err; ssize_t err;
msg_sys->msg_name = &address;
if (MSG_CMSG_COMPAT & flags)
err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
else
err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
if (err < 0)
return err;
err = -ENOBUFS; err = -ENOBUFS;
if (msg_sys->msg_controllen > INT_MAX) if (msg_sys->msg_controllen > INT_MAX)
goto out_freeiov; goto out;
flags |= (msg_sys->msg_flags & allowed_msghdr_flags); flags |= (msg_sys->msg_flags & allowed_msghdr_flags);
ctl_len = msg_sys->msg_controllen; ctl_len = msg_sys->msg_controllen;
if ((MSG_CMSG_COMPAT & flags) && ctl_len) { if ((MSG_CMSG_COMPAT & flags) && ctl_len) {
...@@ -2286,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2286,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl, cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl,
sizeof(ctl)); sizeof(ctl));
if (err) if (err)
goto out_freeiov; goto out;
ctl_buf = msg_sys->msg_control; ctl_buf = msg_sys->msg_control;
ctl_len = msg_sys->msg_controllen; ctl_len = msg_sys->msg_controllen;
} else if (ctl_len) { } else if (ctl_len) {
...@@ -2295,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2295,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
if (ctl_len > sizeof(ctl)) { if (ctl_len > sizeof(ctl)) {
ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL); ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL);
if (ctl_buf == NULL) if (ctl_buf == NULL)
goto out_freeiov; goto out;
} }
err = -EFAULT; err = -EFAULT;
/* /*
...@@ -2341,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2341,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
out_freectl: out_freectl:
if (ctl_buf != ctl) if (ctl_buf != ctl)
sock_kfree_s(sock->sk, ctl_buf, ctl_len); sock_kfree_s(sock->sk, ctl_buf, ctl_len);
out_freeiov: out:
return err;
}
static int sendmsg_copy_msghdr(struct msghdr *msg,
struct user_msghdr __user *umsg, unsigned flags,
struct iovec **iov)
{
int err;
if (flags & MSG_CMSG_COMPAT) {
struct compat_msghdr __user *msg_compat;
msg_compat = (struct compat_msghdr __user *) umsg;
err = get_compat_msghdr(msg, msg_compat, NULL, iov);
} else {
err = copy_msghdr_from_user(msg, umsg, NULL, iov);
}
if (err < 0)
return err;
return 0;
}
static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
struct msghdr *msg_sys, unsigned int flags,
struct used_address *used_address,
unsigned int allowed_msghdr_flags)
{
struct sockaddr_storage address;
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
ssize_t err;
msg_sys->msg_name = &address;
err = sendmsg_copy_msghdr(msg_sys, msg, flags, &iov);
if (err < 0)
return err;
err = ____sys_sendmsg(sock, msg_sys, flags, used_address,
allowed_msghdr_flags);
kfree(iov); kfree(iov);
return err; return err;
} }
...@@ -2349,12 +2389,27 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2349,12 +2389,27 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
/* /*
* BSD sendmsg interface * BSD sendmsg interface
*/ */
long __sys_sendmsg_sock(struct socket *sock, struct user_msghdr __user *msg, long __sys_sendmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
unsigned int flags) unsigned int flags)
{ {
struct msghdr msg_sys; struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
struct sockaddr_storage address;
struct msghdr msg = { .msg_name = &address };
ssize_t err;
err = sendmsg_copy_msghdr(&msg, umsg, flags, &iov);
if (err)
return err;
/* disallow ancillary data requests from this path */
if (msg.msg_control || msg.msg_controllen) {
err = -EINVAL;
goto out;
}
return ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0); err = ____sys_sendmsg(sock, &msg, flags, NULL, 0);
out:
kfree(iov);
return err;
} }
long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
...@@ -2460,33 +2515,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg, ...@@ -2460,33 +2515,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
return __sys_sendmmsg(fd, mmsg, vlen, flags, true); return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
} }
static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, static int recvmsg_copy_msghdr(struct msghdr *msg,
struct msghdr *msg_sys, unsigned int flags, int nosec) struct user_msghdr __user *umsg, unsigned flags,
struct sockaddr __user **uaddr,
struct iovec **iov)
{ {
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *)msg;
struct iovec iovstack[UIO_FASTIOV];
struct iovec *iov = iovstack;
unsigned long cmsg_ptr;
int len;
ssize_t err; ssize_t err;
/* kernel mode address */ if (MSG_CMSG_COMPAT & flags) {
struct sockaddr_storage addr; struct compat_msghdr __user *msg_compat;
/* user mode address pointers */
struct sockaddr __user *uaddr;
int __user *uaddr_len = COMPAT_NAMELEN(msg);
msg_sys->msg_name = &addr; msg_compat = (struct compat_msghdr __user *) umsg;
err = get_compat_msghdr(msg, msg_compat, uaddr, iov);
if (MSG_CMSG_COMPAT & flags) } else {
err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov); err = copy_msghdr_from_user(msg, umsg, uaddr, iov);
else }
err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov);
if (err < 0) if (err < 0)
return err; return err;
return 0;
}
static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys,
struct user_msghdr __user *msg,
struct sockaddr __user *uaddr,
unsigned int flags, int nosec)
{
struct compat_msghdr __user *msg_compat =
(struct compat_msghdr __user *) msg;
int __user *uaddr_len = COMPAT_NAMELEN(msg);
struct sockaddr_storage addr;
unsigned long cmsg_ptr;
int len;
ssize_t err;
msg_sys->msg_name = &addr;
cmsg_ptr = (unsigned long)msg_sys->msg_control; cmsg_ptr = (unsigned long)msg_sys->msg_control;
msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT); msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
...@@ -2497,7 +2560,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2497,7 +2560,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
flags |= MSG_DONTWAIT; flags |= MSG_DONTWAIT;
err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags); err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags);
if (err < 0) if (err < 0)
goto out_freeiov; goto out;
len = err; len = err;
if (uaddr != NULL) { if (uaddr != NULL) {
...@@ -2505,12 +2568,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2505,12 +2568,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
msg_sys->msg_namelen, uaddr, msg_sys->msg_namelen, uaddr,
uaddr_len); uaddr_len);
if (err < 0) if (err < 0)
goto out_freeiov; goto out;
} }
err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT), err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
COMPAT_FLAGS(msg)); COMPAT_FLAGS(msg));
if (err) if (err)
goto out_freeiov; goto out;
if (MSG_CMSG_COMPAT & flags) if (MSG_CMSG_COMPAT & flags)
err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
&msg_compat->msg_controllen); &msg_compat->msg_controllen);
...@@ -2518,10 +2581,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2518,10 +2581,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
&msg->msg_controllen); &msg->msg_controllen);
if (err) if (err)
goto out_freeiov; goto out;
err = len; err = len;
out:
return err;
}
static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
struct msghdr *msg_sys, unsigned int flags, int nosec)
{
struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
/* user mode address pointers */
struct sockaddr __user *uaddr;
ssize_t err;
out_freeiov: err = recvmsg_copy_msghdr(msg_sys, msg, flags, &uaddr, &iov);
if (err < 0)
return err;
err = ____sys_recvmsg(sock, msg_sys, msg, uaddr, flags, nosec);
kfree(iov); kfree(iov);
return err; return err;
} }
...@@ -2530,12 +2608,28 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, ...@@ -2530,12 +2608,28 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
* BSD recvmsg interface * BSD recvmsg interface
*/ */
long __sys_recvmsg_sock(struct socket *sock, struct user_msghdr __user *msg, long __sys_recvmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
unsigned int flags) unsigned int flags)
{ {
struct msghdr msg_sys; struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
struct sockaddr_storage address;
struct msghdr msg = { .msg_name = &address };
struct sockaddr __user *uaddr;
ssize_t err;
return ___sys_recvmsg(sock, msg, &msg_sys, flags, 0); err = recvmsg_copy_msghdr(&msg, umsg, flags, &uaddr, &iov);
if (err)
return err;
/* disallow ancillary data requests from this path */
if (msg.msg_control || msg.msg_controllen) {
err = -EINVAL;
goto out;
}
err = ____sys_recvmsg(sock, &msg, umsg, uaddr, flags, 0);
out:
kfree(iov);
return err;
} }
long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment