Commit 5273a191 authored by David Howells's avatar David Howells

rxrpc: Fix NULL pointer deref due to call->conn being cleared on disconnect

When a call is disconnected, the connection pointer from the call is
cleared to make sure it isn't used again and to prevent further attempted
transmission for the call.  Unfortunately, there might be a daemon trying
to use it at the same time to transmit a packet.

Fix this by keeping call->conn set, but setting a flag on the call to
indicate disconnection instead.

Remove also the bits in the transmission functions where the conn pointer is
checked and a ref taken under spinlock as this is now redundant.

Fixes: 8d94aa38 ("rxrpc: Calls shouldn't hold socket refs")
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
parent 04d36d74
...@@ -490,6 +490,7 @@ enum rxrpc_call_flag { ...@@ -490,6 +490,7 @@ enum rxrpc_call_flag {
RXRPC_CALL_RX_HEARD, /* The peer responded at least once to this call */ RXRPC_CALL_RX_HEARD, /* The peer responded at least once to this call */
RXRPC_CALL_RX_UNDERRUN, /* Got data underrun */ RXRPC_CALL_RX_UNDERRUN, /* Got data underrun */
RXRPC_CALL_IS_INTR, /* The call is interruptible */ RXRPC_CALL_IS_INTR, /* The call is interruptible */
RXRPC_CALL_DISCONNECTED, /* The call has been disconnected */
}; };
/* /*
......
...@@ -493,7 +493,7 @@ void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call) ...@@ -493,7 +493,7 @@ void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call)
_debug("RELEASE CALL %p (%d CONN %p)", call, call->debug_id, conn); _debug("RELEASE CALL %p (%d CONN %p)", call, call->debug_id, conn);
if (conn) if (conn && !test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
rxrpc_disconnect_call(call); rxrpc_disconnect_call(call);
if (call->security) if (call->security)
call->security->free_call_crypto(call); call->security->free_call_crypto(call);
...@@ -569,6 +569,7 @@ static void rxrpc_rcu_destroy_call(struct rcu_head *rcu) ...@@ -569,6 +569,7 @@ static void rxrpc_rcu_destroy_call(struct rcu_head *rcu)
struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu); struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu);
struct rxrpc_net *rxnet = call->rxnet; struct rxrpc_net *rxnet = call->rxnet;
rxrpc_put_connection(call->conn);
rxrpc_put_peer(call->peer); rxrpc_put_peer(call->peer);
kfree(call->rxtx_buffer); kfree(call->rxtx_buffer);
kfree(call->rxtx_annotations); kfree(call->rxtx_annotations);
...@@ -590,7 +591,6 @@ void rxrpc_cleanup_call(struct rxrpc_call *call) ...@@ -590,7 +591,6 @@ void rxrpc_cleanup_call(struct rxrpc_call *call)
ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE);
ASSERT(test_bit(RXRPC_CALL_RELEASED, &call->flags)); ASSERT(test_bit(RXRPC_CALL_RELEASED, &call->flags));
ASSERTCMP(call->conn, ==, NULL);
rxrpc_cleanup_ring(call); rxrpc_cleanup_ring(call);
rxrpc_free_skb(call->tx_pending, rxrpc_skb_cleaned); rxrpc_free_skb(call->tx_pending, rxrpc_skb_cleaned);
......
...@@ -785,6 +785,7 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call) ...@@ -785,6 +785,7 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call)
u32 cid; u32 cid;
spin_lock(&conn->channel_lock); spin_lock(&conn->channel_lock);
set_bit(RXRPC_CALL_DISCONNECTED, &call->flags);
cid = call->cid; cid = call->cid;
if (cid) { if (cid) {
...@@ -792,7 +793,6 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call) ...@@ -792,7 +793,6 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call)
chan = &conn->channels[channel]; chan = &conn->channels[channel];
} }
trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect); trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect);
call->conn = NULL;
/* Calls that have never actually been assigned a channel can simply be /* Calls that have never actually been assigned a channel can simply be
* discarded. If the conn didn't get used either, it will follow * discarded. If the conn didn't get used either, it will follow
...@@ -908,7 +908,6 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call) ...@@ -908,7 +908,6 @@ void rxrpc_disconnect_client_call(struct rxrpc_call *call)
spin_unlock(&rxnet->client_conn_cache_lock); spin_unlock(&rxnet->client_conn_cache_lock);
out_2: out_2:
spin_unlock(&conn->channel_lock); spin_unlock(&conn->channel_lock);
rxrpc_put_connection(conn);
_leave(""); _leave("");
return; return;
......
...@@ -171,6 +171,8 @@ void __rxrpc_disconnect_call(struct rxrpc_connection *conn, ...@@ -171,6 +171,8 @@ void __rxrpc_disconnect_call(struct rxrpc_connection *conn,
_enter("%d,%x", conn->debug_id, call->cid); _enter("%d,%x", conn->debug_id, call->cid);
set_bit(RXRPC_CALL_DISCONNECTED, &call->flags);
if (rcu_access_pointer(chan->call) == call) { if (rcu_access_pointer(chan->call) == call) {
/* Save the result of the call so that we can repeat it if necessary /* Save the result of the call so that we can repeat it if necessary
* through the channel, whilst disposing of the actual call record. * through the channel, whilst disposing of the actual call record.
...@@ -223,9 +225,7 @@ void rxrpc_disconnect_call(struct rxrpc_call *call) ...@@ -223,9 +225,7 @@ void rxrpc_disconnect_call(struct rxrpc_call *call)
__rxrpc_disconnect_call(conn, call); __rxrpc_disconnect_call(conn, call);
spin_unlock(&conn->channel_lock); spin_unlock(&conn->channel_lock);
call->conn = NULL;
conn->idle_timestamp = jiffies; conn->idle_timestamp = jiffies;
rxrpc_put_connection(conn);
} }
/* /*
......
...@@ -129,7 +129,7 @@ static size_t rxrpc_fill_out_ack(struct rxrpc_connection *conn, ...@@ -129,7 +129,7 @@ static size_t rxrpc_fill_out_ack(struct rxrpc_connection *conn,
int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping,
rxrpc_serial_t *_serial) rxrpc_serial_t *_serial)
{ {
struct rxrpc_connection *conn = NULL; struct rxrpc_connection *conn;
struct rxrpc_ack_buffer *pkt; struct rxrpc_ack_buffer *pkt;
struct msghdr msg; struct msghdr msg;
struct kvec iov[2]; struct kvec iov[2];
...@@ -139,18 +139,14 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, ...@@ -139,18 +139,14 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping,
int ret; int ret;
u8 reason; u8 reason;
spin_lock_bh(&call->lock); if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
if (call->conn)
conn = rxrpc_get_connection_maybe(call->conn);
spin_unlock_bh(&call->lock);
if (!conn)
return -ECONNRESET; return -ECONNRESET;
pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
if (!pkt) { if (!pkt)
rxrpc_put_connection(conn);
return -ENOMEM; return -ENOMEM;
}
conn = call->conn;
msg.msg_name = &call->peer->srx.transport; msg.msg_name = &call->peer->srx.transport;
msg.msg_namelen = call->peer->srx.transport_len; msg.msg_namelen = call->peer->srx.transport_len;
...@@ -244,7 +240,6 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, ...@@ -244,7 +240,6 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping,
} }
out: out:
rxrpc_put_connection(conn);
kfree(pkt); kfree(pkt);
return ret; return ret;
} }
...@@ -254,7 +249,7 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, ...@@ -254,7 +249,7 @@ int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping,
*/ */
int rxrpc_send_abort_packet(struct rxrpc_call *call) int rxrpc_send_abort_packet(struct rxrpc_call *call)
{ {
struct rxrpc_connection *conn = NULL; struct rxrpc_connection *conn;
struct rxrpc_abort_buffer pkt; struct rxrpc_abort_buffer pkt;
struct msghdr msg; struct msghdr msg;
struct kvec iov[1]; struct kvec iov[1];
...@@ -271,13 +266,11 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) ...@@ -271,13 +266,11 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call)
test_bit(RXRPC_CALL_TX_LAST, &call->flags)) test_bit(RXRPC_CALL_TX_LAST, &call->flags))
return 0; return 0;
spin_lock_bh(&call->lock); if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags))
if (call->conn)
conn = rxrpc_get_connection_maybe(call->conn);
spin_unlock_bh(&call->lock);
if (!conn)
return -ECONNRESET; return -ECONNRESET;
conn = call->conn;
msg.msg_name = &call->peer->srx.transport; msg.msg_name = &call->peer->srx.transport;
msg.msg_namelen = call->peer->srx.transport_len; msg.msg_namelen = call->peer->srx.transport_len;
msg.msg_control = NULL; msg.msg_control = NULL;
...@@ -312,8 +305,6 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call) ...@@ -312,8 +305,6 @@ int rxrpc_send_abort_packet(struct rxrpc_call *call)
trace_rxrpc_tx_packet(call->debug_id, &pkt.whdr, trace_rxrpc_tx_packet(call->debug_id, &pkt.whdr,
rxrpc_tx_point_call_abort); rxrpc_tx_point_call_abort);
rxrpc_tx_backoff(call, ret); rxrpc_tx_backoff(call, ret);
rxrpc_put_connection(conn);
return ret; return ret;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment