Commit 8aca67f0 authored by Trond Myklebust's avatar Trond Myklebust

SUNRPC: Fix a potential race in rpc_wake_up_task()

Use RCU to ensure that we can safely call rpc_finish_wakeup after we've
called __rpc_do_wake_up_task. If not, there is a theoretical race, in which
the rpc_task finishes executing, and gets freed first.
Signed-off-by: default avatarTrond Myklebust <Trond.Myklebust@netapp.com>
parent e6b3c4db
...@@ -65,13 +65,19 @@ struct nfs_read_data *nfs_readdata_alloc(size_t len) ...@@ -65,13 +65,19 @@ struct nfs_read_data *nfs_readdata_alloc(size_t len)
return p; return p;
} }
static void nfs_readdata_free(struct nfs_read_data *p) static void nfs_readdata_rcu_free(struct rcu_head *head)
{ {
struct nfs_read_data *p = container_of(head, struct nfs_read_data, task.u.tk_rcu);
if (p && (p->pagevec != &p->page_array[0])) if (p && (p->pagevec != &p->page_array[0]))
kfree(p->pagevec); kfree(p->pagevec);
mempool_free(p, nfs_rdata_mempool); mempool_free(p, nfs_rdata_mempool);
} }
static void nfs_readdata_free(struct nfs_read_data *rdata)
{
call_rcu_bh(&rdata->task.u.tk_rcu, nfs_readdata_rcu_free);
}
void nfs_readdata_release(void *data) void nfs_readdata_release(void *data)
{ {
nfs_readdata_free(data); nfs_readdata_free(data);
......
...@@ -102,13 +102,19 @@ struct nfs_write_data *nfs_commit_alloc(void) ...@@ -102,13 +102,19 @@ struct nfs_write_data *nfs_commit_alloc(void)
return p; return p;
} }
void nfs_commit_free(struct nfs_write_data *p) void nfs_commit_rcu_free(struct rcu_head *head)
{ {
struct nfs_write_data *p = container_of(head, struct nfs_write_data, task.u.tk_rcu);
if (p && (p->pagevec != &p->page_array[0])) if (p && (p->pagevec != &p->page_array[0]))
kfree(p->pagevec); kfree(p->pagevec);
mempool_free(p, nfs_commit_mempool); mempool_free(p, nfs_commit_mempool);
} }
void nfs_commit_free(struct nfs_write_data *wdata)
{
call_rcu_bh(&wdata->task.u.tk_rcu, nfs_commit_rcu_free);
}
struct nfs_write_data *nfs_writedata_alloc(size_t len) struct nfs_write_data *nfs_writedata_alloc(size_t len)
{ {
unsigned int pagecount = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; unsigned int pagecount = (len + PAGE_SIZE - 1) >> PAGE_SHIFT;
...@@ -131,13 +137,19 @@ struct nfs_write_data *nfs_writedata_alloc(size_t len) ...@@ -131,13 +137,19 @@ struct nfs_write_data *nfs_writedata_alloc(size_t len)
return p; return p;
} }
static void nfs_writedata_free(struct nfs_write_data *p) static void nfs_writedata_rcu_free(struct rcu_head *head)
{ {
struct nfs_write_data *p = container_of(head, struct nfs_write_data, task.u.tk_rcu);
if (p && (p->pagevec != &p->page_array[0])) if (p && (p->pagevec != &p->page_array[0]))
kfree(p->pagevec); kfree(p->pagevec);
mempool_free(p, nfs_wdata_mempool); mempool_free(p, nfs_wdata_mempool);
} }
static void nfs_writedata_free(struct nfs_write_data *wdata)
{
call_rcu_bh(&wdata->task.u.tk_rcu, nfs_writedata_rcu_free);
}
void nfs_writedata_release(void *wdata) void nfs_writedata_release(void *wdata)
{ {
nfs_writedata_free(wdata); nfs_writedata_free(wdata);
...@@ -258,7 +270,7 @@ static int nfs_writepage_sync(struct nfs_open_context *ctx, struct inode *inode, ...@@ -258,7 +270,7 @@ static int nfs_writepage_sync(struct nfs_open_context *ctx, struct inode *inode,
io_error: io_error:
nfs_end_data_update(inode); nfs_end_data_update(inode);
end_page_writeback(page); end_page_writeback(page);
nfs_writedata_free(wdata); nfs_writedata_release(wdata);
return written ? written : result; return written ? written : result;
} }
...@@ -1043,7 +1055,7 @@ static int nfs_flush_multi(struct inode *inode, struct list_head *head, int how) ...@@ -1043,7 +1055,7 @@ static int nfs_flush_multi(struct inode *inode, struct list_head *head, int how)
while (!list_empty(&list)) { while (!list_empty(&list)) {
data = list_entry(list.next, struct nfs_write_data, pages); data = list_entry(list.next, struct nfs_write_data, pages);
list_del(&data->pages); list_del(&data->pages);
nfs_writedata_free(data); nfs_writedata_release(data);
} }
nfs_mark_request_dirty(req); nfs_mark_request_dirty(req);
nfs_clear_page_writeback(req); nfs_clear_page_writeback(req);
......
...@@ -428,11 +428,6 @@ extern int nfs_updatepage(struct file *, struct page *, unsigned int, unsigned ...@@ -428,11 +428,6 @@ extern int nfs_updatepage(struct file *, struct page *, unsigned int, unsigned
extern int nfs_writeback_done(struct rpc_task *, struct nfs_write_data *); extern int nfs_writeback_done(struct rpc_task *, struct nfs_write_data *);
extern void nfs_writedata_release(void *); extern void nfs_writedata_release(void *);
#if defined(CONFIG_NFS_V3) || defined(CONFIG_NFS_V4)
struct nfs_write_data *nfs_commit_alloc(void);
void nfs_commit_free(struct nfs_write_data *p);
#endif
/* /*
* Try to write back everything synchronously (but check the * Try to write back everything synchronously (but check the
* return value!) * return value!)
...@@ -440,6 +435,8 @@ void nfs_commit_free(struct nfs_write_data *p); ...@@ -440,6 +435,8 @@ void nfs_commit_free(struct nfs_write_data *p);
extern int nfs_sync_inode_wait(struct inode *, unsigned long, unsigned int, int); extern int nfs_sync_inode_wait(struct inode *, unsigned long, unsigned int, int);
#if defined(CONFIG_NFS_V3) || defined(CONFIG_NFS_V4) #if defined(CONFIG_NFS_V3) || defined(CONFIG_NFS_V4)
extern int nfs_commit_inode(struct inode *, int); extern int nfs_commit_inode(struct inode *, int);
extern struct nfs_write_data *nfs_commit_alloc(void);
extern void nfs_commit_free(struct nfs_write_data *wdata);
extern void nfs_commit_release(void *wdata); extern void nfs_commit_release(void *wdata);
#else #else
static inline int static inline int
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <linux/timer.h> #include <linux/timer.h>
#include <linux/sunrpc/types.h> #include <linux/sunrpc/types.h>
#include <linux/rcupdate.h>
#include <linux/spinlock.h> #include <linux/spinlock.h>
#include <linux/wait.h> #include <linux/wait.h>
#include <linux/workqueue.h> #include <linux/workqueue.h>
...@@ -85,6 +86,7 @@ struct rpc_task { ...@@ -85,6 +86,7 @@ struct rpc_task {
union { union {
struct work_struct tk_work; /* Async task work queue */ struct work_struct tk_work; /* Async task work queue */
struct rpc_wait tk_wait; /* RPC wait */ struct rpc_wait tk_wait; /* RPC wait */
struct rcu_head tk_rcu; /* for task deletion */
} u; } u;
unsigned short tk_timeouts; /* maj timeouts */ unsigned short tk_timeouts; /* maj timeouts */
......
...@@ -427,16 +427,19 @@ __rpc_default_timer(struct rpc_task *task) ...@@ -427,16 +427,19 @@ __rpc_default_timer(struct rpc_task *task)
*/ */
void rpc_wake_up_task(struct rpc_task *task) void rpc_wake_up_task(struct rpc_task *task)
{ {
rcu_read_lock_bh();
if (rpc_start_wakeup(task)) { if (rpc_start_wakeup(task)) {
if (RPC_IS_QUEUED(task)) { if (RPC_IS_QUEUED(task)) {
struct rpc_wait_queue *queue = task->u.tk_wait.rpc_waitq; struct rpc_wait_queue *queue = task->u.tk_wait.rpc_waitq;
spin_lock_bh(&queue->lock); /* Note: we're already in a bh-safe context */
spin_lock(&queue->lock);
__rpc_do_wake_up_task(task); __rpc_do_wake_up_task(task);
spin_unlock_bh(&queue->lock); spin_unlock(&queue->lock);
} }
rpc_finish_wakeup(task); rpc_finish_wakeup(task);
} }
rcu_read_unlock_bh();
} }
/* /*
...@@ -499,14 +502,16 @@ struct rpc_task * rpc_wake_up_next(struct rpc_wait_queue *queue) ...@@ -499,14 +502,16 @@ struct rpc_task * rpc_wake_up_next(struct rpc_wait_queue *queue)
struct rpc_task *task = NULL; struct rpc_task *task = NULL;
dprintk("RPC: wake_up_next(%p \"%s\")\n", queue, rpc_qname(queue)); dprintk("RPC: wake_up_next(%p \"%s\")\n", queue, rpc_qname(queue));
spin_lock_bh(&queue->lock); rcu_read_lock_bh();
spin_lock(&queue->lock);
if (RPC_IS_PRIORITY(queue)) if (RPC_IS_PRIORITY(queue))
task = __rpc_wake_up_next_priority(queue); task = __rpc_wake_up_next_priority(queue);
else { else {
task_for_first(task, &queue->tasks[0]) task_for_first(task, &queue->tasks[0])
__rpc_wake_up_task(task); __rpc_wake_up_task(task);
} }
spin_unlock_bh(&queue->lock); spin_unlock(&queue->lock);
rcu_read_unlock_bh();
return task; return task;
} }
...@@ -522,7 +527,8 @@ void rpc_wake_up(struct rpc_wait_queue *queue) ...@@ -522,7 +527,8 @@ void rpc_wake_up(struct rpc_wait_queue *queue)
struct rpc_task *task, *next; struct rpc_task *task, *next;
struct list_head *head; struct list_head *head;
spin_lock_bh(&queue->lock); rcu_read_lock_bh();
spin_lock(&queue->lock);
head = &queue->tasks[queue->maxpriority]; head = &queue->tasks[queue->maxpriority];
for (;;) { for (;;) {
list_for_each_entry_safe(task, next, head, u.tk_wait.list) list_for_each_entry_safe(task, next, head, u.tk_wait.list)
...@@ -531,7 +537,8 @@ void rpc_wake_up(struct rpc_wait_queue *queue) ...@@ -531,7 +537,8 @@ void rpc_wake_up(struct rpc_wait_queue *queue)
break; break;
head--; head--;
} }
spin_unlock_bh(&queue->lock); spin_unlock(&queue->lock);
rcu_read_unlock_bh();
} }
/** /**
...@@ -546,7 +553,8 @@ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) ...@@ -546,7 +553,8 @@ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status)
struct rpc_task *task, *next; struct rpc_task *task, *next;
struct list_head *head; struct list_head *head;
spin_lock_bh(&queue->lock); rcu_read_lock_bh();
spin_lock(&queue->lock);
head = &queue->tasks[queue->maxpriority]; head = &queue->tasks[queue->maxpriority];
for (;;) { for (;;) {
list_for_each_entry_safe(task, next, head, u.tk_wait.list) { list_for_each_entry_safe(task, next, head, u.tk_wait.list) {
...@@ -557,7 +565,8 @@ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) ...@@ -557,7 +565,8 @@ void rpc_wake_up_status(struct rpc_wait_queue *queue, int status)
break; break;
head--; head--;
} }
spin_unlock_bh(&queue->lock); spin_unlock(&queue->lock);
rcu_read_unlock_bh();
} }
static void __rpc_atrun(struct rpc_task *task) static void __rpc_atrun(struct rpc_task *task)
...@@ -817,8 +826,9 @@ rpc_alloc_task(void) ...@@ -817,8 +826,9 @@ rpc_alloc_task(void)
return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS); return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS);
} }
static void rpc_free_task(struct rpc_task *task) static void rpc_free_task(struct rcu_head *rcu)
{ {
struct rpc_task *task = container_of(rcu, struct rpc_task, u.tk_rcu);
dprintk("RPC: %4d freeing task\n", task->tk_pid); dprintk("RPC: %4d freeing task\n", task->tk_pid);
mempool_free(task, rpc_task_mempool); mempool_free(task, rpc_task_mempool);
} }
...@@ -872,7 +882,7 @@ void rpc_put_task(struct rpc_task *task) ...@@ -872,7 +882,7 @@ void rpc_put_task(struct rpc_task *task)
task->tk_client = NULL; task->tk_client = NULL;
} }
if (task->tk_flags & RPC_TASK_DYNAMIC) if (task->tk_flags & RPC_TASK_DYNAMIC)
rpc_free_task(task); call_rcu_bh(&task->u.tk_rcu, rpc_free_task);
if (tk_ops->rpc_release) if (tk_ops->rpc_release)
tk_ops->rpc_release(calldata); tk_ops->rpc_release(calldata);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment