Commit 4509de14 authored by Vakul Garg's avatar Vakul Garg Committed by David S. Miller

net/tls: Move protocol constants from cipher context to tls context

Each tls context maintains two cipher contexts (one each for tx and rx
directions). For each tls session, the constants such as protocol
version, ciphersuite, iv size, associated data size etc are same for
both the directions and need to be stored only once per tls context.
Hence these are moved from 'struct cipher_context' to 'struct
tls_prot_info' and stored only once in 'struct tls_context'.
Signed-off-by: default avatarVakul Garg <vakul.garg@nxp.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent c9b747db
...@@ -199,15 +199,8 @@ enum { ...@@ -199,15 +199,8 @@ enum {
}; };
struct cipher_context { struct cipher_context {
u16 prepend_size;
u16 tag_size;
u16 overhead_size;
u16 iv_size;
char *iv; char *iv;
u16 rec_seq_size;
char *rec_seq; char *rec_seq;
u16 aad_size;
u16 tail_size;
}; };
union tls_crypto_context { union tls_crypto_context {
...@@ -218,7 +211,21 @@ union tls_crypto_context { ...@@ -218,7 +211,21 @@ union tls_crypto_context {
}; };
}; };
struct tls_prot_info {
u16 version;
u16 cipher_type;
u16 prepend_size;
u16 tag_size;
u16 overhead_size;
u16 iv_size;
u16 rec_seq_size;
u16 aad_size;
u16 tail_size;
};
struct tls_context { struct tls_context {
struct tls_prot_info prot_info;
union tls_crypto_context crypto_send; union tls_crypto_context crypto_send;
union tls_crypto_context crypto_recv; union tls_crypto_context crypto_recv;
...@@ -401,16 +408,26 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len) ...@@ -401,16 +408,26 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len)
return (i == -1); return (i == -1);
} }
static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
return icsk->icsk_ulp_data;
}
static inline void tls_advance_record_sn(struct sock *sk, static inline void tls_advance_record_sn(struct sock *sk,
struct cipher_context *ctx, struct cipher_context *ctx,
int version) int version)
{ {
if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size)) struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
if (version != TLS_1_3_VERSION) { if (version != TLS_1_3_VERSION) {
tls_bigint_increment(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, tls_bigint_increment(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
ctx->iv_size); prot->iv_size);
} }
} }
...@@ -420,9 +437,10 @@ static inline void tls_fill_prepend(struct tls_context *ctx, ...@@ -420,9 +437,10 @@ static inline void tls_fill_prepend(struct tls_context *ctx,
unsigned char record_type, unsigned char record_type,
int version) int version)
{ {
size_t pkt_len, iv_size = ctx->tx.iv_size; struct tls_prot_info *prot = &ctx->prot_info;
size_t pkt_len, iv_size = prot->iv_size;
pkt_len = plaintext_len + ctx->tx.tag_size; pkt_len = plaintext_len + prot->tag_size;
if (version != TLS_1_3_VERSION) { if (version != TLS_1_3_VERSION) {
pkt_len += iv_size; pkt_len += iv_size;
...@@ -475,12 +493,6 @@ static inline void xor_iv_with_seq(int version, char *iv, char *seq) ...@@ -475,12 +493,6 @@ static inline void xor_iv_with_seq(int version, char *iv, char *seq)
} }
} }
static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
return icsk->icsk_ulp_data;
}
static inline struct tls_sw_context_rx *tls_sw_ctx_rx( static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
const struct tls_context *tls_ctx) const struct tls_context *tls_ctx)
......
...@@ -247,6 +247,7 @@ static int tls_push_record(struct sock *sk, ...@@ -247,6 +247,7 @@ static int tls_push_record(struct sock *sk,
int flags, int flags,
unsigned char record_type) unsigned char record_type)
{ {
struct tls_prot_info *prot = &ctx->prot_info;
struct tcp_sock *tp = tcp_sk(sk); struct tcp_sock *tp = tcp_sk(sk);
struct page_frag dummy_tag_frag; struct page_frag dummy_tag_frag;
skb_frag_t *frag; skb_frag_t *frag;
...@@ -256,7 +257,7 @@ static int tls_push_record(struct sock *sk, ...@@ -256,7 +257,7 @@ static int tls_push_record(struct sock *sk,
frag = &record->frags[0]; frag = &record->frags[0];
tls_fill_prepend(ctx, tls_fill_prepend(ctx,
skb_frag_address(frag), skb_frag_address(frag),
record->len - ctx->tx.prepend_size, record->len - prot->prepend_size,
record_type, record_type,
ctx->crypto_send.info.version); ctx->crypto_send.info.version);
...@@ -264,7 +265,7 @@ static int tls_push_record(struct sock *sk, ...@@ -264,7 +265,7 @@ static int tls_push_record(struct sock *sk,
dummy_tag_frag.page = skb_frag_page(frag); dummy_tag_frag.page = skb_frag_page(frag);
dummy_tag_frag.offset = 0; dummy_tag_frag.offset = 0;
tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); tls_append_frag(record, &dummy_tag_frag, prot->tag_size);
record->end_seq = tp->write_seq + record->len; record->end_seq = tp->write_seq + record->len;
spin_lock_irq(&offload_ctx->lock); spin_lock_irq(&offload_ctx->lock);
list_add_tail(&record->list, &offload_ctx->records_list); list_add_tail(&record->list, &offload_ctx->records_list);
...@@ -347,6 +348,7 @@ static int tls_push_data(struct sock *sk, ...@@ -347,6 +348,7 @@ static int tls_push_data(struct sock *sk,
unsigned char record_type) unsigned char record_type)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
...@@ -376,10 +378,10 @@ static int tls_push_data(struct sock *sk, ...@@ -376,10 +378,10 @@ static int tls_push_data(struct sock *sk,
* we need to leave room for an authentication tag. * we need to leave room for an authentication tag.
*/ */
max_open_record_len = TLS_MAX_PAYLOAD_SIZE + max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
tls_ctx->tx.prepend_size; prot->prepend_size;
do { do {
rc = tls_do_allocation(sk, ctx, pfrag, rc = tls_do_allocation(sk, ctx, pfrag,
tls_ctx->tx.prepend_size); prot->prepend_size);
if (rc) { if (rc) {
rc = sk_stream_wait_memory(sk, &timeo); rc = sk_stream_wait_memory(sk, &timeo);
if (!rc) if (!rc)
...@@ -397,7 +399,7 @@ static int tls_push_data(struct sock *sk, ...@@ -397,7 +399,7 @@ static int tls_push_data(struct sock *sk,
size = orig_size; size = orig_size;
destroy_record(record); destroy_record(record);
ctx->open_record = NULL; ctx->open_record = NULL;
} else if (record->len > tls_ctx->tx.prepend_size) { } else if (record->len > prot->prepend_size) {
goto last_record; goto last_record;
} }
...@@ -658,6 +660,8 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) ...@@ -658,6 +660,8 @@ int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{ {
u16 nonce_size, tag_size, iv_size, rec_seq_size; u16 nonce_size, tag_size, iv_size, rec_seq_size;
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_record_info *start_marker_record; struct tls_record_info *start_marker_record;
struct tls_offload_context_tx *offload_ctx; struct tls_offload_context_tx *offload_ctx;
struct tls_crypto_info *crypto_info; struct tls_crypto_info *crypto_info;
...@@ -703,10 +707,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) ...@@ -703,10 +707,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
goto free_offload_ctx; goto free_offload_ctx;
} }
ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
ctx->tx.tag_size = tag_size; prot->tag_size = tag_size;
ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; prot->overhead_size = prot->prepend_size + prot->tag_size;
ctx->tx.iv_size = iv_size; prot->iv_size = iv_size;
ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
GFP_KERNEL); GFP_KERNEL);
if (!ctx->tx.iv) { if (!ctx->tx.iv) {
...@@ -716,7 +720,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) ...@@ -716,7 +720,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
ctx->tx.rec_seq_size = rec_seq_size; prot->rec_seq_size = rec_seq_size;
ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!ctx->tx.rec_seq) { if (!ctx->tx.rec_seq) {
rc = -ENOMEM; rc = -ENOMEM;
......
...@@ -435,6 +435,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -435,6 +435,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
unsigned int optlen, int tx) unsigned int optlen, int tx)
{ {
struct tls_crypto_info *crypto_info; struct tls_crypto_info *crypto_info;
struct tls_crypto_info *alt_crypto_info;
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
size_t optsize; size_t optsize;
int rc = 0; int rc = 0;
...@@ -445,10 +446,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -445,10 +446,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto out; goto out;
} }
if (tx) if (tx) {
crypto_info = &ctx->crypto_send.info; crypto_info = &ctx->crypto_send.info;
else alt_crypto_info = &ctx->crypto_recv.info;
} else {
crypto_info = &ctx->crypto_recv.info; crypto_info = &ctx->crypto_recv.info;
alt_crypto_info = &ctx->crypto_send.info;
}
/* Currently we don't support set crypto info more than one time */ /* Currently we don't support set crypto info more than one time */
if (TLS_CRYPTO_INFO_READY(crypto_info)) { if (TLS_CRYPTO_INFO_READY(crypto_info)) {
...@@ -469,6 +473,15 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -469,6 +473,15 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info; goto err_crypto_info;
} }
/* Ensure that TLS version and ciphers are same in both directions */
if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
if (alt_crypto_info->version != crypto_info->version ||
alt_crypto_info->cipher_type != crypto_info->cipher_type) {
rc = -EINVAL;
goto err_crypto_info;
}
}
switch (crypto_info->cipher_type) { switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128: case TLS_CIPHER_AES_GCM_128:
case TLS_CIPHER_AES_GCM_256: { case TLS_CIPHER_AES_GCM_256: {
......
...@@ -127,7 +127,7 @@ static int padding_length(struct tls_sw_context_rx *ctx, ...@@ -127,7 +127,7 @@ static int padding_length(struct tls_sw_context_rx *ctx,
int sub = 0; int sub = 0;
/* Determine zero-padding length */ /* Determine zero-padding length */
if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION) { if (tls_ctx->prot_info.version == TLS_1_3_VERSION) {
char content_type = 0; char content_type = 0;
int err; int err;
int back = 17; int back = 17;
...@@ -155,6 +155,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -155,6 +155,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sgin = aead_req->src; struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx; struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx; struct tls_context *tls_ctx;
struct tls_prot_info *prot;
struct scatterlist *sg; struct scatterlist *sg;
struct sk_buff *skb; struct sk_buff *skb;
unsigned int pages; unsigned int pages;
...@@ -163,6 +164,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -163,6 +164,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
skb = (struct sk_buff *)req->data; skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk); tls_ctx = tls_get_ctx(skb->sk);
ctx = tls_sw_ctx_rx(tls_ctx); ctx = tls_sw_ctx_rx(tls_ctx);
prot = &tls_ctx->prot_info;
/* Propagate if there was an err */ /* Propagate if there was an err */
if (err) { if (err) {
...@@ -171,8 +173,8 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err) ...@@ -171,8 +173,8 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
} else { } else {
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
rxm->full_len -= padding_length(ctx, tls_ctx, skb); rxm->full_len -= padding_length(ctx, tls_ctx, skb);
rxm->offset += tls_ctx->rx.prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= tls_ctx->rx.overhead_size; rxm->full_len -= prot->overhead_size;
} }
/* After using skb->sk to propagate sk through crypto async callback /* After using skb->sk to propagate sk through crypto async callback
...@@ -209,13 +211,14 @@ static int tls_do_decryption(struct sock *sk, ...@@ -209,13 +211,14 @@ static int tls_do_decryption(struct sock *sk,
bool async) bool async)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int ret; int ret;
aead_request_set_tfm(aead_req, ctx->aead_recv); aead_request_set_tfm(aead_req, ctx->aead_recv);
aead_request_set_ad(aead_req, tls_ctx->rx.aad_size); aead_request_set_ad(aead_req, prot->aad_size);
aead_request_set_crypt(aead_req, sgin, sgout, aead_request_set_crypt(aead_req, sgin, sgout,
data_len + tls_ctx->rx.tag_size, data_len + prot->tag_size,
(u8 *)iv_recv); (u8 *)iv_recv);
if (async) { if (async) {
...@@ -253,12 +256,13 @@ static int tls_do_decryption(struct sock *sk, ...@@ -253,12 +256,13 @@ static int tls_do_decryption(struct sock *sk,
static void tls_trim_both_msgs(struct sock *sk, int target_size) static void tls_trim_both_msgs(struct sock *sk, int target_size)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec; struct tls_rec *rec = ctx->open_rec;
sk_msg_trim(sk, &rec->msg_plaintext, target_size); sk_msg_trim(sk, &rec->msg_plaintext, target_size);
if (target_size > 0) if (target_size > 0)
target_size += tls_ctx->tx.overhead_size; target_size += prot->overhead_size;
sk_msg_trim(sk, &rec->msg_encrypted, target_size); sk_msg_trim(sk, &rec->msg_encrypted, target_size);
} }
...@@ -275,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len) ...@@ -275,6 +279,7 @@ static int tls_alloc_encrypted_msg(struct sock *sk, int len)
static int tls_clone_plaintext_msg(struct sock *sk, int required) static int tls_clone_plaintext_msg(struct sock *sk, int required)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec; struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_pl = &rec->msg_plaintext; struct sk_msg *msg_pl = &rec->msg_plaintext;
...@@ -290,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) ...@@ -290,7 +295,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
/* Skip initial bytes in msg_en's data to be able to use /* Skip initial bytes in msg_en's data to be able to use
* same offset of both plain and encrypted data. * same offset of both plain and encrypted data.
*/ */
skip = tls_ctx->tx.prepend_size + msg_pl->sg.size; skip = prot->prepend_size + msg_pl->sg.size;
return sk_msg_clone(sk, msg_pl, msg_en, skip, len); return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
} }
...@@ -298,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required) ...@@ -298,6 +303,7 @@ static int tls_clone_plaintext_msg(struct sock *sk, int required)
static struct tls_rec *tls_get_rec(struct sock *sk) static struct tls_rec *tls_get_rec(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct sk_msg *msg_pl, *msg_en; struct sk_msg *msg_pl, *msg_en;
struct tls_rec *rec; struct tls_rec *rec;
...@@ -316,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk) ...@@ -316,13 +322,11 @@ static struct tls_rec *tls_get_rec(struct sock *sk)
sk_msg_init(msg_en); sk_msg_init(msg_en);
sg_init_table(rec->sg_aead_in, 2); sg_init_table(rec->sg_aead_in, 2);
sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
tls_ctx->tx.aad_size);
sg_unmark_end(&rec->sg_aead_in[1]); sg_unmark_end(&rec->sg_aead_in[1]);
sg_init_table(rec->sg_aead_out, 2); sg_init_table(rec->sg_aead_out, 2);
sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
tls_ctx->tx.aad_size);
sg_unmark_end(&rec->sg_aead_out[1]); sg_unmark_end(&rec->sg_aead_out[1]);
return rec; return rec;
...@@ -411,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) ...@@ -411,6 +415,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
struct aead_request *aead_req = (struct aead_request *)req; struct aead_request *aead_req = (struct aead_request *)req;
struct sock *sk = req->data; struct sock *sk = req->data;
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct scatterlist *sge; struct scatterlist *sge;
struct sk_msg *msg_en; struct sk_msg *msg_en;
...@@ -422,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err) ...@@ -422,8 +427,8 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
msg_en = &rec->msg_encrypted; msg_en = &rec->msg_encrypted;
sge = sk_msg_elem(msg_en, msg_en->sg.curr); sge = sk_msg_elem(msg_en, msg_en->sg.curr);
sge->offset -= tls_ctx->tx.prepend_size; sge->offset -= prot->prepend_size;
sge->length += tls_ctx->tx.prepend_size; sge->length += prot->prepend_size;
/* Check if error is previously set on socket */ /* Check if error is previously set on socket */
if (err || sk->sk_err) { if (err || sk->sk_err) {
...@@ -470,22 +475,23 @@ static int tls_do_encryption(struct sock *sk, ...@@ -470,22 +475,23 @@ static int tls_do_encryption(struct sock *sk,
struct aead_request *aead_req, struct aead_request *aead_req,
size_t data_len, u32 start) size_t data_len, u32 start)
{ {
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_rec *rec = ctx->open_rec; struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_en = &rec->msg_encrypted; struct sk_msg *msg_en = &rec->msg_encrypted;
struct scatterlist *sge = sk_msg_elem(msg_en, start); struct scatterlist *sge = sk_msg_elem(msg_en, start);
int rc; int rc;
memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data)); memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data));
xor_iv_with_seq(tls_ctx->crypto_send.info.version, rec->iv_data, xor_iv_with_seq(prot->version, rec->iv_data,
tls_ctx->tx.rec_seq); tls_ctx->tx.rec_seq);
sge->offset += tls_ctx->tx.prepend_size; sge->offset += prot->prepend_size;
sge->length -= tls_ctx->tx.prepend_size; sge->length -= prot->prepend_size;
msg_en->sg.curr = start; msg_en->sg.curr = start;
aead_request_set_tfm(aead_req, ctx->aead_send); aead_request_set_tfm(aead_req, ctx->aead_send);
aead_request_set_ad(aead_req, tls_ctx->tx.aad_size); aead_request_set_ad(aead_req, prot->aad_size);
aead_request_set_crypt(aead_req, rec->sg_aead_in, aead_request_set_crypt(aead_req, rec->sg_aead_in,
rec->sg_aead_out, rec->sg_aead_out,
data_len, rec->iv_data); data_len, rec->iv_data);
...@@ -500,8 +506,8 @@ static int tls_do_encryption(struct sock *sk, ...@@ -500,8 +506,8 @@ static int tls_do_encryption(struct sock *sk,
rc = crypto_aead_encrypt(aead_req); rc = crypto_aead_encrypt(aead_req);
if (!rc || rc != -EINPROGRESS) { if (!rc || rc != -EINPROGRESS) {
atomic_dec(&ctx->encrypt_pending); atomic_dec(&ctx->encrypt_pending);
sge->offset -= tls_ctx->tx.prepend_size; sge->offset -= prot->prepend_size;
sge->length += tls_ctx->tx.prepend_size; sge->length += prot->prepend_size;
} }
if (!rc) { if (!rc) {
...@@ -513,8 +519,7 @@ static int tls_do_encryption(struct sock *sk, ...@@ -513,8 +519,7 @@ static int tls_do_encryption(struct sock *sk,
/* Unhook the record from context if encryption is not failure */ /* Unhook the record from context if encryption is not failure */
ctx->open_rec = NULL; ctx->open_rec = NULL;
tls_advance_record_sn(sk, &tls_ctx->tx, tls_advance_record_sn(sk, &tls_ctx->tx, prot->version);
tls_ctx->crypto_send.info.version);
return rc; return rc;
} }
...@@ -640,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -640,6 +645,7 @@ static int tls_push_record(struct sock *sk, int flags,
unsigned char record_type) unsigned char record_type)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec = ctx->open_rec, *tmp = NULL; struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
u32 i, split_point, uninitialized_var(orig_end); u32 i, split_point, uninitialized_var(orig_end);
...@@ -658,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -658,12 +664,12 @@ static int tls_push_record(struct sock *sk, int flags,
split = split_point && split_point < msg_pl->sg.size; split = split_point && split_point < msg_pl->sg.size;
if (split) { if (split) {
rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
split_point, tls_ctx->tx.overhead_size, split_point, prot->overhead_size,
&orig_end); &orig_end);
if (rc < 0) if (rc < 0)
return rc; return rc;
sk_msg_trim(sk, msg_en, msg_pl->sg.size + sk_msg_trim(sk, msg_en, msg_pl->sg.size +
tls_ctx->tx.overhead_size); prot->overhead_size);
} }
rec->tx_flags = flags; rec->tx_flags = flags;
...@@ -673,7 +679,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -673,7 +679,7 @@ static int tls_push_record(struct sock *sk, int flags,
sk_msg_iter_var_prev(i); sk_msg_iter_var_prev(i);
rec->content_type = record_type; rec->content_type = record_type;
if (tls_ctx->crypto_send.info.version == TLS_1_3_VERSION) { if (prot->version == TLS_1_3_VERSION) {
/* Add content type to end of message. No padding added */ /* Add content type to end of message. No padding added */
sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
sg_mark_end(&rec->sg_content_type); sg_mark_end(&rec->sg_content_type);
...@@ -694,22 +700,20 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -694,22 +700,20 @@ static int tls_push_record(struct sock *sk, int flags,
i = msg_en->sg.start; i = msg_en->sg.start;
sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
tls_make_aad(rec->aad_space, msg_pl->sg.size + tls_ctx->tx.tail_size, tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, tls_ctx->tx.rec_seq, prot->rec_seq_size,
record_type, record_type, prot->version);
tls_ctx->crypto_send.info.version);
tls_fill_prepend(tls_ctx, tls_fill_prepend(tls_ctx,
page_address(sg_page(&msg_en->sg.data[i])) + page_address(sg_page(&msg_en->sg.data[i])) +
msg_en->sg.data[i].offset, msg_en->sg.data[i].offset,
msg_pl->sg.size + tls_ctx->tx.tail_size, msg_pl->sg.size + prot->tail_size,
record_type, record_type, prot->version);
tls_ctx->crypto_send.info.version);
tls_ctx->pending_open_record_frags = false; tls_ctx->pending_open_record_frags = false;
rc = tls_do_encryption(sk, tls_ctx, ctx, req, rc = tls_do_encryption(sk, tls_ctx, ctx, req,
msg_pl->sg.size + tls_ctx->tx.tail_size, i); msg_pl->sg.size + prot->tail_size, i);
if (rc < 0) { if (rc < 0) {
if (rc != -EINPROGRESS) { if (rc != -EINPROGRESS) {
tls_err_abort(sk, EBADMSG); tls_err_abort(sk, EBADMSG);
...@@ -723,8 +727,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -723,8 +727,7 @@ static int tls_push_record(struct sock *sk, int flags,
} else if (split) { } else if (split) {
msg_pl = &tmp->msg_plaintext; msg_pl = &tmp->msg_plaintext;
msg_en = &tmp->msg_encrypted; msg_en = &tmp->msg_encrypted;
sk_msg_trim(sk, msg_en, msg_pl->sg.size + sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
tls_ctx->tx.overhead_size);
tls_ctx->pending_open_record_frags = true; tls_ctx->pending_open_record_frags = true;
ctx->open_rec = tmp; ctx->open_rec = tmp;
} }
...@@ -859,6 +862,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -859,6 +862,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{ {
long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
bool async_capable = ctx->async_capable; bool async_capable = ctx->async_capable;
unsigned char record_type = TLS_RECORD_TYPE_DATA; unsigned char record_type = TLS_RECORD_TYPE_DATA;
...@@ -925,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -925,7 +929,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
} }
required_size = msg_pl->sg.size + try_to_copy + required_size = msg_pl->sg.size + try_to_copy +
tls_ctx->tx.overhead_size; prot->overhead_size;
if (!sk_stream_memory_free(sk)) if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf; goto wait_for_sndbuf;
...@@ -994,8 +998,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -994,8 +998,8 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
*/ */
try_to_copy -= required_size - msg_pl->sg.size; try_to_copy -= required_size - msg_pl->sg.size;
full_record = true; full_record = true;
sk_msg_trim(sk, msg_en, msg_pl->sg.size + sk_msg_trim(sk, msg_en,
tls_ctx->tx.overhead_size); msg_pl->sg.size + prot->overhead_size);
} }
if (try_to_copy) { if (try_to_copy) {
...@@ -1081,6 +1085,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, ...@@ -1081,6 +1085,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
unsigned char record_type = TLS_RECORD_TYPE_DATA; unsigned char record_type = TLS_RECORD_TYPE_DATA;
struct sk_msg *msg_pl; struct sk_msg *msg_pl;
struct tls_rec *rec; struct tls_rec *rec;
...@@ -1130,8 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page, ...@@ -1130,8 +1135,7 @@ static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
full_record = true; full_record = true;
} }
required_size = msg_pl->sg.size + copy + required_size = msg_pl->sg.size + copy + prot->overhead_size;
tls_ctx->tx.overhead_size;
if (!sk_stream_memory_free(sk)) if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf; goto wait_for_sndbuf;
...@@ -1330,6 +1334,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1330,6 +1334,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
struct aead_request *aead_req; struct aead_request *aead_req;
...@@ -1337,16 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1337,16 +1342,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
u8 *aad, *iv, *mem = NULL; u8 *aad, *iv, *mem = NULL;
struct scatterlist *sgin = NULL; struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL; struct scatterlist *sgout = NULL;
const int data_len = rxm->full_len - tls_ctx->rx.overhead_size + const int data_len = rxm->full_len - prot->overhead_size +
tls_ctx->rx.tail_size; prot->tail_size;
if (*zc && (out_iov || out_sg)) { if (*zc && (out_iov || out_sg)) {
if (out_iov) if (out_iov)
n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
else else
n_sgout = sg_nents(out_sg); n_sgout = sg_nents(out_sg);
n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
rxm->full_len - tls_ctx->rx.prepend_size); rxm->full_len - prot->prepend_size);
} else { } else {
n_sgout = 0; n_sgout = 0;
*zc = false; *zc = false;
...@@ -1363,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1363,7 +1368,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
mem_size = aead_size + (nsg * sizeof(struct scatterlist)); mem_size = aead_size + (nsg * sizeof(struct scatterlist));
mem_size = mem_size + tls_ctx->rx.aad_size; mem_size = mem_size + prot->aad_size;
mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
/* Allocate a single block of memory which contains /* Allocate a single block of memory which contains
...@@ -1379,37 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1379,37 +1384,35 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
sgin = (struct scatterlist *)(mem + aead_size); sgin = (struct scatterlist *)(mem + aead_size);
sgout = sgin + n_sgin; sgout = sgin + n_sgin;
aad = (u8 *)(sgout + n_sgout); aad = (u8 *)(sgout + n_sgout);
iv = aad + tls_ctx->rx.aad_size; iv = aad + prot->aad_size;
/* Prepare IV */ /* Prepare IV */
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
tls_ctx->rx.iv_size); prot->iv_size);
if (err < 0) { if (err < 0) {
kfree(mem); kfree(mem);
return err; return err;
} }
if (tls_ctx->crypto_recv.info.version == TLS_1_3_VERSION) if (prot->version == TLS_1_3_VERSION)
memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv)); memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv));
else else
memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
xor_iv_with_seq(tls_ctx->crypto_recv.info.version, iv, xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
tls_ctx->rx.rec_seq);
/* Prepare AAD */ /* Prepare AAD */
tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size + tls_make_aad(aad, rxm->full_len - prot->overhead_size +
tls_ctx->rx.tail_size, prot->tail_size,
tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, tls_ctx->rx.rec_seq, prot->rec_seq_size,
ctx->control, ctx->control, prot->version);
tls_ctx->crypto_recv.info.version);
/* Prepare sgin */ /* Prepare sgin */
sg_init_table(sgin, n_sgin); sg_init_table(sgin, n_sgin);
sg_set_buf(&sgin[0], aad, tls_ctx->rx.aad_size); sg_set_buf(&sgin[0], aad, prot->aad_size);
err = skb_to_sgvec(skb, &sgin[1], err = skb_to_sgvec(skb, &sgin[1],
rxm->offset + tls_ctx->rx.prepend_size, rxm->offset + prot->prepend_size,
rxm->full_len - tls_ctx->rx.prepend_size); rxm->full_len - prot->prepend_size);
if (err < 0) { if (err < 0) {
kfree(mem); kfree(mem);
return err; return err;
...@@ -1418,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb, ...@@ -1418,7 +1421,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
if (n_sgout) { if (n_sgout) {
if (out_iov) { if (out_iov) {
sg_init_table(sgout, n_sgout); sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], aad, tls_ctx->rx.aad_size); sg_set_buf(&sgout[0], aad, prot->aad_size);
*chunk = 0; *chunk = 0;
err = tls_setup_from_iter(sk, out_iov, data_len, err = tls_setup_from_iter(sk, out_iov, data_len,
...@@ -1459,7 +1462,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1459,7 +1462,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int version = tls_ctx->crypto_recv.info.version; struct tls_prot_info *prot = &tls_ctx->prot_info;
int version = prot->version;
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
int err = 0; int err = 0;
...@@ -1480,8 +1484,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, ...@@ -1480,8 +1484,8 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
rxm->full_len -= padding_length(ctx, tls_ctx, skb); rxm->full_len -= padding_length(ctx, tls_ctx, skb);
rxm->offset += tls_ctx->rx.prepend_size; rxm->offset += prot->prepend_size;
rxm->full_len -= tls_ctx->rx.overhead_size; rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, &tls_ctx->rx, version); tls_advance_record_sn(sk, &tls_ctx->rx, version);
ctx->decrypted = true; ctx->decrypted = true;
ctx->saved_data_ready(sk); ctx->saved_data_ready(sk);
...@@ -1605,6 +1609,7 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1605,6 +1609,7 @@ int tls_sw_recvmsg(struct sock *sk,
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct sk_psock *psock; struct sk_psock *psock;
unsigned char control = 0; unsigned char control = 0;
ssize_t decrypted = 0; ssize_t decrypted = 0;
...@@ -1667,11 +1672,11 @@ int tls_sw_recvmsg(struct sock *sk, ...@@ -1667,11 +1672,11 @@ int tls_sw_recvmsg(struct sock *sk,
rxm = strp_msg(skb); rxm = strp_msg(skb);
to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size; to_decrypt = rxm->full_len - prot->overhead_size;
if (to_decrypt <= len && !is_kvec && !is_peek && if (to_decrypt <= len && !is_kvec && !is_peek &&
ctx->control == TLS_RECORD_TYPE_DATA && ctx->control == TLS_RECORD_TYPE_DATA &&
tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION) prot->version != TLS_1_3_VERSION)
zc = true; zc = true;
/* Do not use async mode if record is non-data */ /* Do not use async mode if record is non-data */
...@@ -1875,6 +1880,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ...@@ -1875,6 +1880,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
{ {
struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
struct strp_msg *rxm = strp_msg(skb); struct strp_msg *rxm = strp_msg(skb);
size_t cipher_overhead; size_t cipher_overhead;
...@@ -1882,17 +1888,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ...@@ -1882,17 +1888,17 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
int ret; int ret;
/* Verify that we have a full TLS header, or wait for more data */ /* Verify that we have a full TLS header, or wait for more data */
if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) if (rxm->offset + prot->prepend_size > skb->len)
return 0; return 0;
/* Sanity-check size of on-stack buffer. */ /* Sanity-check size of on-stack buffer. */
if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { if (WARN_ON(prot->prepend_size > sizeof(header))) {
ret = -EINVAL; ret = -EINVAL;
goto read_failure; goto read_failure;
} }
/* Linearize header to local buffer */ /* Linearize header to local buffer */
ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
if (ret < 0) if (ret < 0)
goto read_failure; goto read_failure;
...@@ -1901,12 +1907,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ...@@ -1901,12 +1907,12 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
data_len = ((header[4] & 0xFF) | (header[3] << 8)); data_len = ((header[4] & 0xFF) | (header[3] << 8));
cipher_overhead = tls_ctx->rx.tag_size; cipher_overhead = prot->tag_size;
if (tls_ctx->crypto_recv.info.version != TLS_1_3_VERSION) if (prot->version != TLS_1_3_VERSION)
cipher_overhead += tls_ctx->rx.iv_size; cipher_overhead += prot->iv_size;
if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
tls_ctx->rx.tail_size) { prot->tail_size) {
ret = -EMSGSIZE; ret = -EMSGSIZE;
goto read_failure; goto read_failure;
} }
...@@ -2066,6 +2072,8 @@ static void tx_work_handler(struct work_struct *work) ...@@ -2066,6 +2072,8 @@ static void tx_work_handler(struct work_struct *work)
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_crypto_info *crypto_info; struct tls_crypto_info *crypto_info;
struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
...@@ -2171,18 +2179,20 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2171,18 +2179,20 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
if (crypto_info->version == TLS_1_3_VERSION) { if (crypto_info->version == TLS_1_3_VERSION) {
nonce_size = 0; nonce_size = 0;
cctx->aad_size = TLS_HEADER_SIZE; prot->aad_size = TLS_HEADER_SIZE;
cctx->tail_size = 1; prot->tail_size = 1;
} else { } else {
cctx->aad_size = TLS_AAD_SPACE_SIZE; prot->aad_size = TLS_AAD_SPACE_SIZE;
cctx->tail_size = 0; prot->tail_size = 0;
} }
cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; prot->version = crypto_info->version;
cctx->tag_size = tag_size; prot->cipher_type = crypto_info->cipher_type;
cctx->overhead_size = cctx->prepend_size + cctx->tag_size + prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
cctx->tail_size; prot->tag_size = tag_size;
cctx->iv_size = iv_size; prot->overhead_size = prot->prepend_size +
prot->tag_size + prot->tail_size;
prot->iv_size = iv_size;
cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
GFP_KERNEL); GFP_KERNEL);
if (!cctx->iv) { if (!cctx->iv) {
...@@ -2192,7 +2202,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2192,7 +2202,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
/* Note: 128 & 256 bit salt are the same size */ /* Note: 128 & 256 bit salt are the same size */
memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
cctx->rec_seq_size = rec_seq_size; prot->rec_seq_size = rec_seq_size;
cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!cctx->rec_seq) { if (!cctx->rec_seq) {
rc = -ENOMEM; rc = -ENOMEM;
...@@ -2215,7 +2225,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2215,7 +2225,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
if (rc) if (rc)
goto free_aead; goto free_aead;
rc = crypto_aead_setauthsize(*aead, cctx->tag_size); rc = crypto_aead_setauthsize(*aead, prot->tag_size);
if (rc) if (rc)
goto free_aead; goto free_aead;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment