Commit 98bf40cd authored by David Howells's avatar David Howells

afs: Protect call->state changes against signals

Protect call->state changes against the call being prematurely terminated
due to a signal.

What can happen is that a signal causes afs_wait_for_call_to_complete() to
abort an afs_call because it's not yet complete whilst afs_deliver_to_call()
is delivering data to that call.

If the data delivery causes the state to change, this may overwrite the state
of the afs_call, making it not-yet-complete again - but no further
notifications will be forthcoming from AF_RXRPC as the rxrpc call has been
aborted and completed, so kAFS will just hang in various places waiting for
that call or on page bits that need clearing by that call.

A tracepoint to monitor call state changes is also provided.
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
parent 13524ab3
...@@ -188,7 +188,6 @@ static int afs_deliver_cb_callback(struct afs_call *call) ...@@ -188,7 +188,6 @@ static int afs_deliver_cb_callback(struct afs_call *call)
switch (call->unmarshall) { switch (call->unmarshall) {
case 0: case 0:
rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx);
call->offset = 0; call->offset = 0;
call->unmarshall++; call->unmarshall++;
...@@ -281,10 +280,12 @@ static int afs_deliver_cb_callback(struct afs_call *call) ...@@ -281,10 +280,12 @@ static int afs_deliver_cb_callback(struct afs_call *call)
break; break;
} }
call->state = AFS_CALL_REPLYING; if (!afs_check_call_state(call, AFS_CALL_SV_REPLYING))
return -EIO;
/* we'll need the file server record as that tells us which set of /* we'll need the file server record as that tells us which set of
* vnodes to operate upon */ * vnodes to operate upon */
rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx);
server = afs_find_server(call->net, &srx); server = afs_find_server(call->net, &srx);
if (!server) if (!server)
return -ENOTCONN; return -ENOTCONN;
...@@ -325,9 +326,6 @@ static int afs_deliver_cb_init_call_back_state(struct afs_call *call) ...@@ -325,9 +326,6 @@ static int afs_deliver_cb_init_call_back_state(struct afs_call *call)
if (ret < 0) if (ret < 0)
return ret; return ret;
/* no unmarshalling required */
call->state = AFS_CALL_REPLYING;
/* we'll need the file server record as that tells us which set of /* we'll need the file server record as that tells us which set of
* vnodes to operate upon */ * vnodes to operate upon */
server = afs_find_server(call->net, &srx); server = afs_find_server(call->net, &srx);
...@@ -352,8 +350,6 @@ static int afs_deliver_cb_init_call_back_state3(struct afs_call *call) ...@@ -352,8 +350,6 @@ static int afs_deliver_cb_init_call_back_state3(struct afs_call *call)
_enter(""); _enter("");
rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx);
_enter("{%u}", call->unmarshall); _enter("{%u}", call->unmarshall);
switch (call->unmarshall) { switch (call->unmarshall) {
...@@ -397,11 +393,12 @@ static int afs_deliver_cb_init_call_back_state3(struct afs_call *call) ...@@ -397,11 +393,12 @@ static int afs_deliver_cb_init_call_back_state3(struct afs_call *call)
break; break;
} }
/* no unmarshalling required */ if (!afs_check_call_state(call, AFS_CALL_SV_REPLYING))
call->state = AFS_CALL_REPLYING; return -EIO;
/* we'll need the file server record as that tells us which set of /* we'll need the file server record as that tells us which set of
* vnodes to operate upon */ * vnodes to operate upon */
rxrpc_kernel_get_peer(call->net->socket, call->rxcall, &srx);
server = afs_find_server(call->net, &srx); server = afs_find_server(call->net, &srx);
if (!server) if (!server)
return -ENOTCONN; return -ENOTCONN;
...@@ -436,8 +433,8 @@ static int afs_deliver_cb_probe(struct afs_call *call) ...@@ -436,8 +433,8 @@ static int afs_deliver_cb_probe(struct afs_call *call)
if (ret < 0) if (ret < 0)
return ret; return ret;
/* no unmarshalling required */ if (!afs_check_call_state(call, AFS_CALL_SV_REPLYING))
call->state = AFS_CALL_REPLYING; return -EIO;
return afs_queue_call_work(call); return afs_queue_call_work(call);
} }
...@@ -519,7 +516,8 @@ static int afs_deliver_cb_probe_uuid(struct afs_call *call) ...@@ -519,7 +516,8 @@ static int afs_deliver_cb_probe_uuid(struct afs_call *call)
break; break;
} }
call->state = AFS_CALL_REPLYING; if (!afs_check_call_state(call, AFS_CALL_SV_REPLYING))
return -EIO;
return afs_queue_call_work(call); return afs_queue_call_work(call);
} }
...@@ -600,8 +598,8 @@ static int afs_deliver_cb_tell_me_about_yourself(struct afs_call *call) ...@@ -600,8 +598,8 @@ static int afs_deliver_cb_tell_me_about_yourself(struct afs_call *call)
if (ret < 0) if (ret < 0)
return ret; return ret;
/* no unmarshalling required */ if (!afs_check_call_state(call, AFS_CALL_SV_REPLYING))
call->state = AFS_CALL_REPLYING; return -EIO;
return afs_queue_call_work(call); return afs_queue_call_work(call);
} }
...@@ -51,12 +51,13 @@ struct afs_iget_data { ...@@ -51,12 +51,13 @@ struct afs_iget_data {
}; };
enum afs_call_state { enum afs_call_state {
AFS_CALL_REQUESTING, /* request is being sent for outgoing call */ AFS_CALL_CL_REQUESTING, /* Client: Request is being sent */
AFS_CALL_AWAIT_REPLY, /* awaiting reply to outgoing call */ AFS_CALL_CL_AWAIT_REPLY, /* Client: Awaiting reply */
AFS_CALL_AWAIT_OP_ID, /* awaiting op ID on incoming call */ AFS_CALL_CL_PROC_REPLY, /* Client: rxrpc call complete; processing reply */
AFS_CALL_AWAIT_REQUEST, /* awaiting request data on incoming call */ AFS_CALL_SV_AWAIT_OP_ID, /* Server: Awaiting op ID */
AFS_CALL_REPLYING, /* replying to incoming call */ AFS_CALL_SV_AWAIT_REQUEST, /* Server: Awaiting request data */
AFS_CALL_AWAIT_ACK, /* awaiting final ACK of incoming call */ AFS_CALL_SV_REPLYING, /* Server: Replying */
AFS_CALL_SV_AWAIT_ACK, /* Server: Awaiting final ACK */
AFS_CALL_COMPLETE, /* Completed or failed */ AFS_CALL_COMPLETE, /* Completed or failed */
}; };
...@@ -97,6 +98,7 @@ struct afs_call { ...@@ -97,6 +98,7 @@ struct afs_call {
size_t offset; /* offset into received data store */ size_t offset; /* offset into received data store */
atomic_t usage; atomic_t usage;
enum afs_call_state state; enum afs_call_state state;
spinlock_t state_lock;
int error; /* error code */ int error; /* error code */
u32 abort_code; /* Remote abort ID or 0 */ u32 abort_code; /* Remote abort ID or 0 */
unsigned request_size; /* size of request data */ unsigned request_size; /* size of request data */
...@@ -543,6 +545,8 @@ struct afs_fs_cursor { ...@@ -543,6 +545,8 @@ struct afs_fs_cursor {
#define AFS_FS_CURSOR_NO_VSLEEP 0x0020 /* Set to prevent sleep on VBUSY, VOFFLINE, ... */ #define AFS_FS_CURSOR_NO_VSLEEP 0x0020 /* Set to prevent sleep on VBUSY, VOFFLINE, ... */
}; };
#include <trace/events/afs.h>
/*****************************************************************************/ /*****************************************************************************/
/* /*
* addr_list.c * addr_list.c
...@@ -788,6 +792,49 @@ static inline int afs_transfer_reply(struct afs_call *call) ...@@ -788,6 +792,49 @@ static inline int afs_transfer_reply(struct afs_call *call)
return afs_extract_data(call, call->buffer, call->reply_max, false); return afs_extract_data(call, call->buffer, call->reply_max, false);
} }
static inline bool afs_check_call_state(struct afs_call *call,
enum afs_call_state state)
{
return READ_ONCE(call->state) == state;
}
static inline bool afs_set_call_state(struct afs_call *call,
enum afs_call_state from,
enum afs_call_state to)
{
bool ok = false;
spin_lock_bh(&call->state_lock);
if (call->state == from) {
call->state = to;
trace_afs_call_state(call, from, to, 0, 0);
ok = true;
}
spin_unlock_bh(&call->state_lock);
return ok;
}
static inline void afs_set_call_complete(struct afs_call *call,
int error, u32 remote_abort)
{
enum afs_call_state state;
bool ok = false;
spin_lock_bh(&call->state_lock);
state = call->state;
if (state != AFS_CALL_COMPLETE) {
call->abort_code = remote_abort;
call->error = error;
call->state = AFS_CALL_COMPLETE;
trace_afs_call_state(call, state, AFS_CALL_COMPLETE,
error, remote_abort);
ok = true;
}
spin_unlock_bh(&call->state_lock);
if (ok)
trace_afs_call_done(call);
}
/* /*
* security.c * security.c
*/ */
...@@ -932,8 +979,6 @@ static inline void afs_check_for_remote_deletion(struct afs_fs_cursor *fc, ...@@ -932,8 +979,6 @@ static inline void afs_check_for_remote_deletion(struct afs_fs_cursor *fc,
/* /*
* debug tracing * debug tracing
*/ */
#include <trace/events/afs.h>
extern unsigned afs_debug; extern unsigned afs_debug;
#define dbgprintk(FMT,...) \ #define dbgprintk(FMT,...) \
......
...@@ -134,6 +134,7 @@ static struct afs_call *afs_alloc_call(struct afs_net *net, ...@@ -134,6 +134,7 @@ static struct afs_call *afs_alloc_call(struct afs_net *net,
atomic_set(&call->usage, 1); atomic_set(&call->usage, 1);
INIT_WORK(&call->async_work, afs_process_async_call); INIT_WORK(&call->async_work, afs_process_async_call);
init_waitqueue_head(&call->waitq); init_waitqueue_head(&call->waitq);
spin_lock_init(&call->state_lock);
o = atomic_inc_return(&net->nr_outstanding_calls); o = atomic_inc_return(&net->nr_outstanding_calls);
trace_afs_call(call, afs_call_trace_alloc, 1, o, trace_afs_call(call, afs_call_trace_alloc, 1, o,
...@@ -288,8 +289,7 @@ static void afs_notify_end_request_tx(struct sock *sock, ...@@ -288,8 +289,7 @@ static void afs_notify_end_request_tx(struct sock *sock,
{ {
struct afs_call *call = (struct afs_call *)call_user_ID; struct afs_call *call = (struct afs_call *)call_user_ID;
if (call->state == AFS_CALL_REQUESTING) afs_set_call_state(call, AFS_CALL_CL_REQUESTING, AFS_CALL_CL_AWAIT_REPLY);
call->state = AFS_CALL_AWAIT_REPLY;
} }
/* /*
...@@ -444,82 +444,87 @@ long afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call, ...@@ -444,82 +444,87 @@ long afs_make_call(struct afs_addr_cursor *ac, struct afs_call *call,
*/ */
static void afs_deliver_to_call(struct afs_call *call) static void afs_deliver_to_call(struct afs_call *call)
{ {
u32 abort_code; enum afs_call_state state;
u32 abort_code, remote_abort = 0;
int ret; int ret;
_enter("%s", call->type->name); _enter("%s", call->type->name);
while (call->state == AFS_CALL_AWAIT_REPLY || while (state = READ_ONCE(call->state),
call->state == AFS_CALL_AWAIT_OP_ID || state == AFS_CALL_CL_AWAIT_REPLY ||
call->state == AFS_CALL_AWAIT_REQUEST || state == AFS_CALL_SV_AWAIT_OP_ID ||
call->state == AFS_CALL_AWAIT_ACK state == AFS_CALL_SV_AWAIT_REQUEST ||
state == AFS_CALL_SV_AWAIT_ACK
) { ) {
if (call->state == AFS_CALL_AWAIT_ACK) { if (state == AFS_CALL_SV_AWAIT_ACK) {
size_t offset = 0; size_t offset = 0;
ret = rxrpc_kernel_recv_data(call->net->socket, ret = rxrpc_kernel_recv_data(call->net->socket,
call->rxcall, call->rxcall,
NULL, 0, &offset, false, NULL, 0, &offset, false,
&call->abort_code, &remote_abort,
&call->service_id); &call->service_id);
trace_afs_recv_data(call, 0, offset, false, ret); trace_afs_recv_data(call, 0, offset, false, ret);
if (ret == -EINPROGRESS || ret == -EAGAIN) if (ret == -EINPROGRESS || ret == -EAGAIN)
return; return;
if (ret < 0) if (ret < 0 || ret == 1) {
call->error = ret; if (ret == 1)
if (ret < 0 || ret == 1) ret = 0;
goto call_complete; goto call_complete;
}
return; return;
} }
ret = call->type->deliver(call); ret = call->type->deliver(call);
state = READ_ONCE(call->state);
switch (ret) { switch (ret) {
case 0: case 0:
if (call->state == AFS_CALL_AWAIT_REPLY) if (state == AFS_CALL_CL_PROC_REPLY)
goto call_complete; goto call_complete;
ASSERTCMP(state, >, AFS_CALL_CL_PROC_REPLY);
goto done; goto done;
case -EINPROGRESS: case -EINPROGRESS:
case -EAGAIN: case -EAGAIN:
goto out; goto out;
case -EIO:
case -ECONNABORTED: case -ECONNABORTED:
goto save_error; ASSERTCMP(state, ==, AFS_CALL_COMPLETE);
goto done;
case -ENOTCONN: case -ENOTCONN:
abort_code = RX_CALL_DEAD; abort_code = RX_CALL_DEAD;
rxrpc_kernel_abort_call(call->net->socket, call->rxcall, rxrpc_kernel_abort_call(call->net->socket, call->rxcall,
abort_code, ret, "KNC"); abort_code, ret, "KNC");
goto save_error; goto local_abort;
case -ENOTSUPP: case -ENOTSUPP:
abort_code = RXGEN_OPCODE; abort_code = RXGEN_OPCODE;
rxrpc_kernel_abort_call(call->net->socket, call->rxcall, rxrpc_kernel_abort_call(call->net->socket, call->rxcall,
abort_code, ret, "KIV"); abort_code, ret, "KIV");
goto save_error; goto local_abort;
case -ENODATA: case -ENODATA:
case -EBADMSG: case -EBADMSG:
case -EMSGSIZE: case -EMSGSIZE:
default: default:
abort_code = RXGEN_CC_UNMARSHAL; abort_code = RXGEN_CC_UNMARSHAL;
if (call->state != AFS_CALL_AWAIT_REPLY) if (state != AFS_CALL_CL_AWAIT_REPLY)
abort_code = RXGEN_SS_UNMARSHAL; abort_code = RXGEN_SS_UNMARSHAL;
rxrpc_kernel_abort_call(call->net->socket, call->rxcall, rxrpc_kernel_abort_call(call->net->socket, call->rxcall,
abort_code, -EBADMSG, "KUM"); abort_code, -EBADMSG, "KUM");
goto save_error; goto local_abort;
} }
} }
done: done:
if (call->state == AFS_CALL_COMPLETE && call->incoming) if (state == AFS_CALL_COMPLETE && call->incoming)
afs_put_call(call); afs_put_call(call);
out: out:
_leave(""); _leave("");
return; return;
save_error: local_abort:
call->error = ret; abort_code = 0;
call_complete: call_complete:
if (call->state != AFS_CALL_COMPLETE) { afs_set_call_complete(call, ret, remote_abort);
call->state = AFS_CALL_COMPLETE; state = AFS_CALL_COMPLETE;
trace_afs_call_done(call);
}
goto done; goto done;
} }
...@@ -551,14 +556,15 @@ static long afs_wait_for_call_to_complete(struct afs_call *call, ...@@ -551,14 +556,15 @@ static long afs_wait_for_call_to_complete(struct afs_call *call,
set_current_state(TASK_UNINTERRUPTIBLE); set_current_state(TASK_UNINTERRUPTIBLE);
/* deliver any messages that are in the queue */ /* deliver any messages that are in the queue */
if (call->state < AFS_CALL_COMPLETE && call->need_attention) { if (!afs_check_call_state(call, AFS_CALL_COMPLETE) &&
call->need_attention) {
call->need_attention = false; call->need_attention = false;
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
afs_deliver_to_call(call); afs_deliver_to_call(call);
continue; continue;
} }
if (call->state == AFS_CALL_COMPLETE) if (afs_check_call_state(call, AFS_CALL_COMPLETE))
break; break;
life = rxrpc_kernel_check_life(call->net->socket, call->rxcall); life = rxrpc_kernel_check_life(call->net->socket, call->rxcall);
...@@ -578,17 +584,17 @@ static long afs_wait_for_call_to_complete(struct afs_call *call, ...@@ -578,17 +584,17 @@ static long afs_wait_for_call_to_complete(struct afs_call *call,
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
/* Kill off the call if it's still live. */ /* Kill off the call if it's still live. */
if (call->state < AFS_CALL_COMPLETE) { if (!afs_check_call_state(call, AFS_CALL_COMPLETE)) {
_debug("call interrupted"); _debug("call interrupted");
if (rxrpc_kernel_abort_call(call->net->socket, call->rxcall, if (rxrpc_kernel_abort_call(call->net->socket, call->rxcall,
RX_USER_ABORT, -EINTR, "KWI")) { RX_USER_ABORT, -EINTR, "KWI"))
call->error = -ERESTARTSYS; afs_set_call_complete(call, -EINTR, 0);
trace_afs_call_done(call);
}
} }
spin_lock_bh(&call->state_lock);
ac->abort_code = call->abort_code; ac->abort_code = call->abort_code;
ac->error = call->error; ac->error = call->error;
spin_unlock_bh(&call->state_lock);
ret = ac->error; ret = ac->error;
switch (ret) { switch (ret) {
...@@ -713,7 +719,7 @@ void afs_charge_preallocation(struct work_struct *work) ...@@ -713,7 +719,7 @@ void afs_charge_preallocation(struct work_struct *work)
break; break;
call->async = true; call->async = true;
call->state = AFS_CALL_AWAIT_OP_ID; call->state = AFS_CALL_SV_AWAIT_OP_ID;
init_waitqueue_head(&call->waitq); init_waitqueue_head(&call->waitq);
} }
...@@ -769,7 +775,7 @@ static int afs_deliver_cm_op_id(struct afs_call *call) ...@@ -769,7 +775,7 @@ static int afs_deliver_cm_op_id(struct afs_call *call)
return ret; return ret;
call->operation_ID = ntohl(call->tmp); call->operation_ID = ntohl(call->tmp);
call->state = AFS_CALL_AWAIT_REQUEST; afs_set_call_state(call, AFS_CALL_SV_AWAIT_OP_ID, AFS_CALL_SV_AWAIT_REQUEST);
call->offset = 0; call->offset = 0;
/* ask the cache manager to route the call (it'll change the call type /* ask the cache manager to route the call (it'll change the call type
...@@ -794,8 +800,7 @@ static void afs_notify_end_reply_tx(struct sock *sock, ...@@ -794,8 +800,7 @@ static void afs_notify_end_reply_tx(struct sock *sock,
{ {
struct afs_call *call = (struct afs_call *)call_user_ID; struct afs_call *call = (struct afs_call *)call_user_ID;
if (call->state == AFS_CALL_REPLYING) afs_set_call_state(call, AFS_CALL_SV_REPLYING, AFS_CALL_SV_AWAIT_ACK);
call->state = AFS_CALL_AWAIT_ACK;
} }
/* /*
...@@ -879,6 +884,8 @@ int afs_extract_data(struct afs_call *call, void *buf, size_t count, ...@@ -879,6 +884,8 @@ int afs_extract_data(struct afs_call *call, void *buf, size_t count,
bool want_more) bool want_more)
{ {
struct afs_net *net = call->net; struct afs_net *net = call->net;
enum afs_call_state state;
u32 remote_abort;
int ret; int ret;
_enter("{%s,%zu},,%zu,%d", _enter("{%s,%zu},,%zu,%d",
...@@ -888,29 +895,30 @@ int afs_extract_data(struct afs_call *call, void *buf, size_t count, ...@@ -888,29 +895,30 @@ int afs_extract_data(struct afs_call *call, void *buf, size_t count,
ret = rxrpc_kernel_recv_data(net->socket, call->rxcall, ret = rxrpc_kernel_recv_data(net->socket, call->rxcall,
buf, count, &call->offset, buf, count, &call->offset,
want_more, &call->abort_code, want_more, &remote_abort,
&call->service_id); &call->service_id);
trace_afs_recv_data(call, count, call->offset, want_more, ret); trace_afs_recv_data(call, count, call->offset, want_more, ret);
if (ret == 0 || ret == -EAGAIN) if (ret == 0 || ret == -EAGAIN)
return ret; return ret;
state = READ_ONCE(call->state);
if (ret == 1) { if (ret == 1) {
switch (call->state) { switch (state) {
case AFS_CALL_AWAIT_REPLY: case AFS_CALL_CL_AWAIT_REPLY:
call->state = AFS_CALL_COMPLETE; afs_set_call_state(call, state, AFS_CALL_CL_PROC_REPLY);
trace_afs_call_done(call);
break; break;
case AFS_CALL_AWAIT_REQUEST: case AFS_CALL_SV_AWAIT_REQUEST:
call->state = AFS_CALL_REPLYING; afs_set_call_state(call, state, AFS_CALL_SV_REPLYING);
break; break;
case AFS_CALL_COMPLETE:
kdebug("prem complete %d", call->error);
return -EIO;
default: default:
break; break;
} }
return 0; return 0;
} }
call->error = ret; afs_set_call_complete(call, ret, remote_abort);
call->state = AFS_CALL_COMPLETE;
trace_afs_call_done(call);
return ret; return ret;
} }
...@@ -441,6 +441,36 @@ TRACE_EVENT(afs_page_dirty, ...@@ -441,6 +441,36 @@ TRACE_EVENT(afs_page_dirty,
__entry->priv >> AFS_PRIV_SHIFT) __entry->priv >> AFS_PRIV_SHIFT)
); );
TRACE_EVENT(afs_call_state,
TP_PROTO(struct afs_call *call,
enum afs_call_state from,
enum afs_call_state to,
int ret, u32 remote_abort),
TP_ARGS(call, from, to, ret, remote_abort),
TP_STRUCT__entry(
__field(struct afs_call *, call )
__field(enum afs_call_state, from )
__field(enum afs_call_state, to )
__field(int, ret )
__field(u32, abort )
),
TP_fast_assign(
__entry->call = call;
__entry->from = from;
__entry->to = to;
__entry->ret = ret;
__entry->abort = remote_abort;
),
TP_printk("c=%p %u->%u r=%d ab=%d",
__entry->call,
__entry->from, __entry->to,
__entry->ret, __entry->abort)
);
#endif /* _TRACE_AFS_H */ #endif /* _TRACE_AFS_H */
/* This part must be outside protection */ /* This part must be outside protection */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment