Commit ec1af27e authored by David S. Miller's avatar David S. Miller

Merge tag 'rxrpc-rewrite-20170406' of...

Merge tag 'rxrpc-rewrite-20170406' of git://git.kernel.org/pub/scm/linux/kernel/git/dhowells/linux-fs

David Howells says:

====================
rxrpc: Miscellany

Here's a set of patches that make some minor changes to AF_RXRPC:

 (1) Store error codes in struct rxrpc_call::error as negative codes and
     only convert to positive in recvmsg() to avoid confusion inside the
     kernel.

 (2) Note the result of trying to abort a call (this fails if the call is
     already 'completed').

 (3) Don't abort on temporary errors whilst processing challenge and
     response packets, but rather drop the packet and wait for
     retransmission.

And also adds some more tracing:

 (4) Protocol errors.

 (5) Received abort packets.

 (6) Changes in the Rx window size due to ACK packet information.

 (7) Client call initiation (to allow the rxrpc_call struct pointer, the
     wire call ID and the user ID/afs_call pointer to be cross-referenced).
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents b4041278 89ca6948
...@@ -419,7 +419,7 @@ int afs_make_call(struct in_addr *addr, struct afs_call *call, gfp_t gfp, ...@@ -419,7 +419,7 @@ int afs_make_call(struct in_addr *addr, struct afs_call *call, gfp_t gfp,
call->state = AFS_CALL_COMPLETE; call->state = AFS_CALL_COMPLETE;
if (ret != -ECONNABORTED) { if (ret != -ECONNABORTED) {
rxrpc_kernel_abort_call(afs_socket, rxcall, RX_USER_ABORT, rxrpc_kernel_abort_call(afs_socket, rxcall, RX_USER_ABORT,
-ret, "KSD"); ret, "KSD");
} else { } else {
abort_code = 0; abort_code = 0;
offset = 0; offset = 0;
...@@ -478,12 +478,12 @@ static void afs_deliver_to_call(struct afs_call *call) ...@@ -478,12 +478,12 @@ static void afs_deliver_to_call(struct afs_call *call)
case -ENOTCONN: case -ENOTCONN:
abort_code = RX_CALL_DEAD; abort_code = RX_CALL_DEAD;
rxrpc_kernel_abort_call(afs_socket, call->rxcall, rxrpc_kernel_abort_call(afs_socket, call->rxcall,
abort_code, -ret, "KNC"); abort_code, ret, "KNC");
goto save_error; goto save_error;
case -ENOTSUPP: case -ENOTSUPP:
abort_code = RXGEN_OPCODE; abort_code = RXGEN_OPCODE;
rxrpc_kernel_abort_call(afs_socket, call->rxcall, rxrpc_kernel_abort_call(afs_socket, call->rxcall,
abort_code, -ret, "KIV"); abort_code, ret, "KIV");
goto save_error; goto save_error;
case -ENODATA: case -ENODATA:
case -EBADMSG: case -EBADMSG:
...@@ -493,7 +493,7 @@ static void afs_deliver_to_call(struct afs_call *call) ...@@ -493,7 +493,7 @@ static void afs_deliver_to_call(struct afs_call *call)
if (call->state != AFS_CALL_AWAIT_REPLY) if (call->state != AFS_CALL_AWAIT_REPLY)
abort_code = RXGEN_SS_UNMARSHAL; abort_code = RXGEN_SS_UNMARSHAL;
rxrpc_kernel_abort_call(afs_socket, call->rxcall, rxrpc_kernel_abort_call(afs_socket, call->rxcall,
abort_code, EBADMSG, "KUM"); abort_code, -EBADMSG, "KUM");
goto save_error; goto save_error;
} }
} }
...@@ -754,7 +754,7 @@ void afs_send_empty_reply(struct afs_call *call) ...@@ -754,7 +754,7 @@ void afs_send_empty_reply(struct afs_call *call)
case -ENOMEM: case -ENOMEM:
_debug("oom"); _debug("oom");
rxrpc_kernel_abort_call(afs_socket, call->rxcall, rxrpc_kernel_abort_call(afs_socket, call->rxcall,
RX_USER_ABORT, ENOMEM, "KOO"); RX_USER_ABORT, -ENOMEM, "KOO");
default: default:
_leave(" [error]"); _leave(" [error]");
return; return;
...@@ -792,7 +792,7 @@ void afs_send_simple_reply(struct afs_call *call, const void *buf, size_t len) ...@@ -792,7 +792,7 @@ void afs_send_simple_reply(struct afs_call *call, const void *buf, size_t len)
if (n == -ENOMEM) { if (n == -ENOMEM) {
_debug("oom"); _debug("oom");
rxrpc_kernel_abort_call(afs_socket, call->rxcall, rxrpc_kernel_abort_call(afs_socket, call->rxcall,
RX_USER_ABORT, ENOMEM, "KOO"); RX_USER_ABORT, -ENOMEM, "KOO");
} }
_leave(" [error]"); _leave(" [error]");
} }
......
...@@ -39,7 +39,7 @@ int rxrpc_kernel_send_data(struct socket *, struct rxrpc_call *, ...@@ -39,7 +39,7 @@ int rxrpc_kernel_send_data(struct socket *, struct rxrpc_call *,
struct msghdr *, size_t); struct msghdr *, size_t);
int rxrpc_kernel_recv_data(struct socket *, struct rxrpc_call *, int rxrpc_kernel_recv_data(struct socket *, struct rxrpc_call *,
void *, size_t, size_t *, bool, u32 *); void *, size_t, size_t *, bool, u32 *);
void rxrpc_kernel_abort_call(struct socket *, struct rxrpc_call *, bool rxrpc_kernel_abort_call(struct socket *, struct rxrpc_call *,
u32, int, const char *); u32, int, const char *);
void rxrpc_kernel_end_call(struct socket *, struct rxrpc_call *); void rxrpc_kernel_end_call(struct socket *, struct rxrpc_call *);
void rxrpc_kernel_get_peer(struct socket *, struct rxrpc_call *, void rxrpc_kernel_get_peer(struct socket *, struct rxrpc_call *,
......
...@@ -683,6 +683,57 @@ TRACE_EVENT(rxrpc_rx_ack, ...@@ -683,6 +683,57 @@ TRACE_EVENT(rxrpc_rx_ack,
__entry->n_acks) __entry->n_acks)
); );
TRACE_EVENT(rxrpc_rx_abort,
TP_PROTO(struct rxrpc_call *call, rxrpc_serial_t serial,
u32 abort_code),
TP_ARGS(call, serial, abort_code),
TP_STRUCT__entry(
__field(struct rxrpc_call *, call )
__field(rxrpc_serial_t, serial )
__field(u32, abort_code )
),
TP_fast_assign(
__entry->call = call;
__entry->serial = serial;
__entry->abort_code = abort_code;
),
TP_printk("c=%p ABORT %08x ac=%d",
__entry->call,
__entry->serial,
__entry->abort_code)
);
TRACE_EVENT(rxrpc_rx_rwind_change,
TP_PROTO(struct rxrpc_call *call, rxrpc_serial_t serial,
u32 rwind, bool wake),
TP_ARGS(call, serial, rwind, wake),
TP_STRUCT__entry(
__field(struct rxrpc_call *, call )
__field(rxrpc_serial_t, serial )
__field(u32, rwind )
__field(bool, wake )
),
TP_fast_assign(
__entry->call = call;
__entry->serial = serial;
__entry->rwind = rwind;
__entry->wake = wake;
),
TP_printk("c=%p %08x rw=%u%s",
__entry->call,
__entry->serial,
__entry->rwind,
__entry->wake ? " wake" : "")
);
TRACE_EVENT(rxrpc_tx_data, TRACE_EVENT(rxrpc_tx_data,
TP_PROTO(struct rxrpc_call *call, rxrpc_seq_t seq, TP_PROTO(struct rxrpc_call *call, rxrpc_seq_t seq,
rxrpc_serial_t serial, u8 flags, bool retrans, bool lose), rxrpc_serial_t serial, u8 flags, bool retrans, bool lose),
...@@ -1087,6 +1138,56 @@ TRACE_EVENT(rxrpc_improper_term, ...@@ -1087,6 +1138,56 @@ TRACE_EVENT(rxrpc_improper_term,
__entry->abort_code) __entry->abort_code)
); );
TRACE_EVENT(rxrpc_rx_eproto,
TP_PROTO(struct rxrpc_call *call, rxrpc_serial_t serial,
const char *why),
TP_ARGS(call, serial, why),
TP_STRUCT__entry(
__field(struct rxrpc_call *, call )
__field(rxrpc_serial_t, serial )
__field(const char *, why )
),
TP_fast_assign(
__entry->call = call;
__entry->serial = serial;
__entry->why = why;
),
TP_printk("c=%p EPROTO %08x %s",
__entry->call,
__entry->serial,
__entry->why)
);
TRACE_EVENT(rxrpc_connect_call,
TP_PROTO(struct rxrpc_call *call),
TP_ARGS(call),
TP_STRUCT__entry(
__field(struct rxrpc_call *, call )
__field(unsigned long, user_call_ID )
__field(u32, cid )
__field(u32, call_id )
),
TP_fast_assign(
__entry->call = call;
__entry->user_call_ID = call->user_call_ID;
__entry->cid = call->cid;
__entry->call_id = call->call_id;
),
TP_printk("c=%p u=%p %08x:%08x",
__entry->call,
(void *)__entry->user_call_ID,
__entry->cid,
__entry->call_id)
);
#endif /* _TRACE_RXRPC_H */ #endif /* _TRACE_RXRPC_H */
/* This part must be outside protection */ /* This part must be outside protection */
......
...@@ -739,6 +739,25 @@ static inline bool rxrpc_abort_call(const char *why, struct rxrpc_call *call, ...@@ -739,6 +739,25 @@ static inline bool rxrpc_abort_call(const char *why, struct rxrpc_call *call,
return ret; return ret;
} }
/*
* Abort a call due to a protocol error.
*/
static inline bool __rxrpc_abort_eproto(struct rxrpc_call *call,
struct sk_buff *skb,
const char *eproto_why,
const char *why,
u32 abort_code)
{
struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
trace_rxrpc_rx_eproto(call, sp->hdr.serial, eproto_why);
return rxrpc_abort_call(why, call, sp->hdr.seq, abort_code, -EPROTO);
}
#define rxrpc_abort_eproto(call, skb, eproto_why, abort_why, abort_code) \
__rxrpc_abort_eproto((call), (skb), tracepoint_string(eproto_why), \
(abort_why), (abort_code))
/* /*
* conn_client.c * conn_client.c
*/ */
......
...@@ -413,11 +413,11 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, ...@@ -413,11 +413,11 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
case RXRPC_CONN_REMOTELY_ABORTED: case RXRPC_CONN_REMOTELY_ABORTED:
rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED,
conn->remote_abort, ECONNABORTED); conn->remote_abort, -ECONNABORTED);
break; break;
case RXRPC_CONN_LOCALLY_ABORTED: case RXRPC_CONN_LOCALLY_ABORTED:
rxrpc_abort_call("CON", call, sp->hdr.seq, rxrpc_abort_call("CON", call, sp->hdr.seq,
conn->local_abort, ECONNABORTED); conn->local_abort, -ECONNABORTED);
break; break;
default: default:
BUG(); BUG();
...@@ -600,7 +600,7 @@ int rxrpc_reject_call(struct rxrpc_sock *rx) ...@@ -600,7 +600,7 @@ int rxrpc_reject_call(struct rxrpc_sock *rx)
write_lock_bh(&call->state_lock); write_lock_bh(&call->state_lock);
switch (call->state) { switch (call->state) {
case RXRPC_CALL_SERVER_ACCEPTING: case RXRPC_CALL_SERVER_ACCEPTING:
__rxrpc_abort_call("REJ", call, 1, RX_USER_ABORT, ECONNABORTED); __rxrpc_abort_call("REJ", call, 1, RX_USER_ABORT, -ECONNABORTED);
abort = true; abort = true;
/* fall through */ /* fall through */
case RXRPC_CALL_COMPLETE: case RXRPC_CALL_COMPLETE:
......
...@@ -386,7 +386,7 @@ void rxrpc_process_call(struct work_struct *work) ...@@ -386,7 +386,7 @@ void rxrpc_process_call(struct work_struct *work)
now = ktime_get_real(); now = ktime_get_real();
if (ktime_before(call->expire_at, now)) { if (ktime_before(call->expire_at, now)) {
rxrpc_abort_call("EXP", call, 0, RX_CALL_TIMEOUT, ETIME); rxrpc_abort_call("EXP", call, 0, RX_CALL_TIMEOUT, -ETIME);
set_bit(RXRPC_CALL_EV_ABORT, &call->events); set_bit(RXRPC_CALL_EV_ABORT, &call->events);
goto recheck_state; goto recheck_state;
} }
......
...@@ -486,7 +486,7 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx) ...@@ -486,7 +486,7 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx)
call = list_entry(rx->to_be_accepted.next, call = list_entry(rx->to_be_accepted.next,
struct rxrpc_call, accept_link); struct rxrpc_call, accept_link);
list_del(&call->accept_link); list_del(&call->accept_link);
rxrpc_abort_call("SKR", call, 0, RX_CALL_DEAD, ECONNRESET); rxrpc_abort_call("SKR", call, 0, RX_CALL_DEAD, -ECONNRESET);
rxrpc_put_call(call, rxrpc_call_put); rxrpc_put_call(call, rxrpc_call_put);
} }
...@@ -494,7 +494,7 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx) ...@@ -494,7 +494,7 @@ void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx)
call = list_entry(rx->sock_calls.next, call = list_entry(rx->sock_calls.next,
struct rxrpc_call, sock_link); struct rxrpc_call, sock_link);
rxrpc_get_call(call, rxrpc_call_got); rxrpc_get_call(call, rxrpc_call_got);
rxrpc_abort_call("SKT", call, 0, RX_CALL_DEAD, ECONNRESET); rxrpc_abort_call("SKT", call, 0, RX_CALL_DEAD, -ECONNRESET);
rxrpc_send_abort_packet(call); rxrpc_send_abort_packet(call);
rxrpc_release_call(rx, call); rxrpc_release_call(rx, call);
rxrpc_put_call(call, rxrpc_call_put); rxrpc_put_call(call, rxrpc_call_put);
......
...@@ -550,6 +550,7 @@ static void rxrpc_activate_one_channel(struct rxrpc_connection *conn, ...@@ -550,6 +550,7 @@ static void rxrpc_activate_one_channel(struct rxrpc_connection *conn,
call->cid = conn->proto.cid | channel; call->cid = conn->proto.cid | channel;
call->call_id = call_id; call->call_id = call_id;
trace_rxrpc_connect_call(call);
_net("CONNECT call %08x:%08x as call %d on conn %d", _net("CONNECT call %08x:%08x as call %d on conn %d",
call->cid, call->call_id, call->debug_id, conn->debug_id); call->cid, call->call_id, call->debug_id, conn->debug_id);
......
...@@ -168,7 +168,7 @@ static void rxrpc_abort_calls(struct rxrpc_connection *conn, ...@@ -168,7 +168,7 @@ static void rxrpc_abort_calls(struct rxrpc_connection *conn,
* generate a connection-level abort * generate a connection-level abort
*/ */
static int rxrpc_abort_connection(struct rxrpc_connection *conn, static int rxrpc_abort_connection(struct rxrpc_connection *conn,
u32 error, u32 abort_code) int error, u32 abort_code)
{ {
struct rxrpc_wire_header whdr; struct rxrpc_wire_header whdr;
struct msghdr msg; struct msghdr msg;
...@@ -281,14 +281,17 @@ static int rxrpc_process_event(struct rxrpc_connection *conn, ...@@ -281,14 +281,17 @@ static int rxrpc_process_event(struct rxrpc_connection *conn,
case RXRPC_PACKET_TYPE_ABORT: case RXRPC_PACKET_TYPE_ABORT:
if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
&wtmp, sizeof(wtmp)) < 0) &wtmp, sizeof(wtmp)) < 0) {
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial,
tracepoint_string("bad_abort"));
return -EPROTO; return -EPROTO;
}
abort_code = ntohl(wtmp); abort_code = ntohl(wtmp);
_proto("Rx ABORT %%%u { ac=%d }", sp->hdr.serial, abort_code); _proto("Rx ABORT %%%u { ac=%d }", sp->hdr.serial, abort_code);
conn->state = RXRPC_CONN_REMOTELY_ABORTED; conn->state = RXRPC_CONN_REMOTELY_ABORTED;
rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED, rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED,
abort_code, ECONNABORTED); abort_code, -ECONNABORTED);
return -ECONNABORTED; return -ECONNABORTED;
case RXRPC_PACKET_TYPE_CHALLENGE: case RXRPC_PACKET_TYPE_CHALLENGE:
...@@ -327,7 +330,8 @@ static int rxrpc_process_event(struct rxrpc_connection *conn, ...@@ -327,7 +330,8 @@ static int rxrpc_process_event(struct rxrpc_connection *conn,
return 0; return 0;
default: default:
_leave(" = -EPROTO [%u]", sp->hdr.type); trace_rxrpc_rx_eproto(NULL, sp->hdr.serial,
tracepoint_string("bad_conn_pkt"));
return -EPROTO; return -EPROTO;
} }
} }
...@@ -370,7 +374,7 @@ static void rxrpc_secure_connection(struct rxrpc_connection *conn) ...@@ -370,7 +374,7 @@ static void rxrpc_secure_connection(struct rxrpc_connection *conn)
abort: abort:
_debug("abort %d, %d", ret, abort_code); _debug("abort %d, %d", ret, abort_code);
rxrpc_abort_connection(conn, -ret, abort_code); rxrpc_abort_connection(conn, ret, abort_code);
_leave(" [aborted]"); _leave(" [aborted]");
} }
...@@ -419,9 +423,8 @@ void rxrpc_process_connection(struct work_struct *work) ...@@ -419,9 +423,8 @@ void rxrpc_process_connection(struct work_struct *work)
goto out; goto out;
protocol_error: protocol_error:
if (rxrpc_abort_connection(conn, -ret, abort_code) < 0) if (rxrpc_abort_connection(conn, ret, abort_code) < 0)
goto requeue_and_leave; goto requeue_and_leave;
rxrpc_free_skb(skb, rxrpc_skb_rx_freed); rxrpc_free_skb(skb, rxrpc_skb_rx_freed);
_leave(" [EPROTO]");
goto out; goto out;
} }
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
static void rxrpc_proto_abort(const char *why, static void rxrpc_proto_abort(const char *why,
struct rxrpc_call *call, rxrpc_seq_t seq) struct rxrpc_call *call, rxrpc_seq_t seq)
{ {
if (rxrpc_abort_call(why, call, seq, RX_PROTOCOL_ERROR, EBADMSG)) { if (rxrpc_abort_call(why, call, seq, RX_PROTOCOL_ERROR, -EBADMSG)) {
set_bit(RXRPC_CALL_EV_ABORT, &call->events); set_bit(RXRPC_CALL_EV_ABORT, &call->events);
rxrpc_queue_call(call); rxrpc_queue_call(call);
} }
...@@ -665,6 +665,8 @@ static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -665,6 +665,8 @@ static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb,
rwind = RXRPC_RXTX_BUFF_SIZE - 1; rwind = RXRPC_RXTX_BUFF_SIZE - 1;
if (rwind > call->tx_winsize) if (rwind > call->tx_winsize)
wake = true; wake = true;
trace_rxrpc_rx_rwind_change(call, sp->hdr.serial,
ntohl(ackinfo->rwind), wake);
call->tx_winsize = rwind; call->tx_winsize = rwind;
} }
...@@ -877,7 +879,7 @@ static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb) ...@@ -877,7 +879,7 @@ static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb)
} }
/* /*
* Process an ABORT packet. * Process an ABORT packet directed at a call.
*/ */
static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb) static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb)
{ {
...@@ -892,10 +894,12 @@ static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb) ...@@ -892,10 +894,12 @@ static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb)
&wtmp, sizeof(wtmp)) >= 0) &wtmp, sizeof(wtmp)) >= 0)
abort_code = ntohl(wtmp); abort_code = ntohl(wtmp);
trace_rxrpc_rx_abort(call, sp->hdr.serial, abort_code);
_proto("Rx ABORT %%%u { %x }", sp->hdr.serial, abort_code); _proto("Rx ABORT %%%u { %x }", sp->hdr.serial, abort_code);
if (rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, if (rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED,
abort_code, ECONNABORTED)) abort_code, -ECONNABORTED))
rxrpc_notify_socket(call); rxrpc_notify_socket(call);
} }
...@@ -958,7 +962,7 @@ static void rxrpc_input_implicit_end_call(struct rxrpc_connection *conn, ...@@ -958,7 +962,7 @@ static void rxrpc_input_implicit_end_call(struct rxrpc_connection *conn,
case RXRPC_CALL_COMPLETE: case RXRPC_CALL_COMPLETE:
break; break;
default: default:
if (rxrpc_abort_call("IMP", call, 0, RX_CALL_DEAD, ESHUTDOWN)) { if (rxrpc_abort_call("IMP", call, 0, RX_CALL_DEAD, -ESHUTDOWN)) {
set_bit(RXRPC_CALL_EV_ABORT, &call->events); set_bit(RXRPC_CALL_EV_ABORT, &call->events);
rxrpc_queue_call(call); rxrpc_queue_call(call);
} }
...@@ -1017,8 +1021,11 @@ int rxrpc_extract_header(struct rxrpc_skb_priv *sp, struct sk_buff *skb) ...@@ -1017,8 +1021,11 @@ int rxrpc_extract_header(struct rxrpc_skb_priv *sp, struct sk_buff *skb)
struct rxrpc_wire_header whdr; struct rxrpc_wire_header whdr;
/* dig out the RxRPC connection details */ /* dig out the RxRPC connection details */
if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0) if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0) {
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial,
tracepoint_string("bad_hdr"));
return -EBADMSG; return -EBADMSG;
}
memset(sp, 0, sizeof(*sp)); memset(sp, 0, sizeof(*sp));
sp->hdr.epoch = ntohl(whdr.epoch); sp->hdr.epoch = ntohl(whdr.epoch);
......
...@@ -46,7 +46,10 @@ static int none_respond_to_challenge(struct rxrpc_connection *conn, ...@@ -46,7 +46,10 @@ static int none_respond_to_challenge(struct rxrpc_connection *conn,
struct sk_buff *skb, struct sk_buff *skb,
u32 *_abort_code) u32 *_abort_code)
{ {
*_abort_code = RX_PROTOCOL_ERROR; struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial,
tracepoint_string("chall_none"));
return -EPROTO; return -EPROTO;
} }
...@@ -54,7 +57,10 @@ static int none_verify_response(struct rxrpc_connection *conn, ...@@ -54,7 +57,10 @@ static int none_verify_response(struct rxrpc_connection *conn,
struct sk_buff *skb, struct sk_buff *skb,
u32 *_abort_code) u32 *_abort_code)
{ {
*_abort_code = RX_PROTOCOL_ERROR; struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial,
tracepoint_string("resp_none"));
return -EPROTO; return -EPROTO;
} }
......
...@@ -296,7 +296,7 @@ void rxrpc_peer_error_distributor(struct work_struct *work) ...@@ -296,7 +296,7 @@ void rxrpc_peer_error_distributor(struct work_struct *work)
hlist_del_init(&call->error_link); hlist_del_init(&call->error_link);
rxrpc_see_call(call); rxrpc_see_call(call);
if (rxrpc_set_call_completion(call, compl, 0, error)) if (rxrpc_set_call_completion(call, compl, 0, -error))
rxrpc_notify_socket(call); rxrpc_notify_socket(call);
} }
......
...@@ -83,11 +83,11 @@ static int rxrpc_recvmsg_term(struct rxrpc_call *call, struct msghdr *msg) ...@@ -83,11 +83,11 @@ static int rxrpc_recvmsg_term(struct rxrpc_call *call, struct msghdr *msg)
ret = put_cmsg(msg, SOL_RXRPC, RXRPC_ABORT, 4, &tmp); ret = put_cmsg(msg, SOL_RXRPC, RXRPC_ABORT, 4, &tmp);
break; break;
case RXRPC_CALL_NETWORK_ERROR: case RXRPC_CALL_NETWORK_ERROR:
tmp = call->error; tmp = -call->error;
ret = put_cmsg(msg, SOL_RXRPC, RXRPC_NET_ERROR, 4, &tmp); ret = put_cmsg(msg, SOL_RXRPC, RXRPC_NET_ERROR, 4, &tmp);
break; break;
case RXRPC_CALL_LOCAL_ERROR: case RXRPC_CALL_LOCAL_ERROR:
tmp = call->error; tmp = -call->error;
ret = put_cmsg(msg, SOL_RXRPC, RXRPC_LOCAL_ERROR, 4, &tmp); ret = put_cmsg(msg, SOL_RXRPC, RXRPC_LOCAL_ERROR, 4, &tmp);
break; break;
default: default:
...@@ -682,14 +682,16 @@ int rxrpc_kernel_recv_data(struct socket *sock, struct rxrpc_call *call, ...@@ -682,14 +682,16 @@ int rxrpc_kernel_recv_data(struct socket *sock, struct rxrpc_call *call,
return ret; return ret;
short_data: short_data:
trace_rxrpc_rx_eproto(call, 0, tracepoint_string("short_data"));
ret = -EBADMSG; ret = -EBADMSG;
goto out; goto out;
excess_data: excess_data:
trace_rxrpc_rx_eproto(call, 0, tracepoint_string("excess_data"));
ret = -EMSGSIZE; ret = -EMSGSIZE;
goto out; goto out;
call_complete: call_complete:
*_abort = call->abort_code; *_abort = call->abort_code;
ret = -call->error; ret = call->error;
if (call->completion == RXRPC_CALL_SUCCEEDED) { if (call->completion == RXRPC_CALL_SUCCEEDED) {
ret = 1; ret = 1;
if (size > 0) if (size > 0)
......
...@@ -148,15 +148,13 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, ...@@ -148,15 +148,13 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call,
u32 data_size, u32 data_size,
void *sechdr) void *sechdr)
{ {
struct rxrpc_skb_priv *sp; struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
struct rxkad_level1_hdr hdr; struct rxkad_level1_hdr hdr;
struct rxrpc_crypt iv; struct rxrpc_crypt iv;
struct scatterlist sg; struct scatterlist sg;
u16 check; u16 check;
sp = rxrpc_skb(skb);
_enter(""); _enter("");
check = sp->hdr.seq ^ call->call_id; check = sp->hdr.seq ^ call->call_id;
...@@ -323,6 +321,7 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -323,6 +321,7 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
struct rxrpc_crypt iv; struct rxrpc_crypt iv;
struct scatterlist sg[16]; struct scatterlist sg[16];
struct sk_buff *trailer; struct sk_buff *trailer;
bool aborted;
u32 data_size, buf; u32 data_size, buf;
u16 check; u16 check;
int nsg; int nsg;
...@@ -330,7 +329,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -330,7 +329,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
_enter(""); _enter("");
if (len < 8) { if (len < 8) {
rxrpc_abort_call("V1H", call, seq, RXKADSEALEDINCON, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H",
RXKADSEALEDINCON);
goto protocol_error; goto protocol_error;
} }
...@@ -355,7 +355,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -355,7 +355,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
/* Extract the decrypted packet length */ /* Extract the decrypted packet length */
if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
rxrpc_abort_call("XV1", call, seq, RXKADDATALEN, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1",
RXKADDATALEN);
goto protocol_error; goto protocol_error;
} }
offset += sizeof(sechdr); offset += sizeof(sechdr);
...@@ -368,12 +369,14 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -368,12 +369,14 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
check ^= seq ^ call->call_id; check ^= seq ^ call->call_id;
check &= 0xffff; check &= 0xffff;
if (check != 0) { if (check != 0) {
rxrpc_abort_call("V1C", call, seq, RXKADSEALEDINCON, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C",
RXKADSEALEDINCON);
goto protocol_error; goto protocol_error;
} }
if (data_size > len) { if (data_size > len) {
rxrpc_abort_call("V1L", call, seq, RXKADDATALEN, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L",
RXKADDATALEN);
goto protocol_error; goto protocol_error;
} }
...@@ -381,8 +384,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -381,8 +384,8 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb,
return 0; return 0;
protocol_error: protocol_error:
if (aborted)
rxrpc_send_abort_packet(call); rxrpc_send_abort_packet(call);
_leave(" = -EPROTO");
return -EPROTO; return -EPROTO;
nomem: nomem:
...@@ -403,6 +406,7 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -403,6 +406,7 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
struct rxrpc_crypt iv; struct rxrpc_crypt iv;
struct scatterlist _sg[4], *sg; struct scatterlist _sg[4], *sg;
struct sk_buff *trailer; struct sk_buff *trailer;
bool aborted;
u32 data_size, buf; u32 data_size, buf;
u16 check; u16 check;
int nsg; int nsg;
...@@ -410,7 +414,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -410,7 +414,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
_enter(",{%d}", skb->len); _enter(",{%d}", skb->len);
if (len < 8) { if (len < 8) {
rxrpc_abort_call("V2H", call, seq, RXKADSEALEDINCON, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H",
RXKADSEALEDINCON);
goto protocol_error; goto protocol_error;
} }
...@@ -445,7 +450,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -445,7 +450,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
/* Extract the decrypted packet length */ /* Extract the decrypted packet length */
if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) {
rxrpc_abort_call("XV2", call, seq, RXKADDATALEN, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2",
RXKADDATALEN);
goto protocol_error; goto protocol_error;
} }
offset += sizeof(sechdr); offset += sizeof(sechdr);
...@@ -458,12 +464,14 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -458,12 +464,14 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
check ^= seq ^ call->call_id; check ^= seq ^ call->call_id;
check &= 0xffff; check &= 0xffff;
if (check != 0) { if (check != 0) {
rxrpc_abort_call("V2C", call, seq, RXKADSEALEDINCON, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C",
RXKADSEALEDINCON);
goto protocol_error; goto protocol_error;
} }
if (data_size > len) { if (data_size > len) {
rxrpc_abort_call("V2L", call, seq, RXKADDATALEN, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L",
RXKADDATALEN);
goto protocol_error; goto protocol_error;
} }
...@@ -471,8 +479,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -471,8 +479,8 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb,
return 0; return 0;
protocol_error: protocol_error:
if (aborted)
rxrpc_send_abort_packet(call); rxrpc_send_abort_packet(call);
_leave(" = -EPROTO");
return -EPROTO; return -EPROTO;
nomem: nomem:
...@@ -491,6 +499,7 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -491,6 +499,7 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher); SKCIPHER_REQUEST_ON_STACK(req, call->conn->cipher);
struct rxrpc_crypt iv; struct rxrpc_crypt iv;
struct scatterlist sg; struct scatterlist sg;
bool aborted;
u16 cksum; u16 cksum;
u32 x, y; u32 x, y;
...@@ -522,10 +531,9 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -522,10 +531,9 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
cksum = 1; /* zero checksums are not permitted */ cksum = 1; /* zero checksums are not permitted */
if (cksum != expected_cksum) { if (cksum != expected_cksum) {
rxrpc_abort_call("VCK", call, seq, RXKADSEALEDINCON, EPROTO); aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK",
rxrpc_send_abort_packet(call); RXKADSEALEDINCON);
_leave(" = -EPROTO [csum failed]"); goto protocol_error;
return -EPROTO;
} }
switch (call->conn->params.security_level) { switch (call->conn->params.security_level) {
...@@ -538,6 +546,11 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, ...@@ -538,6 +546,11 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb,
default: default:
return -ENOANO; return -ENOANO;
} }
protocol_error:
if (aborted)
rxrpc_send_abort_packet(call);
return -EPROTO;
} }
/* /*
...@@ -754,22 +767,23 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, ...@@ -754,22 +767,23 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
struct rxkad_response resp struct rxkad_response resp
__attribute__((aligned(8))); /* must be aligned for crypto */ __attribute__((aligned(8))); /* must be aligned for crypto */
struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
const char *eproto;
u32 version, nonce, min_level, abort_code; u32 version, nonce, min_level, abort_code;
int ret; int ret;
_enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key)); _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key));
if (!conn->params.key) { eproto = tracepoint_string("chall_no_key");
_leave(" = -EPROTO [no key]"); abort_code = RX_PROTOCOL_ERROR;
return -EPROTO; if (!conn->params.key)
} goto protocol_error;
abort_code = RXKADEXPIRED;
ret = key_validate(conn->params.key); ret = key_validate(conn->params.key);
if (ret < 0) { if (ret < 0)
*_abort_code = RXKADEXPIRED; goto other_error;
return ret;
}
eproto = tracepoint_string("chall_short");
abort_code = RXKADPACKETSHORT; abort_code = RXKADPACKETSHORT;
if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
&challenge, sizeof(challenge)) < 0) &challenge, sizeof(challenge)) < 0)
...@@ -782,13 +796,15 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, ...@@ -782,13 +796,15 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
_proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }", _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }",
sp->hdr.serial, version, nonce, min_level); sp->hdr.serial, version, nonce, min_level);
eproto = tracepoint_string("chall_ver");
abort_code = RXKADINCONSISTENCY; abort_code = RXKADINCONSISTENCY;
if (version != RXKAD_VERSION) if (version != RXKAD_VERSION)
goto protocol_error; goto protocol_error;
abort_code = RXKADLEVELFAIL; abort_code = RXKADLEVELFAIL;
ret = -EACCES;
if (conn->params.security_level < min_level) if (conn->params.security_level < min_level)
goto protocol_error; goto other_error;
token = conn->params.key->payload.data[0]; token = conn->params.key->payload.data[0];
...@@ -815,28 +831,34 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, ...@@ -815,28 +831,34 @@ static int rxkad_respond_to_challenge(struct rxrpc_connection *conn,
return rxkad_send_response(conn, &sp->hdr, &resp, token->kad); return rxkad_send_response(conn, &sp->hdr, &resp, token->kad);
protocol_error: protocol_error:
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
ret = -EPROTO;
other_error:
*_abort_code = abort_code; *_abort_code = abort_code;
_leave(" = -EPROTO [%d]", abort_code); return ret;
return -EPROTO;
} }
/* /*
* decrypt the kerberos IV ticket in the response * decrypt the kerberos IV ticket in the response
*/ */
static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
struct sk_buff *skb,
void *ticket, size_t ticket_len, void *ticket, size_t ticket_len,
struct rxrpc_crypt *_session_key, struct rxrpc_crypt *_session_key,
time_t *_expiry, time_t *_expiry,
u32 *_abort_code) u32 *_abort_code)
{ {
struct skcipher_request *req; struct skcipher_request *req;
struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
struct rxrpc_crypt iv, key; struct rxrpc_crypt iv, key;
struct scatterlist sg[1]; struct scatterlist sg[1];
struct in_addr addr; struct in_addr addr;
unsigned int life; unsigned int life;
const char *eproto;
time_t issue, now; time_t issue, now;
bool little_endian; bool little_endian;
int ret; int ret;
u32 abort_code;
u8 *p, *q, *name, *end; u8 *p, *q, *name, *end;
_enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key)); _enter("{%d},{%x}", conn->debug_id, key_serial(conn->server_key));
...@@ -847,11 +869,11 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, ...@@ -847,11 +869,11 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
if (ret < 0) { if (ret < 0) {
switch (ret) { switch (ret) {
case -EKEYEXPIRED: case -EKEYEXPIRED:
*_abort_code = RXKADEXPIRED; abort_code = RXKADEXPIRED;
goto error; goto other_error;
default: default:
*_abort_code = RXKADNOAUTH; abort_code = RXKADNOAUTH;
goto error; goto other_error;
} }
} }
...@@ -860,13 +882,11 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, ...@@ -860,13 +882,11 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv)); memcpy(&iv, &conn->server_key->payload.data[2], sizeof(iv));
ret = -ENOMEM;
req = skcipher_request_alloc(conn->server_key->payload.data[0], req = skcipher_request_alloc(conn->server_key->payload.data[0],
GFP_NOFS); GFP_NOFS);
if (!req) { if (!req)
*_abort_code = RXKADNOAUTH; goto temporary_error;
ret = -ENOMEM;
goto error;
}
sg_init_one(&sg[0], ticket, ticket_len); sg_init_one(&sg[0], ticket, ticket_len);
skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_callback(req, 0, NULL, NULL);
...@@ -877,11 +897,12 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, ...@@ -877,11 +897,12 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
p = ticket; p = ticket;
end = p + ticket_len; end = p + ticket_len;
#define Z(size) \ #define Z(field) \
({ \ ({ \
u8 *__str = p; \ u8 *__str = p; \
eproto = tracepoint_string("rxkad_bad_"#field); \
q = memchr(p, 0, end - p); \ q = memchr(p, 0, end - p); \
if (!q || q - p > (size)) \ if (!q || q - p > (field##_SZ)) \
goto bad_ticket; \ goto bad_ticket; \
for (; p < q; p++) \ for (; p < q; p++) \
if (!isprint(*p)) \ if (!isprint(*p)) \
...@@ -896,17 +917,18 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, ...@@ -896,17 +917,18 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
p++; p++;
/* extract the authentication name */ /* extract the authentication name */
name = Z(ANAME_SZ); name = Z(ANAME);
_debug("KIV ANAME: %s", name); _debug("KIV ANAME: %s", name);
/* extract the principal's instance */ /* extract the principal's instance */
name = Z(INST_SZ); name = Z(INST);
_debug("KIV INST : %s", name); _debug("KIV INST : %s", name);
/* extract the principal's authentication domain */ /* extract the principal's authentication domain */
name = Z(REALM_SZ); name = Z(REALM);
_debug("KIV REALM: %s", name); _debug("KIV REALM: %s", name);
eproto = tracepoint_string("rxkad_bad_len");
if (end - p < 4 + 8 + 4 + 2) if (end - p < 4 + 8 + 4 + 2)
goto bad_ticket; goto bad_ticket;
...@@ -941,36 +963,37 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, ...@@ -941,36 +963,37 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn,
/* check the ticket is in date */ /* check the ticket is in date */
if (issue > now) { if (issue > now) {
*_abort_code = RXKADNOAUTH; abort_code = RXKADNOAUTH;
ret = -EKEYREJECTED; ret = -EKEYREJECTED;
goto error; goto other_error;
} }
if (issue < now - life) { if (issue < now - life) {
*_abort_code = RXKADEXPIRED; abort_code = RXKADEXPIRED;
ret = -EKEYEXPIRED; ret = -EKEYEXPIRED;
goto error; goto other_error;
} }
*_expiry = issue + life; *_expiry = issue + life;
/* get the service name */ /* get the service name */
name = Z(SNAME_SZ); name = Z(SNAME);
_debug("KIV SNAME: %s", name); _debug("KIV SNAME: %s", name);
/* get the service instance name */ /* get the service instance name */
name = Z(INST_SZ); name = Z(INST);
_debug("KIV SINST: %s", name); _debug("KIV SINST: %s", name);
return 0;
ret = 0;
error:
_leave(" = %d", ret);
return ret;
bad_ticket: bad_ticket:
*_abort_code = RXKADBADTICKET; trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
ret = -EBADMSG; abort_code = RXKADBADTICKET;
goto error; ret = -EPROTO;
other_error:
*_abort_code = abort_code;
return ret;
temporary_error:
return ret;
} }
/* /*
...@@ -1020,6 +1043,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1020,6 +1043,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
__attribute__((aligned(8))); /* must be aligned for crypto */ __attribute__((aligned(8))); /* must be aligned for crypto */
struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
struct rxrpc_crypt session_key; struct rxrpc_crypt session_key;
const char *eproto;
time_t expiry; time_t expiry;
void *ticket; void *ticket;
u32 abort_code, version, kvno, ticket_len, level; u32 abort_code, version, kvno, ticket_len, level;
...@@ -1028,6 +1052,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1028,6 +1052,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
_enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key)); _enter("{%d,%x}", conn->debug_id, key_serial(conn->server_key));
eproto = tracepoint_string("rxkad_rsp_short");
abort_code = RXKADPACKETSHORT; abort_code = RXKADPACKETSHORT;
if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
&response, sizeof(response)) < 0) &response, sizeof(response)) < 0)
...@@ -1041,40 +1066,43 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1041,40 +1066,43 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
_proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }", _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }",
sp->hdr.serial, version, kvno, ticket_len); sp->hdr.serial, version, kvno, ticket_len);
eproto = tracepoint_string("rxkad_rsp_ver");
abort_code = RXKADINCONSISTENCY; abort_code = RXKADINCONSISTENCY;
if (version != RXKAD_VERSION) if (version != RXKAD_VERSION)
goto protocol_error; goto protocol_error;
eproto = tracepoint_string("rxkad_rsp_tktlen");
abort_code = RXKADTICKETLEN; abort_code = RXKADTICKETLEN;
if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN) if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN)
goto protocol_error; goto protocol_error;
eproto = tracepoint_string("rxkad_rsp_unkkey");
abort_code = RXKADUNKNOWNKEY; abort_code = RXKADUNKNOWNKEY;
if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5) if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5)
goto protocol_error; goto protocol_error;
/* extract the kerberos ticket and decrypt and decode it */ /* extract the kerberos ticket and decrypt and decode it */
ret = -ENOMEM;
ticket = kmalloc(ticket_len, GFP_NOFS); ticket = kmalloc(ticket_len, GFP_NOFS);
if (!ticket) if (!ticket)
return -ENOMEM; goto temporary_error;
eproto = tracepoint_string("rxkad_tkt_short");
abort_code = RXKADPACKETSHORT; abort_code = RXKADPACKETSHORT;
if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header),
ticket, ticket_len) < 0) ticket, ticket_len) < 0)
goto protocol_error_free; goto protocol_error_free;
ret = rxkad_decrypt_ticket(conn, ticket, ticket_len, &session_key, ret = rxkad_decrypt_ticket(conn, skb, ticket, ticket_len, &session_key,
&expiry, &abort_code); &expiry, _abort_code);
if (ret < 0) { if (ret < 0)
*_abort_code = abort_code; goto temporary_error_free;
kfree(ticket);
return ret;
}
/* use the session key from inside the ticket to decrypt the /* use the session key from inside the ticket to decrypt the
* response */ * response */
rxkad_decrypt_response(conn, &response, &session_key); rxkad_decrypt_response(conn, &response, &session_key);
eproto = tracepoint_string("rxkad_rsp_param");
abort_code = RXKADSEALEDINCON; abort_code = RXKADSEALEDINCON;
if (ntohl(response.encrypted.epoch) != conn->proto.epoch) if (ntohl(response.encrypted.epoch) != conn->proto.epoch)
goto protocol_error_free; goto protocol_error_free;
...@@ -1085,6 +1113,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1085,6 +1113,7 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
csum = response.encrypted.checksum; csum = response.encrypted.checksum;
response.encrypted.checksum = 0; response.encrypted.checksum = 0;
rxkad_calc_response_checksum(&response); rxkad_calc_response_checksum(&response);
eproto = tracepoint_string("rxkad_rsp_csum");
if (response.encrypted.checksum != csum) if (response.encrypted.checksum != csum)
goto protocol_error_free; goto protocol_error_free;
...@@ -1093,11 +1122,15 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1093,11 +1122,15 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
struct rxrpc_call *call; struct rxrpc_call *call;
u32 call_id = ntohl(response.encrypted.call_id[i]); u32 call_id = ntohl(response.encrypted.call_id[i]);
eproto = tracepoint_string("rxkad_rsp_callid");
if (call_id > INT_MAX) if (call_id > INT_MAX)
goto protocol_error_unlock; goto protocol_error_unlock;
eproto = tracepoint_string("rxkad_rsp_callctr");
if (call_id < conn->channels[i].call_counter) if (call_id < conn->channels[i].call_counter)
goto protocol_error_unlock; goto protocol_error_unlock;
eproto = tracepoint_string("rxkad_rsp_callst");
if (call_id > conn->channels[i].call_counter) { if (call_id > conn->channels[i].call_counter) {
call = rcu_dereference_protected( call = rcu_dereference_protected(
conn->channels[i].call, conn->channels[i].call,
...@@ -1109,10 +1142,12 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1109,10 +1142,12 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
} }
spin_unlock(&conn->channel_lock); spin_unlock(&conn->channel_lock);
eproto = tracepoint_string("rxkad_rsp_seq");
abort_code = RXKADOUTOFSEQUENCE; abort_code = RXKADOUTOFSEQUENCE;
if (ntohl(response.encrypted.inc_nonce) != conn->security_nonce + 1) if (ntohl(response.encrypted.inc_nonce) != conn->security_nonce + 1)
goto protocol_error_free; goto protocol_error_free;
eproto = tracepoint_string("rxkad_rsp_level");
abort_code = RXKADLEVELFAIL; abort_code = RXKADLEVELFAIL;
level = ntohl(response.encrypted.level); level = ntohl(response.encrypted.level);
if (level > RXRPC_SECURITY_ENCRYPT) if (level > RXRPC_SECURITY_ENCRYPT)
...@@ -1123,10 +1158,8 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1123,10 +1158,8 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
* this the connection security can be handled in exactly the same way * this the connection security can be handled in exactly the same way
* as for a client connection */ * as for a client connection */
ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno); ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno);
if (ret < 0) { if (ret < 0)
kfree(ticket); goto temporary_error_free;
return ret;
}
kfree(ticket); kfree(ticket);
_leave(" = 0"); _leave(" = 0");
...@@ -1137,9 +1170,18 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, ...@@ -1137,9 +1170,18 @@ static int rxkad_verify_response(struct rxrpc_connection *conn,
protocol_error_free: protocol_error_free:
kfree(ticket); kfree(ticket);
protocol_error: protocol_error:
trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto);
*_abort_code = abort_code; *_abort_code = abort_code;
_leave(" = -EPROTO [%d]", abort_code);
return -EPROTO; return -EPROTO;
temporary_error_free:
kfree(ticket);
temporary_error:
/* Ignore the response packet if we got a temporary error such as
* ENOMEM. We just want to send the challenge again. Note that we
* also come out this way if the ticket decryption fails.
*/
return ret;
} }
/* /*
......
...@@ -556,7 +556,7 @@ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) ...@@ -556,7 +556,7 @@ int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len)
ret = -ESHUTDOWN; ret = -ESHUTDOWN;
} else if (cmd == RXRPC_CMD_SEND_ABORT) { } else if (cmd == RXRPC_CMD_SEND_ABORT) {
ret = 0; ret = 0;
if (rxrpc_abort_call("CMD", call, 0, abort_code, ECONNABORTED)) if (rxrpc_abort_call("CMD", call, 0, abort_code, -ECONNABORTED))
ret = rxrpc_send_abort_packet(call); ret = rxrpc_send_abort_packet(call);
} else if (cmd != RXRPC_CMD_SEND_DATA) { } else if (cmd != RXRPC_CMD_SEND_DATA) {
ret = -EINVAL; ret = -EINVAL;
...@@ -624,6 +624,7 @@ int rxrpc_kernel_send_data(struct socket *sock, struct rxrpc_call *call, ...@@ -624,6 +624,7 @@ int rxrpc_kernel_send_data(struct socket *sock, struct rxrpc_call *call,
break; break;
default: default:
/* Request phase complete for this client call */ /* Request phase complete for this client call */
trace_rxrpc_rx_eproto(call, 0, tracepoint_string("late_send"));
ret = -EPROTO; ret = -EPROTO;
break; break;
} }
...@@ -642,20 +643,24 @@ EXPORT_SYMBOL(rxrpc_kernel_send_data); ...@@ -642,20 +643,24 @@ EXPORT_SYMBOL(rxrpc_kernel_send_data);
* @error: Local error value * @error: Local error value
* @why: 3-char string indicating why. * @why: 3-char string indicating why.
* *
* Allow a kernel service to abort a call, if it's still in an abortable state. * Allow a kernel service to abort a call, if it's still in an abortable state
* and return true if the call was aborted, false if it was already complete.
*/ */
void rxrpc_kernel_abort_call(struct socket *sock, struct rxrpc_call *call, bool rxrpc_kernel_abort_call(struct socket *sock, struct rxrpc_call *call,
u32 abort_code, int error, const char *why) u32 abort_code, int error, const char *why)
{ {
bool aborted;
_enter("{%d},%d,%d,%s", call->debug_id, abort_code, error, why); _enter("{%d},%d,%d,%s", call->debug_id, abort_code, error, why);
mutex_lock(&call->user_mutex); mutex_lock(&call->user_mutex);
if (rxrpc_abort_call(why, call, 0, abort_code, error)) aborted = rxrpc_abort_call(why, call, 0, abort_code, error);
if (aborted)
rxrpc_send_abort_packet(call); rxrpc_send_abort_packet(call);
mutex_unlock(&call->user_mutex); mutex_unlock(&call->user_mutex);
_leave(""); return aborted;
} }
EXPORT_SYMBOL(rxrpc_kernel_abort_call); EXPORT_SYMBOL(rxrpc_kernel_abort_call);
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment