Commit 16bd188e authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'tls-pad-strparser-internal-header-decrypt_ctx-etc'

Jakub Kicinski says:

====================
tls: pad strparser, internal header, decrypt_ctx etc.

A grab bag of non-functional refactoring to make the series
which will let us decrypt into a fresh skb smaller.

Patches in this series are not strictly required to get the
decryption into a fresh skb going, they are more in the "things
which had been annoying me for a while" category.
====================

Link: https://lore.kernel.org/r/20220708010314.1451462-1-kuba@kernel.orgSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 67d7ebde 35560b7f
...@@ -65,15 +65,19 @@ struct _strp_msg { ...@@ -65,15 +65,19 @@ struct _strp_msg {
struct sk_skb_cb { struct sk_skb_cb {
#define SK_SKB_CB_PRIV_LEN 20 #define SK_SKB_CB_PRIV_LEN 20
unsigned char data[SK_SKB_CB_PRIV_LEN]; unsigned char data[SK_SKB_CB_PRIV_LEN];
/* align strp on cache line boundary within skb->cb[] */
unsigned char pad[4];
struct _strp_msg strp; struct _strp_msg strp;
/* temp_reg is a temporary register used for bpf_convert_data_end_access
* when dst_reg == src_reg. /* strp users' data follows */
*/
u64 temp_reg;
struct tls_msg { struct tls_msg {
u8 control; u8 control;
u8 decrypted; u8 decrypted;
} tls; } tls;
/* temp_reg is a temporary register used for bpf_convert_data_end_access
* when dst_reg == src_reg.
*/
u64 temp_reg;
}; };
static inline struct strp_msg *strp_msg(struct sk_buff *skb) static inline struct strp_msg *strp_msg(struct sk_buff *skb)
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#include <linux/crypto.h> #include <linux/crypto.h>
#include <linux/socket.h> #include <linux/socket.h>
#include <linux/tcp.h> #include <linux/tcp.h>
#include <linux/skmsg.h>
#include <linux/mutex.h> #include <linux/mutex.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <linux/rcupdate.h> #include <linux/rcupdate.h>
...@@ -50,6 +49,7 @@ ...@@ -50,6 +49,7 @@
#include <crypto/aead.h> #include <crypto/aead.h>
#include <uapi/linux/tls.h> #include <uapi/linux/tls.h>
struct tls_rec;
/* Maximum data size carried in a TLS record */ /* Maximum data size carried in a TLS record */
#define TLS_MAX_PAYLOAD_SIZE ((size_t)1 << 14) #define TLS_MAX_PAYLOAD_SIZE ((size_t)1 << 14)
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
#define MAX_IV_SIZE 16 #define MAX_IV_SIZE 16
#define TLS_TAG_SIZE 16 #define TLS_TAG_SIZE 16
#define TLS_MAX_REC_SEQ_SIZE 8 #define TLS_MAX_REC_SEQ_SIZE 8
#define TLS_MAX_AAD_SIZE TLS_AAD_SPACE_SIZE
/* For CCM mode, the full 16-bytes of IV is made of '4' fields of given sizes. /* For CCM mode, the full 16-bytes of IV is made of '4' fields of given sizes.
* *
...@@ -77,13 +78,6 @@ ...@@ -77,13 +78,6 @@
#define TLS_AES_CCM_IV_B0_BYTE 2 #define TLS_AES_CCM_IV_B0_BYTE 2
#define TLS_SM4_CCM_IV_B0_BYTE 2 #define TLS_SM4_CCM_IV_B0_BYTE 2
#define __TLS_INC_STATS(net, field) \
__SNMP_INC_STATS((net)->mib.tls_statistics, field)
#define TLS_INC_STATS(net, field) \
SNMP_INC_STATS((net)->mib.tls_statistics, field)
#define TLS_DEC_STATS(net, field) \
SNMP_DEC_STATS((net)->mib.tls_statistics, field)
enum { enum {
TLS_BASE, TLS_BASE,
TLS_SW, TLS_SW,
...@@ -92,32 +86,6 @@ enum { ...@@ -92,32 +86,6 @@ enum {
TLS_NUM_CONFIG, TLS_NUM_CONFIG,
}; };
/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages
* allocated or mapped for each TLS record. After encryption, the records are
* stores in a linked list.
*/
struct tls_rec {
struct list_head list;
int tx_ready;
int tx_flags;
struct sk_msg msg_plaintext;
struct sk_msg msg_encrypted;
/* AAD | msg_plaintext.sg.data | sg_tag */
struct scatterlist sg_aead_in[2];
/* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */
struct scatterlist sg_aead_out[2];
char content_type;
struct scatterlist sg_content_type;
char aad_space[TLS_AAD_SPACE_SIZE];
u8 iv_data[MAX_IV_SIZE];
struct aead_request aead_req;
u8 aead_req_ctx[];
};
struct tx_work { struct tx_work {
struct delayed_work work; struct delayed_work work;
struct sock *sk; struct sock *sk;
...@@ -348,44 +316,6 @@ struct tls_offload_context_rx { ...@@ -348,44 +316,6 @@ struct tls_offload_context_rx {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX \ #define TLS_OFFLOAD_CONTEXT_SIZE_RX \
(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX) (sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)
struct tls_context *tls_ctx_create(struct sock *sk);
void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
void update_sk_prot(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
unsigned int optlen);
void tls_err_abort(struct sock *sk, int err);
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
void tls_sw_release_resources_tx(struct sock *sk);
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_rx(struct sock *sk);
void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int flags, int *addr_len);
bool tls_sw_sock_is_readable(struct sock *sk);
ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct pipe_inode_info *pipe,
size_t len, unsigned int flags);
int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_device_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
int tls_tx_records(struct sock *sk, int flags);
struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
u32 seq, u64 *p_record_sn); u32 seq, u64 *p_record_sn);
...@@ -399,58 +329,6 @@ static inline u32 tls_record_start_seq(struct tls_record_info *rec) ...@@ -399,58 +329,6 @@ static inline u32 tls_record_start_seq(struct tls_record_info *rec)
return rec->end_seq - rec->len; return rec->end_seq - rec->len;
} }
int tls_push_sg(struct sock *sk, struct tls_context *ctx,
struct scatterlist *sg, u16 first_offset,
int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags);
void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
return &scb->tls;
}
static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
{
return !!ctx->partially_sent_record;
}
static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
{
return tls_ctx->pending_open_record_frags;
}
static inline bool is_tx_ready(struct tls_sw_context_tx *ctx)
{
struct tls_rec *rec;
rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
if (!rec)
return false;
return READ_ONCE(rec->tx_ready);
}
static inline u16 tls_user_config(struct tls_context *ctx, bool tx)
{
u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
switch (config) {
case TLS_BASE:
return TLS_CONF_BASE;
case TLS_SW:
return TLS_CONF_SW;
case TLS_HW:
return TLS_CONF_HW;
case TLS_HW_RECORD:
return TLS_CONF_HW_RECORD;
}
return 0;
}
struct sk_buff * struct sk_buff *
tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, tls_validate_xmit_skb(struct sock *sk, struct net_device *dev,
struct sk_buff *skb); struct sk_buff *skb);
...@@ -469,31 +347,6 @@ static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk) ...@@ -469,31 +347,6 @@ static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
#endif #endif
} }
static inline bool tls_bigint_increment(unsigned char *seq, int len)
{
int i;
for (i = len - 1; i >= 0; i--) {
++seq[i];
if (seq[i] != 0)
break;
}
return (i == -1);
}
static inline void tls_bigint_subtract(unsigned char *seq, int n)
{
u64 rcd_sn;
__be64 *p;
BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8);
p = (__be64 *)seq;
rcd_sn = be64_to_cpu(*p);
*p = cpu_to_be64(rcd_sn - n);
}
static inline struct tls_context *tls_get_ctx(const struct sock *sk) static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
...@@ -504,82 +357,6 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk) ...@@ -504,82 +357,6 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
return (__force void *)icsk->icsk_ulp_data; return (__force void *)icsk->icsk_ulp_data;
} }
static inline void tls_advance_record_sn(struct sock *sk,
struct tls_prot_info *prot,
struct cipher_context *ctx)
{
if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
tls_err_abort(sk, -EBADMSG);
if (prot->version != TLS_1_3_VERSION &&
prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
tls_bigint_increment(ctx->iv + prot->salt_size,
prot->iv_size);
}
static inline void tls_fill_prepend(struct tls_context *ctx,
char *buf,
size_t plaintext_len,
unsigned char record_type)
{
struct tls_prot_info *prot = &ctx->prot_info;
size_t pkt_len, iv_size = prot->iv_size;
pkt_len = plaintext_len + prot->tag_size;
if (prot->version != TLS_1_3_VERSION &&
prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) {
pkt_len += iv_size;
memcpy(buf + TLS_NONCE_OFFSET,
ctx->tx.iv + prot->salt_size, iv_size);
}
/* we cover nonce explicit here as well, so buf should be of
* size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
*/
buf[0] = prot->version == TLS_1_3_VERSION ?
TLS_RECORD_TYPE_DATA : record_type;
/* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */
buf[1] = TLS_1_2_VERSION_MINOR;
buf[2] = TLS_1_2_VERSION_MAJOR;
/* we can use IV for nonce explicit according to spec */
buf[3] = pkt_len >> 8;
buf[4] = pkt_len & 0xFF;
}
static inline void tls_make_aad(char *buf,
size_t size,
char *record_sequence,
unsigned char record_type,
struct tls_prot_info *prot)
{
if (prot->version != TLS_1_3_VERSION) {
memcpy(buf, record_sequence, prot->rec_seq_size);
buf += 8;
} else {
size += prot->tag_size;
}
buf[0] = prot->version == TLS_1_3_VERSION ?
TLS_RECORD_TYPE_DATA : record_type;
buf[1] = TLS_1_2_VERSION_MAJOR;
buf[2] = TLS_1_2_VERSION_MINOR;
buf[3] = size >> 8;
buf[4] = size & 0xFF;
}
static inline void xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq)
{
int i;
if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
for (i = 0; i < 8; i++)
iv[i + 4] ^= seq[i];
}
}
static inline struct tls_sw_context_rx *tls_sw_ctx_rx( static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
const struct tls_context *tls_ctx) const struct tls_context *tls_ctx)
{ {
...@@ -616,9 +393,6 @@ static inline bool tls_sw_has_ctx_rx(const struct sock *sk) ...@@ -616,9 +393,6 @@ static inline bool tls_sw_has_ctx_rx(const struct sock *sk)
return !!tls_sw_ctx_rx(ctx); return !!tls_sw_ctx_rx(ctx);
} }
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
static inline struct tls_offload_context_rx * static inline struct tls_offload_context_rx *
tls_offload_ctx_rx(const struct tls_context *tls_ctx) tls_offload_ctx_rx(const struct tls_context *tls_ctx)
{ {
...@@ -693,31 +467,11 @@ static inline bool tls_offload_tx_resync_pending(struct sock *sk) ...@@ -693,31 +467,11 @@ static inline bool tls_offload_tx_resync_pending(struct sock *sk)
return ret; return ret;
} }
int __net_init tls_proc_init(struct net *net);
void __net_exit tls_proc_fini(struct net *net);
int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type);
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout);
struct sk_buff *tls_encrypt_skb(struct sk_buff *skb); struct sk_buff *tls_encrypt_skb(struct sk_buff *skb);
int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info);
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
void tls_device_init(void);
void tls_device_cleanup(void);
void tls_device_sk_destruct(struct sock *sk); void tls_device_sk_destruct(struct sock *sk);
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
void tls_device_free_resources_tx(struct sock *sk);
int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
void tls_device_offload_cleanup_rx(struct sock *sk);
void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq); void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq);
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm);
static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk)
{ {
...@@ -726,33 +480,5 @@ static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) ...@@ -726,33 +480,5 @@ static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk)
return false; return false;
return tls_get_ctx(sk)->rx_conf == TLS_HW; return tls_get_ctx(sk)->rx_conf == TLS_HW;
} }
#else
static inline void tls_device_init(void) {}
static inline void tls_device_cleanup(void) {}
static inline int
tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
return -EOPNOTSUPP;
}
static inline void tls_device_free_resources_tx(struct sock *sk) {}
static inline int
tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
{
return -EOPNOTSUPP;
}
static inline void tls_device_offload_cleanup_rx(struct sock *sk) {}
static inline void
tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
static inline int
tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm)
{
return 0;
}
#endif #endif
#endif /* _TLS_OFFLOAD_H */ #endif /* _TLS_OFFLOAD_H */
...@@ -533,6 +533,9 @@ EXPORT_SYMBOL_GPL(strp_check_rcv); ...@@ -533,6 +533,9 @@ EXPORT_SYMBOL_GPL(strp_check_rcv);
static int __init strp_dev_init(void) static int __init strp_dev_init(void)
{ {
BUILD_BUG_ON(sizeof(struct sk_skb_cb) >
sizeof_field(struct sk_buff, cb));
strp_wq = create_singlethread_workqueue("kstrp"); strp_wq = create_singlethread_workqueue("kstrp");
if (unlikely(!strp_wq)) if (unlikely(!strp_wq))
return -ENOMEM; return -ENOMEM;
......
/*
* Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
* Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
*
* This software is available to you under a choice of one of two
* licenses. You may choose to be licensed under the terms of the GNU
* General Public License (GPL) Version 2, available from the file
* COPYING in the main directory of this source tree, or the
* OpenIB.org BSD license below:
*
* Redistribution and use in source and binary forms, with or
* without modification, are permitted provided that the following
* conditions are met:
*
* - Redistributions of source code must retain the above
* copyright notice, this list of conditions and the following
* disclaimer.
*
* - Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following
* disclaimer in the documentation and/or other materials
* provided with the distribution.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
* BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
* ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef _TLS_INT_H
#define _TLS_INT_H
#include <asm/byteorder.h>
#include <linux/types.h>
#include <linux/skmsg.h>
#include <net/tls.h>
#define __TLS_INC_STATS(net, field) \
__SNMP_INC_STATS((net)->mib.tls_statistics, field)
#define TLS_INC_STATS(net, field) \
SNMP_INC_STATS((net)->mib.tls_statistics, field)
#define TLS_DEC_STATS(net, field) \
SNMP_DEC_STATS((net)->mib.tls_statistics, field)
/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages
* allocated or mapped for each TLS record. After encryption, the records are
* stores in a linked list.
*/
struct tls_rec {
struct list_head list;
int tx_ready;
int tx_flags;
struct sk_msg msg_plaintext;
struct sk_msg msg_encrypted;
/* AAD | msg_plaintext.sg.data | sg_tag */
struct scatterlist sg_aead_in[2];
/* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */
struct scatterlist sg_aead_out[2];
char content_type;
struct scatterlist sg_content_type;
char aad_space[TLS_AAD_SPACE_SIZE];
u8 iv_data[MAX_IV_SIZE];
struct aead_request aead_req;
u8 aead_req_ctx[];
};
int __net_init tls_proc_init(struct net *net);
void __net_exit tls_proc_fini(struct net *net);
struct tls_context *tls_ctx_create(struct sock *sk);
void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
void update_sk_prot(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
unsigned int optlen);
void tls_err_abort(struct sock *sk, int err);
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
void tls_sw_release_resources_tx(struct sock *sk);
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_rx(struct sock *sk);
void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int flags, int *addr_len);
bool tls_sw_sock_is_readable(struct sock *sk);
ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct pipe_inode_info *pipe,
size_t len, unsigned int flags);
int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_device_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
int tls_tx_records(struct sock *sk, int flags);
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type);
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout);
int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info);
static inline struct tls_msg *tls_msg(struct sk_buff *skb)
{
struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
return &scb->tls;
}
#ifdef CONFIG_TLS_DEVICE
void tls_device_init(void);
void tls_device_cleanup(void);
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
void tls_device_free_resources_tx(struct sock *sk);
int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
void tls_device_offload_cleanup_rx(struct sock *sk);
void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm);
#else
static inline void tls_device_init(void) {}
static inline void tls_device_cleanup(void) {}
static inline int
tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
return -EOPNOTSUPP;
}
static inline void tls_device_free_resources_tx(struct sock *sk) {}
static inline int
tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
{
return -EOPNOTSUPP;
}
static inline void tls_device_offload_cleanup_rx(struct sock *sk) {}
static inline void
tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
static inline int
tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
struct sk_buff *skb, struct strp_msg *rxm)
{
return 0;
}
#endif
int tls_push_sg(struct sock *sk, struct tls_context *ctx,
struct scatterlist *sg, u16 first_offset,
int flags);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags);
void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
{
return !!ctx->partially_sent_record;
}
static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
{
return tls_ctx->pending_open_record_frags;
}
static inline bool tls_bigint_increment(unsigned char *seq, int len)
{
int i;
for (i = len - 1; i >= 0; i--) {
++seq[i];
if (seq[i] != 0)
break;
}
return (i == -1);
}
static inline void tls_bigint_subtract(unsigned char *seq, int n)
{
u64 rcd_sn;
__be64 *p;
BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8);
p = (__be64 *)seq;
rcd_sn = be64_to_cpu(*p);
*p = cpu_to_be64(rcd_sn - n);
}
static inline void
tls_advance_record_sn(struct sock *sk, struct tls_prot_info *prot,
struct cipher_context *ctx)
{
if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
tls_err_abort(sk, -EBADMSG);
if (prot->version != TLS_1_3_VERSION &&
prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
tls_bigint_increment(ctx->iv + prot->salt_size,
prot->iv_size);
}
static inline void
tls_xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq)
{
int i;
if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
for (i = 0; i < 8; i++)
iv[i + 4] ^= seq[i];
}
}
static inline void
tls_fill_prepend(struct tls_context *ctx, char *buf, size_t plaintext_len,
unsigned char record_type)
{
struct tls_prot_info *prot = &ctx->prot_info;
size_t pkt_len, iv_size = prot->iv_size;
pkt_len = plaintext_len + prot->tag_size;
if (prot->version != TLS_1_3_VERSION &&
prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) {
pkt_len += iv_size;
memcpy(buf + TLS_NONCE_OFFSET,
ctx->tx.iv + prot->salt_size, iv_size);
}
/* we cover nonce explicit here as well, so buf should be of
* size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
*/
buf[0] = prot->version == TLS_1_3_VERSION ?
TLS_RECORD_TYPE_DATA : record_type;
/* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */
buf[1] = TLS_1_2_VERSION_MINOR;
buf[2] = TLS_1_2_VERSION_MAJOR;
/* we can use IV for nonce explicit according to spec */
buf[3] = pkt_len >> 8;
buf[4] = pkt_len & 0xFF;
}
static inline
void tls_make_aad(char *buf, size_t size, char *record_sequence,
unsigned char record_type, struct tls_prot_info *prot)
{
if (prot->version != TLS_1_3_VERSION) {
memcpy(buf, record_sequence, prot->rec_seq_size);
buf += 8;
} else {
size += prot->tag_size;
}
buf[0] = prot->version == TLS_1_3_VERSION ?
TLS_RECORD_TYPE_DATA : record_type;
buf[1] = TLS_1_2_VERSION_MAJOR;
buf[2] = TLS_1_2_VERSION_MINOR;
buf[3] = size >> 8;
buf[4] = size & 0xFF;
}
#endif
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <net/tcp.h> #include <net/tcp.h>
#include <net/tls.h> #include <net/tls.h>
#include "tls.h"
#include "trace.h" #include "trace.h"
/* device_offload_lock is used to synchronize tls_dev_add /* device_offload_lock is used to synchronize tls_dev_add
...@@ -562,7 +563,7 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -562,7 +563,7 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk); lock_sock(sk);
if (unlikely(msg->msg_controllen)) { if (unlikely(msg->msg_controllen)) {
rc = tls_proccess_cmsg(sk, msg, &record_type); rc = tls_process_cmsg(sk, msg, &record_type);
if (rc) if (rc)
goto out; goto out;
} }
......
...@@ -34,6 +34,8 @@ ...@@ -34,6 +34,8 @@
#include <crypto/scatterwalk.h> #include <crypto/scatterwalk.h>
#include <net/ip6_checksum.h> #include <net/ip6_checksum.h>
#include "tls.h"
static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
{ {
struct scatterlist *src = walk->sg; struct scatterlist *src = walk->sg;
......
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
#include <net/tls.h> #include <net/tls.h>
#include <net/tls_toe.h> #include <net/tls_toe.h>
#include "tls.h"
MODULE_AUTHOR("Mellanox Technologies"); MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL"); MODULE_LICENSE("Dual BSD/GPL");
...@@ -164,7 +166,7 @@ static int tls_handle_open_record(struct sock *sk, int flags) ...@@ -164,7 +166,7 @@ static int tls_handle_open_record(struct sock *sk, int flags)
return 0; return 0;
} }
int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type) unsigned char *record_type)
{ {
struct cmsghdr *cmsg; struct cmsghdr *cmsg;
...@@ -1003,6 +1005,23 @@ static void tls_update(struct sock *sk, struct proto *p, ...@@ -1003,6 +1005,23 @@ static void tls_update(struct sock *sk, struct proto *p,
} }
} }
static u16 tls_user_config(struct tls_context *ctx, bool tx)
{
u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
switch (config) {
case TLS_BASE:
return TLS_CONF_BASE;
case TLS_SW:
return TLS_CONF_SW;
case TLS_HW:
return TLS_CONF_HW;
case TLS_HW_RECORD:
return TLS_CONF_HW_RECORD;
}
return 0;
}
static int tls_get_info(const struct sock *sk, struct sk_buff *skb) static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
{ {
u16 version, cipher_type; u16 version, cipher_type;
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <net/snmp.h> #include <net/snmp.h>
#include <net/tls.h> #include <net/tls.h>
#include "tls.h"
#ifdef CONFIG_PROC_FS #ifdef CONFIG_PROC_FS
static const struct snmp_mib tls_mib_list[] = { static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW), SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW),
......
...@@ -44,12 +44,21 @@ ...@@ -44,12 +44,21 @@
#include <net/strparser.h> #include <net/strparser.h>
#include <net/tls.h> #include <net/tls.h>
#include "tls.h"
struct tls_decrypt_arg { struct tls_decrypt_arg {
bool zc; bool zc;
bool async; bool async;
u8 tail; u8 tail;
}; };
struct tls_decrypt_ctx {
u8 iv[MAX_IV_SIZE];
u8 aad[TLS_MAX_AAD_SIZE];
u8 tail;
struct scatterlist sg[];
};
noinline void tls_err_abort(struct sock *sk, int err) noinline void tls_err_abort(struct sock *sk, int err)
{ {
WARN_ON_ONCE(err >= 0); WARN_ON_ONCE(err >= 0);
...@@ -517,7 +526,8 @@ static int tls_do_encryption(struct sock *sk, ...@@ -517,7 +526,8 @@ static int tls_do_encryption(struct sock *sk,
memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
prot->iv_size + prot->salt_size); prot->iv_size + prot->salt_size);
xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq); tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
tls_ctx->tx.rec_seq);
sge->offset += prot->prepend_size; sge->offset += prot->prepend_size;
sge->length -= prot->prepend_size; sge->length -= prot->prepend_size;
...@@ -954,7 +964,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -954,7 +964,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk); lock_sock(sk);
if (unlikely(msg->msg_controllen)) { if (unlikely(msg->msg_controllen)) {
ret = tls_proccess_cmsg(sk, msg, &record_type); ret = tls_process_cmsg(sk, msg, &record_type);
if (ret) { if (ret) {
if (ret == -EINPROGRESS) if (ret == -EINPROGRESS)
num_async++; num_async++;
...@@ -1292,54 +1302,50 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -1292,54 +1302,50 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
return ret; return ret;
} }
static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, static int
bool nonblock, long timeo, int *err) tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
long timeo)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { while (!ctx->recv_pkt) {
if (sk->sk_err) { if (!sk_psock_queue_empty(psock))
*err = sock_error(sk); return 0;
return NULL;
} if (sk->sk_err)
return sock_error(sk);
if (!skb_queue_empty(&sk->sk_receive_queue)) { if (!skb_queue_empty(&sk->sk_receive_queue)) {
__strp_unpause(&ctx->strp); __strp_unpause(&ctx->strp);
if (ctx->recv_pkt) if (ctx->recv_pkt)
return ctx->recv_pkt; break;
} }
if (sk->sk_shutdown & RCV_SHUTDOWN) if (sk->sk_shutdown & RCV_SHUTDOWN)
return NULL; return 0;
if (sock_flag(sk, SOCK_DONE)) if (sock_flag(sk, SOCK_DONE))
return NULL; return 0;
if (nonblock || !timeo) { if (nonblock || !timeo)
*err = -EAGAIN; return -EAGAIN;
return NULL;
}
add_wait_queue(sk_sleep(sk), &wait); add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo, sk_wait_event(sk, &timeo,
ctx->recv_pkt != skb || ctx->recv_pkt || !sk_psock_queue_empty(psock),
!sk_psock_queue_empty(psock),
&wait); &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait); remove_wait_queue(sk_sleep(sk), &wait);
/* Handle signals */ /* Handle signals */
if (signal_pending(current)) { if (signal_pending(current))
*err = sock_intr_errno(timeo); return sock_intr_errno(timeo);
return NULL;
}
} }
return skb; return 1;
} }
static int tls_setup_from_iter(struct iov_iter *from, static int tls_setup_from_iter(struct iov_iter *from,
...@@ -1414,17 +1420,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1414,17 +1420,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot = &tls_ctx->prot_info;
int n_sgin, n_sgout, aead_size, err, pages = 0;
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb); struct tls_msg *tlm = tls_msg(skb);
int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
u8 *aad, *iv, *tail, *mem = NULL;
struct aead_request *aead_req; struct aead_request *aead_req;
struct sk_buff *unused; struct sk_buff *unused;
struct scatterlist *sgin = NULL; struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL; struct scatterlist *sgout = NULL;
const int data_len = rxm->full_len - prot->overhead_size; const int data_len = rxm->full_len - prot->overhead_size;
int tail_pages = !!prot->tail_size; int tail_pages = !!prot->tail_size;
struct tls_decrypt_ctx *dctx;
int iv_offset = 0; int iv_offset = 0;
u8 *mem;
if (darg->zc && (out_iov || out_sg)) { if (darg->zc && (out_iov || out_sg)) {
if (out_iov) if (out_iov)
...@@ -1446,38 +1453,30 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1446,38 +1453,30 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Increment to accommodate AAD */ /* Increment to accommodate AAD */
n_sgin = n_sgin + 1; n_sgin = n_sgin + 1;
nsg = n_sgin + n_sgout;
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem_size = aead_size + (nsg * sizeof(struct scatterlist));
mem_size = mem_size + prot->aad_size;
mem_size = mem_size + MAX_IV_SIZE;
mem_size = mem_size + prot->tail_size;
/* Allocate a single block of memory which contains /* Allocate a single block of memory which contains
* aead_req || sgin[] || sgout[] || aad || iv || tail. * aead_req || tls_decrypt_ctx.
* This order achieves correct alignment for aead_req, sgin, sgout. * Both structs are variable length.
*/ */
mem = kmalloc(mem_size, sk->sk_allocation); aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
sk->sk_allocation);
if (!mem) if (!mem)
return -ENOMEM; return -ENOMEM;
/* Segment the allocated memory */ /* Segment the allocated memory */
aead_req = (struct aead_request *)mem; aead_req = (struct aead_request *)mem;
sgin = (struct scatterlist *)(mem + aead_size); dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
sgout = sgin + n_sgin; sgin = &dctx->sg[0];
aad = (u8 *)(sgout + n_sgout); sgout = &dctx->sg[n_sgin];
iv = aad + prot->aad_size;
tail = iv + MAX_IV_SIZE;
/* For CCM based ciphers, first byte of nonce+iv is a constant */ /* For CCM based ciphers, first byte of nonce+iv is a constant */
switch (prot->cipher_type) { switch (prot->cipher_type) {
case TLS_CIPHER_AES_CCM_128: case TLS_CIPHER_AES_CCM_128:
iv[0] = TLS_AES_CCM_IV_B0_BYTE; dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
iv_offset = 1; iv_offset = 1;
break; break;
case TLS_CIPHER_SM4_CCM: case TLS_CIPHER_SM4_CCM:
iv[0] = TLS_SM4_CCM_IV_B0_BYTE; dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
iv_offset = 1; iv_offset = 1;
break; break;
} }
...@@ -1485,40 +1484,36 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1485,40 +1484,36 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Prepare IV */ /* Prepare IV */
if (prot->version == TLS_1_3_VERSION || if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
memcpy(iv + iv_offset, tls_ctx->rx.iv, memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
prot->iv_size + prot->salt_size); prot->iv_size + prot->salt_size);
} else { } else {
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
iv + iv_offset + prot->salt_size, &dctx->iv[iv_offset] + prot->salt_size,
prot->iv_size); prot->iv_size);
if (err < 0) { if (err < 0)
kfree(mem); goto exit_free;
return err; memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
}
memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
} }
xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
/* Prepare AAD */ /* Prepare AAD */
tls_make_aad(aad, rxm->full_len - prot->overhead_size + tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
prot->tail_size, prot->tail_size,
tls_ctx->rx.rec_seq, tlm->control, prot); tls_ctx->rx.rec_seq, tlm->control, prot);
/* Prepare sgin */ /* Prepare sgin */
sg_init_table(sgin, n_sgin); sg_init_table(sgin, n_sgin);
sg_set_buf(&sgin[0], aad, prot->aad_size); sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
err = skb_to_sgvec(skb, &sgin[1], err = skb_to_sgvec(skb, &sgin[1],
rxm->offset + prot->prepend_size, rxm->offset + prot->prepend_size,
rxm->full_len - prot->prepend_size); rxm->full_len - prot->prepend_size);
if (err < 0) { if (err < 0)
kfree(mem); goto exit_free;
return err;
}
if (n_sgout) { if (n_sgout) {
if (out_iov) { if (out_iov) {
sg_init_table(sgout, n_sgout); sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], aad, prot->aad_size); sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
err = tls_setup_from_iter(out_iov, data_len, err = tls_setup_from_iter(out_iov, data_len,
&pages, &sgout[1], &pages, &sgout[1],
...@@ -1528,7 +1523,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1528,7 +1523,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
if (prot->tail_size) { if (prot->tail_size) {
sg_unmark_end(&sgout[pages]); sg_unmark_end(&sgout[pages]);
sg_set_buf(&sgout[pages + 1], tail, sg_set_buf(&sgout[pages + 1], &dctx->tail,
prot->tail_size); prot->tail_size);
sg_mark_end(&sgout[pages + 1]); sg_mark_end(&sgout[pages + 1]);
} }
...@@ -1545,18 +1540,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1545,18 +1540,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
} }
/* Prepare and submit AEAD request */ /* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv, err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
data_len + prot->tail_size, aead_req, darg); data_len + prot->tail_size, aead_req, darg);
if (darg->async) if (darg->async)
return 0; return 0;
if (prot->tail_size) if (prot->tail_size)
darg->tail = *tail; darg->tail = dctx->tail;
/* Release the pages in case iov was mapped to pages */ /* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--) for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages])); put_page(sg_page(&sgout[pages]));
exit_free:
kfree(mem); kfree(mem);
return err; return err;
} }
...@@ -1813,8 +1808,8 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1813,8 +1808,8 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_decrypt_arg darg = {}; struct tls_decrypt_arg darg = {};
int to_decrypt, chunk; int to_decrypt, chunk;
skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err); err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
if (!skb) { if (err <= 0) {
if (psock) { if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len, chunk = sk_msg_recvmsg(sk, psock, msg, len,
flags); flags);
...@@ -1824,6 +1819,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1824,6 +1819,7 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end; goto recv_end;
} }
skb = ctx->recv_pkt;
rxm = strp_msg(skb); rxm = strp_msg(skb);
tlm = tls_msg(skb); tlm = tls_msg(skb);
...@@ -1990,11 +1986,13 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, ...@@ -1990,11 +1986,13 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
} else { } else {
struct tls_decrypt_arg darg = {}; struct tls_decrypt_arg darg = {};
skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
&err); timeo);
if (!skb) if (err <= 0)
goto splice_read_end; goto splice_read_end;
skb = ctx->recv_pkt;
err = decrypt_skb_update(sk, skb, NULL, &darg); err = decrypt_skb_update(sk, skb, NULL, &darg);
if (err < 0) { if (err < 0) {
tls_err_abort(sk, -EBADMSG); tls_err_abort(sk, -EBADMSG);
...@@ -2271,12 +2269,23 @@ static void tx_work_handler(struct work_struct *work) ...@@ -2271,12 +2269,23 @@ static void tx_work_handler(struct work_struct *work)
mutex_unlock(&tls_ctx->tx_lock); mutex_unlock(&tls_ctx->tx_lock);
} }
static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
{
struct tls_rec *rec;
rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
if (!rec)
return false;
return READ_ONCE(rec->tx_ready);
}
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
{ {
struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
/* Schedule the transmission if tx list is ready */ /* Schedule the transmission if tx list is ready */
if (is_tx_ready(tx_ctx) && if (tls_is_tx_ready(tx_ctx) &&
!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
schedule_delayed_work(&tx_ctx->tx_work.work, 0); schedule_delayed_work(&tx_ctx->tx_work.work, 0);
} }
...@@ -2474,13 +2483,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2474,13 +2483,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_priv; goto free_priv;
} }
/* Sanity-check the sizes for stack allocations. */
if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
rc = -EINVAL;
goto free_priv;
}
if (crypto_info->version == TLS_1_3_VERSION) { if (crypto_info->version == TLS_1_3_VERSION) {
nonce_size = 0; nonce_size = 0;
prot->aad_size = TLS_HEADER_SIZE; prot->aad_size = TLS_HEADER_SIZE;
...@@ -2490,6 +2492,14 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2490,6 +2492,14 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
prot->tail_size = 0; prot->tail_size = 0;
} }
/* Sanity-check the sizes for stack allocations. */
if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
prot->aad_size > TLS_MAX_AAD_SIZE) {
rc = -EINVAL;
goto free_priv;
}
prot->version = crypto_info->version; prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type; prot->cipher_type = crypto_info->cipher_type;
prot->prepend_size = TLS_HEADER_SIZE + nonce_size; prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
......
...@@ -38,6 +38,8 @@ ...@@ -38,6 +38,8 @@
#include <net/tls.h> #include <net/tls.h>
#include <net/tls_toe.h> #include <net/tls_toe.h>
#include "tls.h"
static LIST_HEAD(device_list); static LIST_HEAD(device_list);
static DEFINE_SPINLOCK(device_spinlock); static DEFINE_SPINLOCK(device_spinlock);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment