Commit 8d338c76 authored by Herbert Xu's avatar Herbert Xu

tls: Only use data field in crypto completion function

The crypto_async_request passed to the completion is not guaranteed
to be the original request object.  Only the data field can be relied
upon.

Fix this by storing the socket pointer with the AEAD request.
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 1dbab131
...@@ -70,6 +70,8 @@ struct tls_rec { ...@@ -70,6 +70,8 @@ struct tls_rec {
char content_type; char content_type;
struct scatterlist sg_content_type; struct scatterlist sg_content_type;
struct sock *sk;
char aad_space[TLS_AAD_SPACE_SIZE]; char aad_space[TLS_AAD_SPACE_SIZE];
u8 iv_data[MAX_IV_SIZE]; u8 iv_data[MAX_IV_SIZE];
struct aead_request aead_req; struct aead_request aead_req;
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include <linux/bug.h> #include <linux/bug.h>
#include <linux/sched/signal.h> #include <linux/sched/signal.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/kernel.h>
#include <linux/splice.h> #include <linux/splice.h>
#include <crypto/aead.h> #include <crypto/aead.h>
...@@ -57,6 +58,7 @@ struct tls_decrypt_arg { ...@@ -57,6 +58,7 @@ struct tls_decrypt_arg {
}; };
struct tls_decrypt_ctx { struct tls_decrypt_ctx {
struct sock *sk;
u8 iv[MAX_IV_SIZE]; u8 iv[MAX_IV_SIZE];
u8 aad[TLS_MAX_AAD_SIZE]; u8 aad[TLS_MAX_AAD_SIZE];
u8 tail; u8 tail;
...@@ -177,18 +179,25 @@ static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, ...@@ -177,18 +179,25 @@ static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
return sub; return sub;
} }
static void tls_decrypt_done(struct crypto_async_request *req, int err) static void tls_decrypt_done(crypto_completion_data_t *data, int err)
{ {
struct aead_request *aead_req = (struct aead_request *)req; struct aead_request *aead_req = crypto_get_completion_data(data);
struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
struct scatterlist *sgout = aead_req->dst; struct scatterlist *sgout = aead_req->dst;
struct scatterlist *sgin = aead_req->src; struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx; struct tls_sw_context_rx *ctx;
struct tls_decrypt_ctx *dctx;
struct tls_context *tls_ctx; struct tls_context *tls_ctx;
struct scatterlist *sg; struct scatterlist *sg;
unsigned int pages; unsigned int pages;
struct sock *sk; struct sock *sk;
int aead_size;
sk = (struct sock *)req->data; aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead);
aead_size = ALIGN(aead_size, __alignof__(*dctx));
dctx = (void *)((u8 *)aead_req + aead_size);
sk = dctx->sk;
tls_ctx = tls_get_ctx(sk); tls_ctx = tls_get_ctx(sk);
ctx = tls_sw_ctx_rx(tls_ctx); ctx = tls_sw_ctx_rx(tls_ctx);
...@@ -240,7 +249,7 @@ static int tls_do_decryption(struct sock *sk, ...@@ -240,7 +249,7 @@ static int tls_do_decryption(struct sock *sk,
if (darg->async) { if (darg->async) {
aead_request_set_callback(aead_req, aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG, CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_decrypt_done, sk); tls_decrypt_done, aead_req);
atomic_inc(&ctx->decrypt_pending); atomic_inc(&ctx->decrypt_pending);
} else { } else {
aead_request_set_callback(aead_req, aead_request_set_callback(aead_req,
...@@ -336,6 +345,8 @@ static struct tls_rec *tls_get_rec(struct sock *sk) ...@@ -336,6 +345,8 @@ static struct tls_rec *tls_get_rec(struct sock *sk)
sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
sg_unmark_end(&rec->sg_aead_out[1]); sg_unmark_end(&rec->sg_aead_out[1]);
rec->sk = sk;
return rec; return rec;
} }
...@@ -417,22 +428,27 @@ int tls_tx_records(struct sock *sk, int flags) ...@@ -417,22 +428,27 @@ int tls_tx_records(struct sock *sk, int flags)
return rc; return rc;
} }
static void tls_encrypt_done(struct crypto_async_request *req, int err) static void tls_encrypt_done(crypto_completion_data_t *data, int err)
{ {
struct aead_request *aead_req = (struct aead_request *)req; struct aead_request *aead_req = crypto_get_completion_data(data);
struct sock *sk = req->data; struct tls_sw_context_tx *ctx;
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx;
struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_prot_info *prot;
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct scatterlist *sge; struct scatterlist *sge;
struct sk_msg *msg_en; struct sk_msg *msg_en;
struct tls_rec *rec; struct tls_rec *rec;
bool ready = false; bool ready = false;
struct sock *sk;
int pending; int pending;
rec = container_of(aead_req, struct tls_rec, aead_req); rec = container_of(aead_req, struct tls_rec, aead_req);
msg_en = &rec->msg_encrypted; msg_en = &rec->msg_encrypted;
sk = rec->sk;
tls_ctx = tls_get_ctx(sk);
prot = &tls_ctx->prot_info;
ctx = tls_sw_ctx_tx(tls_ctx);
sge = sk_msg_elem(msg_en, msg_en->sg.curr); sge = sk_msg_elem(msg_en, msg_en->sg.curr);
sge->offset -= prot->prepend_size; sge->offset -= prot->prepend_size;
sge->length += prot->prepend_size; sge->length += prot->prepend_size;
...@@ -520,7 +536,7 @@ static int tls_do_encryption(struct sock *sk, ...@@ -520,7 +536,7 @@ static int tls_do_encryption(struct sock *sk,
data_len, rec->iv_data); data_len, rec->iv_data);
aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_encrypt_done, sk); tls_encrypt_done, aead_req);
/* Add the record in tx_list */ /* Add the record in tx_list */
list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
...@@ -1485,6 +1501,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, ...@@ -1485,6 +1501,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
* Both structs are variable length. * Both structs are variable length.
*/ */
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
aead_size = ALIGN(aead_size, __alignof__(*dctx));
mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout), mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
sk->sk_allocation); sk->sk_allocation);
if (!mem) { if (!mem) {
...@@ -1495,6 +1512,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, ...@@ -1495,6 +1512,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
/* Segment the allocated memory */ /* Segment the allocated memory */
aead_req = (struct aead_request *)mem; aead_req = (struct aead_request *)mem;
dctx = (struct tls_decrypt_ctx *)(mem + aead_size); dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
dctx->sk = sk;
sgin = &dctx->sg[0]; sgin = &dctx->sg[0];
sgout = &dctx->sg[n_sgin]; sgout = &dctx->sg[n_sgin];
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment