Commit 71e52c27 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu

crypto: arm64/aes-ce-gcm - operate on two input blocks at a time

Update the core AES/GCM transform and the associated plumbing to operate
on 2 AES/GHASH blocks at a time. By itself, this is not expected to
result in a noticeable speedup, but it paves the way for reimplementing
the GHASH component using 2-way aggregation.
Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 3465893d
...@@ -286,9 +286,10 @@ ENTRY(pmull_ghash_update_p8) ...@@ -286,9 +286,10 @@ ENTRY(pmull_ghash_update_p8)
__pmull_ghash p8 __pmull_ghash p8
ENDPROC(pmull_ghash_update_p8) ENDPROC(pmull_ghash_update_p8)
KS .req v8 KS0 .req v8
CTR .req v9 KS1 .req v9
INP .req v10 INP0 .req v10
INP1 .req v11
.macro load_round_keys, rounds, rk .macro load_round_keys, rounds, rk
cmp \rounds, #12 cmp \rounds, #12
...@@ -336,84 +337,146 @@ CPU_LE( rev x8, x8 ) ...@@ -336,84 +337,146 @@ CPU_LE( rev x8, x8 )
.if \enc == 1 .if \enc == 1
ldr x10, [sp] ldr x10, [sp]
ld1 {KS.16b}, [x10] ld1 {KS0.16b-KS1.16b}, [x10]
.endif .endif
0: ld1 {CTR.8b}, [x5] // load upper counter 0: ld1 {INP0.16b-INP1.16b}, [x3], #32
ld1 {INP.16b}, [x3], #16
rev x9, x8 rev x9, x8
add x8, x8, #1 add x11, x8, #1
sub w0, w0, #1 add x8, x8, #2
ins CTR.d[1], x9 // set lower counter
.if \enc == 1 .if \enc == 1
eor INP.16b, INP.16b, KS.16b // encrypt input eor INP0.16b, INP0.16b, KS0.16b // encrypt input
st1 {INP.16b}, [x2], #16 eor INP1.16b, INP1.16b, KS1.16b
.endif .endif
rev64 T1.16b, INP.16b ld1 {KS0.8b}, [x5] // load upper counter
rev x11, x11
sub w0, w0, #2
mov KS1.8b, KS0.8b
ins KS0.d[1], x9 // set lower counter
ins KS1.d[1], x11
rev64 T1.16b, INP0.16b
cmp w7, #12 cmp w7, #12
b.ge 2f // AES-192/256? b.ge 2f // AES-192/256?
1: enc_round CTR, v21 1: enc_round KS0, v21
ext T2.16b, XL.16b, XL.16b, #8 ext T2.16b, XL.16b, XL.16b, #8
ext IN1.16b, T1.16b, T1.16b, #8 ext IN1.16b, T1.16b, T1.16b, #8
enc_round CTR, v22 enc_round KS1, v21
eor T1.16b, T1.16b, T2.16b eor T1.16b, T1.16b, T2.16b
eor XL.16b, XL.16b, IN1.16b eor XL.16b, XL.16b, IN1.16b
enc_round CTR, v23 enc_round KS0, v22
pmull2 XH.1q, SHASH.2d, XL.2d // a1 * b1 pmull2 XH.1q, SHASH.2d, XL.2d // a1 * b1
eor T1.16b, T1.16b, XL.16b eor T1.16b, T1.16b, XL.16b
enc_round CTR, v24 enc_round KS1, v22
pmull XL.1q, SHASH.1d, XL.1d // a0 * b0 pmull XL.1q, SHASH.1d, XL.1d // a0 * b0
pmull XM.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0) pmull XM.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0)
enc_round CTR, v25 enc_round KS0, v23
ext T1.16b, XL.16b, XH.16b, #8 ext T1.16b, XL.16b, XH.16b, #8
eor T2.16b, XL.16b, XH.16b eor T2.16b, XL.16b, XH.16b
eor XM.16b, XM.16b, T1.16b eor XM.16b, XM.16b, T1.16b
enc_round CTR, v26 enc_round KS1, v23
eor XM.16b, XM.16b, T2.16b eor XM.16b, XM.16b, T2.16b
pmull T2.1q, XL.1d, MASK.1d pmull T2.1q, XL.1d, MASK.1d
enc_round CTR, v27 enc_round KS0, v24
mov XH.d[0], XM.d[1] mov XH.d[0], XM.d[1]
mov XM.d[1], XL.d[0] mov XM.d[1], XL.d[0]
enc_round CTR, v28 enc_round KS1, v24
eor XL.16b, XM.16b, T2.16b eor XL.16b, XM.16b, T2.16b
enc_round CTR, v29 enc_round KS0, v25
ext T2.16b, XL.16b, XL.16b, #8 ext T2.16b, XL.16b, XL.16b, #8
aese CTR.16b, v30.16b enc_round KS1, v25
pmull XL.1q, XL.1d, MASK.1d pmull XL.1q, XL.1d, MASK.1d
eor T2.16b, T2.16b, XH.16b eor T2.16b, T2.16b, XH.16b
eor KS.16b, CTR.16b, v31.16b enc_round KS0, v26
eor XL.16b, XL.16b, T2.16b
rev64 T1.16b, INP1.16b
enc_round KS1, v26
ext T2.16b, XL.16b, XL.16b, #8
ext IN1.16b, T1.16b, T1.16b, #8
enc_round KS0, v27
eor T1.16b, T1.16b, T2.16b
eor XL.16b, XL.16b, IN1.16b
enc_round KS1, v27
pmull2 XH.1q, SHASH.2d, XL.2d // a1 * b1
eor T1.16b, T1.16b, XL.16b
enc_round KS0, v28
pmull XL.1q, SHASH.1d, XL.1d // a0 * b0
pmull XM.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0)
enc_round KS1, v28
ext T1.16b, XL.16b, XH.16b, #8
eor T2.16b, XL.16b, XH.16b
eor XM.16b, XM.16b, T1.16b
enc_round KS0, v29
eor XM.16b, XM.16b, T2.16b
pmull T2.1q, XL.1d, MASK.1d
enc_round KS1, v29
mov XH.d[0], XM.d[1]
mov XM.d[1], XL.d[0]
aese KS0.16b, v30.16b
eor XL.16b, XM.16b, T2.16b
aese KS1.16b, v30.16b
ext T2.16b, XL.16b, XL.16b, #8
eor KS0.16b, KS0.16b, v31.16b
pmull XL.1q, XL.1d, MASK.1d
eor T2.16b, T2.16b, XH.16b
eor KS1.16b, KS1.16b, v31.16b
eor XL.16b, XL.16b, T2.16b eor XL.16b, XL.16b, T2.16b
.if \enc == 0 .if \enc == 0
eor INP.16b, INP.16b, KS.16b eor INP0.16b, INP0.16b, KS0.16b
st1 {INP.16b}, [x2], #16 eor INP1.16b, INP1.16b, KS1.16b
.endif .endif
st1 {INP0.16b-INP1.16b}, [x2], #32
cbnz w0, 0b cbnz w0, 0b
CPU_LE( rev x8, x8 ) CPU_LE( rev x8, x8 )
...@@ -421,16 +484,20 @@ CPU_LE( rev x8, x8 ) ...@@ -421,16 +484,20 @@ CPU_LE( rev x8, x8 )
str x8, [x5, #8] // store lower counter str x8, [x5, #8] // store lower counter
.if \enc == 1 .if \enc == 1
st1 {KS.16b}, [x10] st1 {KS0.16b-KS1.16b}, [x10]
.endif .endif
ret ret
2: b.eq 3f // AES-192? 2: b.eq 3f // AES-192?
enc_round CTR, v17 enc_round KS0, v17
enc_round CTR, v18 enc_round KS1, v17
3: enc_round CTR, v19 enc_round KS0, v18
enc_round CTR, v20 enc_round KS1, v18
3: enc_round KS0, v19
enc_round KS1, v19
enc_round KS0, v20
enc_round KS1, v20
b 1b b 1b
.endm .endm
......
...@@ -348,9 +348,10 @@ static int gcm_encrypt(struct aead_request *req) ...@@ -348,9 +348,10 @@ static int gcm_encrypt(struct aead_request *req)
struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
struct skcipher_walk walk; struct skcipher_walk walk;
u8 iv[AES_BLOCK_SIZE]; u8 iv[AES_BLOCK_SIZE];
u8 ks[AES_BLOCK_SIZE]; u8 ks[2 * AES_BLOCK_SIZE];
u8 tag[AES_BLOCK_SIZE]; u8 tag[AES_BLOCK_SIZE];
u64 dg[2] = {}; u64 dg[2] = {};
int nrounds = num_rounds(&ctx->aes_key);
int err; int err;
if (req->assoclen) if (req->assoclen)
...@@ -362,32 +363,31 @@ static int gcm_encrypt(struct aead_request *req) ...@@ -362,32 +363,31 @@ static int gcm_encrypt(struct aead_request *req)
if (likely(may_use_simd())) { if (likely(may_use_simd())) {
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
num_rounds(&ctx->aes_key));
put_unaligned_be32(2, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
pmull_gcm_encrypt_block(ks, iv, NULL, pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
num_rounds(&ctx->aes_key));
put_unaligned_be32(3, iv + GCM_IV_SIZE); put_unaligned_be32(3, iv + GCM_IV_SIZE);
pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
put_unaligned_be32(4, iv + GCM_IV_SIZE);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_aead_encrypt(&walk, req, false); err = skcipher_walk_aead_encrypt(&walk, req, false);
while (walk.nbytes >= AES_BLOCK_SIZE) { while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE; int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr, pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
walk.src.virt.addr, &ctx->ghash_key, walk.src.virt.addr, &ctx->ghash_key,
iv, ctx->aes_key.key_enc, iv, ctx->aes_key.key_enc, nrounds,
num_rounds(&ctx->aes_key), ks); ks);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE); walk.nbytes % (2 * AES_BLOCK_SIZE));
} }
} else { } else {
__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
num_rounds(&ctx->aes_key));
put_unaligned_be32(2, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
err = skcipher_walk_aead_encrypt(&walk, req, false); err = skcipher_walk_aead_encrypt(&walk, req, false);
...@@ -399,8 +399,7 @@ static int gcm_encrypt(struct aead_request *req) ...@@ -399,8 +399,7 @@ static int gcm_encrypt(struct aead_request *req)
do { do {
__aes_arm64_encrypt(ctx->aes_key.key_enc, __aes_arm64_encrypt(ctx->aes_key.key_enc,
ks, iv, ks, iv, nrounds);
num_rounds(&ctx->aes_key));
crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE); crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE);
crypto_inc(iv, AES_BLOCK_SIZE); crypto_inc(iv, AES_BLOCK_SIZE);
...@@ -417,19 +416,28 @@ static int gcm_encrypt(struct aead_request *req) ...@@ -417,19 +416,28 @@ static int gcm_encrypt(struct aead_request *req)
} }
if (walk.nbytes) if (walk.nbytes)
__aes_arm64_encrypt(ctx->aes_key.key_enc, ks, iv, __aes_arm64_encrypt(ctx->aes_key.key_enc, ks, iv,
num_rounds(&ctx->aes_key)); nrounds);
} }
/* handle the tail */ /* handle the tail */
if (walk.nbytes) { if (walk.nbytes) {
u8 buf[GHASH_BLOCK_SIZE]; u8 buf[GHASH_BLOCK_SIZE];
unsigned int nbytes = walk.nbytes;
u8 *dst = walk.dst.virt.addr;
u8 *head = NULL;
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks, crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks,
walk.nbytes); walk.nbytes);
memcpy(buf, walk.dst.virt.addr, walk.nbytes); if (walk.nbytes > GHASH_BLOCK_SIZE) {
memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes); head = dst;
ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL); dst += GHASH_BLOCK_SIZE;
nbytes %= GHASH_BLOCK_SIZE;
}
memcpy(buf, dst, nbytes);
memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);
err = skcipher_walk_done(&walk, 0); err = skcipher_walk_done(&walk, 0);
} }
...@@ -452,10 +460,11 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -452,10 +460,11 @@ static int gcm_decrypt(struct aead_request *req)
struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
unsigned int authsize = crypto_aead_authsize(aead); unsigned int authsize = crypto_aead_authsize(aead);
struct skcipher_walk walk; struct skcipher_walk walk;
u8 iv[AES_BLOCK_SIZE]; u8 iv[2 * AES_BLOCK_SIZE];
u8 tag[AES_BLOCK_SIZE]; u8 tag[AES_BLOCK_SIZE];
u8 buf[GHASH_BLOCK_SIZE]; u8 buf[2 * GHASH_BLOCK_SIZE];
u64 dg[2] = {}; u64 dg[2] = {};
int nrounds = num_rounds(&ctx->aes_key);
int err; int err;
if (req->assoclen) if (req->assoclen)
...@@ -466,37 +475,44 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -466,37 +475,44 @@ static int gcm_decrypt(struct aead_request *req)
if (likely(may_use_simd())) { if (likely(may_use_simd())) {
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc,
num_rounds(&ctx->aes_key));
put_unaligned_be32(2, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_aead_decrypt(&walk, req, false); err = skcipher_walk_aead_decrypt(&walk, req, false);
while (walk.nbytes >= AES_BLOCK_SIZE) { while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE; int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr, pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
walk.src.virt.addr, &ctx->ghash_key, walk.src.virt.addr, &ctx->ghash_key,
iv, ctx->aes_key.key_enc, iv, ctx->aes_key.key_enc, nrounds);
num_rounds(&ctx->aes_key));
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE); walk.nbytes % (2 * AES_BLOCK_SIZE));
} }
if (walk.nbytes) { if (walk.nbytes) {
u8 *iv2 = iv + AES_BLOCK_SIZE;
if (walk.nbytes > AES_BLOCK_SIZE) {
memcpy(iv2, iv, AES_BLOCK_SIZE);
crypto_inc(iv2, AES_BLOCK_SIZE);
}
kernel_neon_begin(); kernel_neon_begin();
pmull_gcm_encrypt_block(iv, iv, ctx->aes_key.key_enc, pmull_gcm_encrypt_block(iv, iv, ctx->aes_key.key_enc,
num_rounds(&ctx->aes_key)); nrounds);
if (walk.nbytes > AES_BLOCK_SIZE)
pmull_gcm_encrypt_block(iv2, iv2, NULL,
nrounds);
kernel_neon_end(); kernel_neon_end();
} }
} else { } else {
__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
num_rounds(&ctx->aes_key));
put_unaligned_be32(2, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
err = skcipher_walk_aead_decrypt(&walk, req, false); err = skcipher_walk_aead_decrypt(&walk, req, false);
...@@ -511,8 +527,7 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -511,8 +527,7 @@ static int gcm_decrypt(struct aead_request *req)
do { do {
__aes_arm64_encrypt(ctx->aes_key.key_enc, __aes_arm64_encrypt(ctx->aes_key.key_enc,
buf, iv, buf, iv, nrounds);
num_rounds(&ctx->aes_key));
crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE); crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
crypto_inc(iv, AES_BLOCK_SIZE); crypto_inc(iv, AES_BLOCK_SIZE);
...@@ -525,14 +540,24 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -525,14 +540,24 @@ static int gcm_decrypt(struct aead_request *req)
} }
if (walk.nbytes) if (walk.nbytes)
__aes_arm64_encrypt(ctx->aes_key.key_enc, iv, iv, __aes_arm64_encrypt(ctx->aes_key.key_enc, iv, iv,
num_rounds(&ctx->aes_key)); nrounds);
} }
/* handle the tail */ /* handle the tail */
if (walk.nbytes) { if (walk.nbytes) {
memcpy(buf, walk.src.virt.addr, walk.nbytes); const u8 *src = walk.src.virt.addr;
memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes); const u8 *head = NULL;
ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL); unsigned int nbytes = walk.nbytes;
if (walk.nbytes > GHASH_BLOCK_SIZE) {
head = src;
src += GHASH_BLOCK_SIZE;
nbytes %= GHASH_BLOCK_SIZE;
}
memcpy(buf, src, nbytes);
memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv, crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv,
walk.nbytes); walk.nbytes);
...@@ -557,7 +582,7 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -557,7 +582,7 @@ static int gcm_decrypt(struct aead_request *req)
static struct aead_alg gcm_aes_alg = { static struct aead_alg gcm_aes_alg = {
.ivsize = GCM_IV_SIZE, .ivsize = GCM_IV_SIZE,
.chunksize = AES_BLOCK_SIZE, .chunksize = 2 * AES_BLOCK_SIZE,
.maxauthsize = AES_BLOCK_SIZE, .maxauthsize = AES_BLOCK_SIZE,
.setkey = gcm_setkey, .setkey = gcm_setkey,
.setauthsize = gcm_setauthsize, .setauthsize = gcm_setauthsize,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment