Commit 11031c0d authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu

crypto: arm64/gcm-ce - implement 4 way interleave

To improve performance on cores with deep pipelines such as ThunderX2,
reimplement gcm(aes) using a 4-way interleave rather than the 2-way
interleave we use currently.

This comes down to a complete rewrite of the GCM part of the combined
GCM/GHASH driver, and instead of interleaving two invocations of AES
with the GHASH handling at the instruction level, the new version
uses a more coarse grained approach where each chunk of 64 bytes is
encrypted first and then ghashed (or ghashed and then decrypted in
the converse case).

The core NEON routine is now able to consume inputs of any size,
and tail blocks of less than 64 bytes are handled using overlapping
loads and stores, and processed by the same 4-way encryption and
hashing routines. This gets rid of most of the branches, and avoids
having to return to the C code to handle the tail block using a
stack buffer.

The table below compares the performance of the old driver and the new
one on various micro-architectures and running in various modes.

        |     AES-128      |     AES-192      |     AES-256      |
 #bytes | 512 | 1500 |  4k | 512 | 1500 |  4k | 512 | 1500 |  4k |
 -------+-----+------+-----+-----+------+-----+-----+------+-----+
    TX2 | 35% |  23% | 11% | 34% |  20% |  9% | 38% |  25% | 16% |
   EMAG | 11% |   6% |  3% | 12% |   4% |  2% | 11% |   4% |  2% |
    A72 |  8% |   5% | -4% |  9% |   4% | -5% |  7% |   4% | -5% |
    A53 | 11% |   6% | -1% | 10% |   8% | -1% | 10% |   8% | -2% |
Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent ec05a74f
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
T1 .req v2 T1 .req v2
T2 .req v3 T2 .req v3
MASK .req v4 MASK .req v4
XL .req v5 XM .req v5
XM .req v6 XL .req v6
XH .req v7 XH .req v7
IN1 .req v7 IN1 .req v7
...@@ -358,20 +358,37 @@ ENTRY(pmull_ghash_update_p8) ...@@ -358,20 +358,37 @@ ENTRY(pmull_ghash_update_p8)
__pmull_ghash p8 __pmull_ghash p8
ENDPROC(pmull_ghash_update_p8) ENDPROC(pmull_ghash_update_p8)
KS0 .req v12 KS0 .req v8
KS1 .req v13 KS1 .req v9
INP0 .req v14 KS2 .req v10
INP1 .req v15 KS3 .req v11
.macro load_round_keys, rounds, rk INP0 .req v21
cmp \rounds, #12 INP1 .req v22
blo 2222f /* 128 bits */ INP2 .req v23
beq 1111f /* 192 bits */ INP3 .req v24
ld1 {v17.4s-v18.4s}, [\rk], #32
1111: ld1 {v19.4s-v20.4s}, [\rk], #32 K0 .req v25
2222: ld1 {v21.4s-v24.4s}, [\rk], #64 K1 .req v26
ld1 {v25.4s-v28.4s}, [\rk], #64 K2 .req v27
ld1 {v29.4s-v31.4s}, [\rk] K3 .req v28
K4 .req v12
K5 .req v13
K6 .req v4
K7 .req v5
K8 .req v14
K9 .req v15
KK .req v29
KL .req v30
KM .req v31
.macro load_round_keys, rounds, rk, tmp
add \tmp, \rk, #64
ld1 {K0.4s-K3.4s}, [\rk]
ld1 {K4.4s-K5.4s}, [\tmp]
add \tmp, \rk, \rounds, lsl #4
sub \tmp, \tmp, #32
ld1 {KK.4s-KM.4s}, [\tmp]
.endm .endm
.macro enc_round, state, key .macro enc_round, state, key
...@@ -379,197 +396,367 @@ ENDPROC(pmull_ghash_update_p8) ...@@ -379,197 +396,367 @@ ENDPROC(pmull_ghash_update_p8)
aesmc \state\().16b, \state\().16b aesmc \state\().16b, \state\().16b
.endm .endm
.macro enc_block, state, rounds .macro enc_qround, s0, s1, s2, s3, key
cmp \rounds, #12 enc_round \s0, \key
b.lo 2222f /* 128 bits */ enc_round \s1, \key
b.eq 1111f /* 192 bits */ enc_round \s2, \key
enc_round \state, v17 enc_round \s3, \key
enc_round \state, v18 .endm
1111: enc_round \state, v19
enc_round \state, v20 .macro enc_block, state, rounds, rk, tmp
2222: .irp key, v21, v22, v23, v24, v25, v26, v27, v28, v29 add \tmp, \rk, #96
ld1 {K6.4s-K7.4s}, [\tmp], #32
.irp key, K0, K1, K2, K3, K4 K5
enc_round \state, \key enc_round \state, \key
.endr .endr
aese \state\().16b, v30.16b
eor \state\().16b, \state\().16b, v31.16b tbnz \rounds, #2, .Lnot128_\@
.Lout256_\@:
enc_round \state, K6
enc_round \state, K7
.Lout192_\@:
enc_round \state, KK
aese \state\().16b, KL.16b
eor \state\().16b, \state\().16b, KM.16b
.subsection 1
.Lnot128_\@:
ld1 {K8.4s-K9.4s}, [\tmp], #32
enc_round \state, K6
enc_round \state, K7
ld1 {K6.4s-K7.4s}, [\tmp]
enc_round \state, K8
enc_round \state, K9
tbz \rounds, #1, .Lout192_\@
b .Lout256_\@
.previous
.endm .endm
.align 6
.macro pmull_gcm_do_crypt, enc .macro pmull_gcm_do_crypt, enc
ld1 {SHASH.2d}, [x4], #16 stp x29, x30, [sp, #-32]!
ld1 {HH.2d}, [x4] mov x29, sp
ld1 {XL.2d}, [x1] str x19, [sp, #24]
ldr x8, [x5, #8] // load lower counter
load_round_keys x7, x6, x8
ld1 {SHASH.2d}, [x3], #16
ld1 {HH.2d-HH4.2d}, [x3]
movi MASK.16b, #0xe1
trn1 SHASH2.2d, SHASH.2d, HH.2d trn1 SHASH2.2d, SHASH.2d, HH.2d
trn2 T1.2d, SHASH.2d, HH.2d trn2 T1.2d, SHASH.2d, HH.2d
CPU_LE( rev x8, x8 )
shl MASK.2d, MASK.2d, #57
eor SHASH2.16b, SHASH2.16b, T1.16b eor SHASH2.16b, SHASH2.16b, T1.16b
.if \enc == 1 trn1 HH34.2d, HH3.2d, HH4.2d
ldr x10, [sp] trn2 T1.2d, HH3.2d, HH4.2d
ld1 {KS0.16b-KS1.16b}, [x10] eor HH34.16b, HH34.16b, T1.16b
.endif
cbnz x6, 4f ld1 {XL.2d}, [x4]
0: ld1 {INP0.16b-INP1.16b}, [x3], #32 cbz x0, 3f // tag only?
rev x9, x8 ldr w8, [x5, #12] // load lower counter
add x11, x8, #1 CPU_LE( rev w8, w8 )
add x8, x8, #2
.if \enc == 1 0: mov w9, #4 // max blocks per round
eor INP0.16b, INP0.16b, KS0.16b // encrypt input add x10, x0, #0xf
eor INP1.16b, INP1.16b, KS1.16b lsr x10, x10, #4 // remaining blocks
subs x0, x0, #64
csel w9, w10, w9, mi
add w8, w8, w9
bmi 1f
ld1 {INP0.16b-INP3.16b}, [x2], #64
.subsection 1
/*
* Populate the four input registers right to left with up to 63 bytes
* of data, using overlapping loads to avoid branches.
*
* INP0 INP1 INP2 INP3
* 1 byte | | | |x |
* 16 bytes | | | |xxxxxxxx|
* 17 bytes | | |xxxxxxxx|x |
* 47 bytes | |xxxxxxxx|xxxxxxxx|xxxxxxx |
* etc etc
*
* Note that this code may read up to 15 bytes before the start of
* the input. It is up to the calling code to ensure this is safe if
* this happens in the first iteration of the loop (i.e., when the
* input size is < 16 bytes)
*/
1: mov x15, #16
ands x19, x0, #0xf
csel x19, x19, x15, ne
adr_l x17, .Lpermute_table + 16
sub x11, x15, x19
add x12, x17, x11
sub x17, x17, x11
ld1 {T1.16b}, [x12]
sub x10, x1, x11
sub x11, x2, x11
cmp x0, #-16
csel x14, x15, xzr, gt
cmp x0, #-32
csel x15, x15, xzr, gt
cmp x0, #-48
csel x16, x19, xzr, gt
csel x1, x1, x10, gt
csel x2, x2, x11, gt
ld1 {INP0.16b}, [x2], x14
ld1 {INP1.16b}, [x2], x15
ld1 {INP2.16b}, [x2], x16
ld1 {INP3.16b}, [x2]
tbl INP3.16b, {INP3.16b}, T1.16b
b 2f
.previous
2: .if \enc == 0
bl pmull_gcm_ghash_4x
.endif .endif
ld1 {KS0.8b}, [x5] // load upper counter bl pmull_gcm_enc_4x
rev x11, x11
sub w0, w0, #2
mov KS1.8b, KS0.8b
ins KS0.d[1], x9 // set lower counter
ins KS1.d[1], x11
rev64 T1.16b, INP1.16b tbnz x0, #63, 6f
st1 {INP0.16b-INP3.16b}, [x1], #64
.if \enc == 1
bl pmull_gcm_ghash_4x
.endif
bne 0b
cmp w7, #12 3: ldp x19, x10, [sp, #24]
b.ge 2f // AES-192/256? cbz x10, 5f // output tag?
1: enc_round KS0, v21 ld1 {INP3.16b}, [x10] // load lengths[]
ext IN1.16b, T1.16b, T1.16b, #8 mov w9, #1
bl pmull_gcm_ghash_4x
enc_round KS1, v21 mov w11, #(0x1 << 24) // BE '1U'
pmull2 XH2.1q, SHASH.2d, IN1.2d // a1 * b1 ld1 {KS0.16b}, [x5]
mov KS0.s[3], w11
enc_round KS0, v22 enc_block KS0, x7, x6, x12
eor T1.16b, T1.16b, IN1.16b
enc_round KS1, v22 ext XL.16b, XL.16b, XL.16b, #8
pmull XL2.1q, SHASH.1d, IN1.1d // a0 * b0 rev64 XL.16b, XL.16b
eor XL.16b, XL.16b, KS0.16b
st1 {XL.16b}, [x10] // store tag
enc_round KS0, v23 4: ldp x29, x30, [sp], #32
pmull XM2.1q, SHASH2.1d, T1.1d // (a1 + a0)(b1 + b0) ret
enc_round KS1, v23 5:
rev64 T1.16b, INP0.16b CPU_LE( rev w8, w8 )
ext T2.16b, XL.16b, XL.16b, #8 str w8, [x5, #12] // store lower counter
st1 {XL.2d}, [x4]
b 4b
6: ld1 {T1.16b-T2.16b}, [x17], #32 // permute vectors
sub x17, x17, x19, lsl #1
cmp w9, #1
beq 7f
.subsection 1
7: ld1 {INP2.16b}, [x1]
tbx INP2.16b, {INP3.16b}, T1.16b
mov INP3.16b, INP2.16b
b 8f
.previous
st1 {INP0.16b}, [x1], x14
st1 {INP1.16b}, [x1], x15
st1 {INP2.16b}, [x1], x16
tbl INP3.16b, {INP3.16b}, T1.16b
tbx INP3.16b, {INP2.16b}, T2.16b
8: st1 {INP3.16b}, [x1]
enc_round KS0, v24 .if \enc == 1
ext IN1.16b, T1.16b, T1.16b, #8 ld1 {T1.16b}, [x17]
eor T1.16b, T1.16b, T2.16b tbl INP3.16b, {INP3.16b}, T1.16b // clear non-data bits
bl pmull_gcm_ghash_4x
.endif
b 3b
.endm
enc_round KS1, v24 /*
eor XL.16b, XL.16b, IN1.16b * void pmull_gcm_encrypt(int blocks, u8 dst[], const u8 src[],
* struct ghash_key const *k, u64 dg[], u8 ctr[],
* int rounds, u8 tag)
*/
ENTRY(pmull_gcm_encrypt)
pmull_gcm_do_crypt 1
ENDPROC(pmull_gcm_encrypt)
enc_round KS0, v25 /*
eor T1.16b, T1.16b, XL.16b * void pmull_gcm_decrypt(int blocks, u8 dst[], const u8 src[],
* struct ghash_key const *k, u64 dg[], u8 ctr[],
* int rounds, u8 tag)
*/
ENTRY(pmull_gcm_decrypt)
pmull_gcm_do_crypt 0
ENDPROC(pmull_gcm_decrypt)
enc_round KS1, v25 pmull_gcm_ghash_4x:
pmull2 XH.1q, HH.2d, XL.2d // a1 * b1 movi MASK.16b, #0xe1
shl MASK.2d, MASK.2d, #57
enc_round KS0, v26 rev64 T1.16b, INP0.16b
pmull XL.1q, HH.1d, XL.1d // a0 * b0 rev64 T2.16b, INP1.16b
rev64 TT3.16b, INP2.16b
rev64 TT4.16b, INP3.16b
enc_round KS1, v26 ext XL.16b, XL.16b, XL.16b, #8
pmull2 XM.1q, SHASH2.2d, T1.2d // (a1 + a0)(b1 + b0)
enc_round KS0, v27 tbz w9, #2, 0f // <4 blocks?
eor XL.16b, XL.16b, XL2.16b .subsection 1
eor XH.16b, XH.16b, XH2.16b 0: movi XH2.16b, #0
movi XM2.16b, #0
movi XL2.16b, #0
enc_round KS1, v27 tbz w9, #0, 1f // 2 blocks?
eor XM.16b, XM.16b, XM2.16b tbz w9, #1, 2f // 1 block?
ext T1.16b, XL.16b, XH.16b, #8
enc_round KS0, v28 eor T2.16b, T2.16b, XL.16b
eor T2.16b, XL.16b, XH.16b ext T1.16b, T2.16b, T2.16b, #8
eor XM.16b, XM.16b, T1.16b b .Lgh3
enc_round KS1, v28 1: eor TT3.16b, TT3.16b, XL.16b
eor XM.16b, XM.16b, T2.16b ext T2.16b, TT3.16b, TT3.16b, #8
b .Lgh2
enc_round KS0, v29 2: eor TT4.16b, TT4.16b, XL.16b
pmull T2.1q, XL.1d, MASK.1d ext IN1.16b, TT4.16b, TT4.16b, #8
b .Lgh1
.previous
enc_round KS1, v29 eor T1.16b, T1.16b, XL.16b
mov XH.d[0], XM.d[1] ext IN1.16b, T1.16b, T1.16b, #8
mov XM.d[1], XL.d[0]
aese KS0.16b, v30.16b pmull2 XH2.1q, HH4.2d, IN1.2d // a1 * b1
eor XL.16b, XM.16b, T2.16b eor T1.16b, T1.16b, IN1.16b
pmull XL2.1q, HH4.1d, IN1.1d // a0 * b0
pmull2 XM2.1q, HH34.2d, T1.2d // (a1 + a0)(b1 + b0)
aese KS1.16b, v30.16b ext T1.16b, T2.16b, T2.16b, #8
ext T2.16b, XL.16b, XL.16b, #8 .Lgh3: eor T2.16b, T2.16b, T1.16b
pmull2 XH.1q, HH3.2d, T1.2d // a1 * b1
pmull XL.1q, HH3.1d, T1.1d // a0 * b0
pmull XM.1q, HH34.1d, T2.1d // (a1 + a0)(b1 + b0)
eor KS0.16b, KS0.16b, v31.16b eor XH2.16b, XH2.16b, XH.16b
pmull XL.1q, XL.1d, MASK.1d eor XL2.16b, XL2.16b, XL.16b
eor T2.16b, T2.16b, XH.16b eor XM2.16b, XM2.16b, XM.16b
eor KS1.16b, KS1.16b, v31.16b ext T2.16b, TT3.16b, TT3.16b, #8
eor XL.16b, XL.16b, T2.16b .Lgh2: eor TT3.16b, TT3.16b, T2.16b
pmull2 XH.1q, HH.2d, T2.2d // a1 * b1
pmull XL.1q, HH.1d, T2.1d // a0 * b0
pmull2 XM.1q, SHASH2.2d, TT3.2d // (a1 + a0)(b1 + b0)
.if \enc == 0 eor XH2.16b, XH2.16b, XH.16b
eor INP0.16b, INP0.16b, KS0.16b eor XL2.16b, XL2.16b, XL.16b
eor INP1.16b, INP1.16b, KS1.16b eor XM2.16b, XM2.16b, XM.16b
.endif
st1 {INP0.16b-INP1.16b}, [x2], #32 ext IN1.16b, TT4.16b, TT4.16b, #8
.Lgh1: eor TT4.16b, TT4.16b, IN1.16b
pmull XL.1q, SHASH.1d, IN1.1d // a0 * b0
pmull2 XH.1q, SHASH.2d, IN1.2d // a1 * b1
pmull XM.1q, SHASH2.1d, TT4.1d // (a1 + a0)(b1 + b0)
cbnz w0, 0b eor XH.16b, XH.16b, XH2.16b
eor XL.16b, XL.16b, XL2.16b
eor XM.16b, XM.16b, XM2.16b
CPU_LE( rev x8, x8 ) eor T2.16b, XL.16b, XH.16b
st1 {XL.2d}, [x1] ext T1.16b, XL.16b, XH.16b, #8
str x8, [x5, #8] // store lower counter eor XM.16b, XM.16b, T2.16b
.if \enc == 1 __pmull_reduce_p64
st1 {KS0.16b-KS1.16b}, [x10]
.endif eor T2.16b, T2.16b, XH.16b
eor XL.16b, XL.16b, T2.16b
ret ret
ENDPROC(pmull_gcm_ghash_4x)
pmull_gcm_enc_4x:
ld1 {KS0.16b}, [x5] // load upper counter
sub w10, w8, #4
sub w11, w8, #3
sub w12, w8, #2
sub w13, w8, #1
rev w10, w10
rev w11, w11
rev w12, w12
rev w13, w13
mov KS1.16b, KS0.16b
mov KS2.16b, KS0.16b
mov KS3.16b, KS0.16b
ins KS0.s[3], w10 // set lower counter
ins KS1.s[3], w11
ins KS2.s[3], w12
ins KS3.s[3], w13
add x10, x6, #96 // round key pointer
ld1 {K6.4s-K7.4s}, [x10], #32
.irp key, K0, K1, K2, K3, K4, K5
enc_qround KS0, KS1, KS2, KS3, \key
.endr
2: b.eq 3f // AES-192? tbnz x7, #2, .Lnot128
enc_round KS0, v17 .subsection 1
enc_round KS1, v17 .Lnot128:
enc_round KS0, v18 ld1 {K8.4s-K9.4s}, [x10], #32
enc_round KS1, v18 .irp key, K6, K7
3: enc_round KS0, v19 enc_qround KS0, KS1, KS2, KS3, \key
enc_round KS1, v19 .endr
enc_round KS0, v20 ld1 {K6.4s-K7.4s}, [x10]
enc_round KS1, v20 .irp key, K8, K9
b 1b enc_qround KS0, KS1, KS2, KS3, \key
.endr
tbz x7, #1, .Lout192
b .Lout256
.previous
4: load_round_keys w7, x6 .Lout256:
b 0b .irp key, K6, K7
.endm enc_qround KS0, KS1, KS2, KS3, \key
.endr
/* .Lout192:
* void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], const u8 src[], enc_qround KS0, KS1, KS2, KS3, KK
* struct ghash_key const *k, u8 ctr[],
* int rounds, u8 ks[])
*/
ENTRY(pmull_gcm_encrypt)
pmull_gcm_do_crypt 1
ENDPROC(pmull_gcm_encrypt)
/* aese KS0.16b, KL.16b
* void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], const u8 src[], aese KS1.16b, KL.16b
* struct ghash_key const *k, u8 ctr[], aese KS2.16b, KL.16b
* int rounds) aese KS3.16b, KL.16b
*/
ENTRY(pmull_gcm_decrypt) eor KS0.16b, KS0.16b, KM.16b
pmull_gcm_do_crypt 0 eor KS1.16b, KS1.16b, KM.16b
ENDPROC(pmull_gcm_decrypt) eor KS2.16b, KS2.16b, KM.16b
eor KS3.16b, KS3.16b, KM.16b
eor INP0.16b, INP0.16b, KS0.16b
eor INP1.16b, INP1.16b, KS1.16b
eor INP2.16b, INP2.16b, KS2.16b
eor INP3.16b, INP3.16b, KS3.16b
/*
* void pmull_gcm_encrypt_block(u8 dst[], u8 src[], u8 rk[], int rounds)
*/
ENTRY(pmull_gcm_encrypt_block)
cbz x2, 0f
load_round_keys w3, x2
0: ld1 {v0.16b}, [x1]
enc_block v0, w3
st1 {v0.16b}, [x0]
ret ret
ENDPROC(pmull_gcm_encrypt_block) ENDPROC(pmull_gcm_enc_4x)
.section ".rodata", "a"
.align 6
.Lpermute_table:
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
.byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
.byte 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
.previous
...@@ -58,17 +58,15 @@ asmlinkage void pmull_ghash_update_p8(int blocks, u64 dg[], const char *src, ...@@ -58,17 +58,15 @@ asmlinkage void pmull_ghash_update_p8(int blocks, u64 dg[], const char *src,
struct ghash_key const *k, struct ghash_key const *k,
const char *head); const char *head);
asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], asmlinkage void pmull_gcm_encrypt(int bytes, u8 dst[], const u8 src[],
const u8 src[], struct ghash_key const *k, struct ghash_key const *k, u64 dg[],
u8 ctr[], u32 const rk[], int rounds, u8 ctr[], u32 const rk[], int rounds,
u8 ks[]); u8 tag[]);
asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], asmlinkage void pmull_gcm_decrypt(int bytes, u8 dst[], const u8 src[],
const u8 src[], struct ghash_key const *k, struct ghash_key const *k, u64 dg[],
u8 ctr[], u32 const rk[], int rounds); u8 ctr[], u32 const rk[], int rounds,
u8 tag[]);
asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[],
u32 const rk[], int rounds);
static int ghash_init(struct shash_desc *desc) static int ghash_init(struct shash_desc *desc)
{ {
...@@ -85,7 +83,7 @@ static void ghash_do_update(int blocks, u64 dg[], const char *src, ...@@ -85,7 +83,7 @@ static void ghash_do_update(int blocks, u64 dg[], const char *src,
struct ghash_key const *k, struct ghash_key const *k,
const char *head)) const char *head))
{ {
if (likely(crypto_simd_usable())) { if (likely(crypto_simd_usable() && simd_update)) {
kernel_neon_begin(); kernel_neon_begin();
simd_update(blocks, dg, src, key, head); simd_update(blocks, dg, src, key, head);
kernel_neon_end(); kernel_neon_end();
...@@ -398,136 +396,112 @@ static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[]) ...@@ -398,136 +396,112 @@ static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[])
} }
} }
static void gcm_final(struct aead_request *req, struct gcm_aes_ctx *ctx,
u64 dg[], u8 tag[], int cryptlen)
{
u8 mac[AES_BLOCK_SIZE];
u128 lengths;
lengths.a = cpu_to_be64(req->assoclen * 8);
lengths.b = cpu_to_be64(cryptlen * 8);
ghash_do_update(1, dg, (void *)&lengths, &ctx->ghash_key, NULL,
pmull_ghash_update_p64);
put_unaligned_be64(dg[1], mac);
put_unaligned_be64(dg[0], mac + 8);
crypto_xor(tag, mac, AES_BLOCK_SIZE);
}
static int gcm_encrypt(struct aead_request *req) static int gcm_encrypt(struct aead_request *req)
{ {
struct crypto_aead *aead = crypto_aead_reqtfm(req); struct crypto_aead *aead = crypto_aead_reqtfm(req);
struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
int nrounds = num_rounds(&ctx->aes_key);
struct skcipher_walk walk; struct skcipher_walk walk;
u8 buf[AES_BLOCK_SIZE];
u8 iv[AES_BLOCK_SIZE]; u8 iv[AES_BLOCK_SIZE];
u8 ks[2 * AES_BLOCK_SIZE];
u8 tag[AES_BLOCK_SIZE];
u64 dg[2] = {}; u64 dg[2] = {};
int nrounds = num_rounds(&ctx->aes_key); u128 lengths;
u8 *tag;
int err; int err;
lengths.a = cpu_to_be64(req->assoclen * 8);
lengths.b = cpu_to_be64(req->cryptlen * 8);
if (req->assoclen) if (req->assoclen)
gcm_calculate_auth_mac(req, dg); gcm_calculate_auth_mac(req, dg);
memcpy(iv, req->iv, GCM_IV_SIZE); memcpy(iv, req->iv, GCM_IV_SIZE);
put_unaligned_be32(1, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
err = skcipher_walk_aead_encrypt(&walk, req, false); err = skcipher_walk_aead_encrypt(&walk, req, false);
if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) { if (likely(crypto_simd_usable())) {
u32 const *rk = NULL;
kernel_neon_begin();
pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
put_unaligned_be32(2, iv + GCM_IV_SIZE);
pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
put_unaligned_be32(3, iv + GCM_IV_SIZE);
pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
put_unaligned_be32(4, iv + GCM_IV_SIZE);
do { do {
int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
tag = (u8 *)&lengths;
if (rk) if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
kernel_neon_begin(); src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
}
pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr, kernel_neon_begin();
walk.src.virt.addr, &ctx->ghash_key, pmull_gcm_encrypt(nbytes, dst, src, &ctx->ghash_key, dg,
iv, rk, nrounds, ks); iv, ctx->aes_key.key_enc, nrounds,
tag);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, if (unlikely(!nbytes))
walk.nbytes % (2 * AES_BLOCK_SIZE)); break;
rk = ctx->aes_key.key_enc; if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
} while (walk.nbytes >= 2 * AES_BLOCK_SIZE); memcpy(walk.dst.virt.addr,
} else { buf + sizeof(buf) - nbytes, nbytes);
aes_encrypt(&ctx->aes_key, tag, iv);
put_unaligned_be32(2, iv + GCM_IV_SIZE);
while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) { err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
const int blocks = } while (walk.nbytes);
walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; } else {
while (walk.nbytes >= AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE;
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr; u8 *dst = walk.dst.virt.addr;
u8 *src = walk.src.virt.addr;
int remaining = blocks; int remaining = blocks;
do { do {
aes_encrypt(&ctx->aes_key, ks, iv); aes_encrypt(&ctx->aes_key, buf, iv);
crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE); crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
crypto_inc(iv, AES_BLOCK_SIZE); crypto_inc(iv, AES_BLOCK_SIZE);
dst += AES_BLOCK_SIZE; dst += AES_BLOCK_SIZE;
src += AES_BLOCK_SIZE; src += AES_BLOCK_SIZE;
} while (--remaining > 0); } while (--remaining > 0);
ghash_do_update(blocks, dg, ghash_do_update(blocks, dg, walk.dst.virt.addr,
walk.dst.virt.addr, &ctx->ghash_key, &ctx->ghash_key, NULL, NULL);
NULL, pmull_ghash_update_p64);
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,
walk.nbytes % (2 * AES_BLOCK_SIZE)); walk.nbytes % AES_BLOCK_SIZE);
}
if (walk.nbytes) {
aes_encrypt(&ctx->aes_key, ks, iv);
if (walk.nbytes > AES_BLOCK_SIZE) {
crypto_inc(iv, AES_BLOCK_SIZE);
aes_encrypt(&ctx->aes_key, ks + AES_BLOCK_SIZE, iv);
}
} }
}
/* handle the tail */ /* handle the tail */
if (walk.nbytes) { if (walk.nbytes) {
u8 buf[GHASH_BLOCK_SIZE]; aes_encrypt(&ctx->aes_key, buf, iv);
unsigned int nbytes = walk.nbytes;
u8 *dst = walk.dst.virt.addr;
u8 *head = NULL;
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks, crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
walk.nbytes); buf, walk.nbytes);
if (walk.nbytes > GHASH_BLOCK_SIZE) { memcpy(buf, walk.dst.virt.addr, walk.nbytes);
head = dst; memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
dst += GHASH_BLOCK_SIZE;
nbytes %= GHASH_BLOCK_SIZE;
} }
memcpy(buf, dst, nbytes); tag = (u8 *)&lengths;
memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes); ghash_do_update(1, dg, tag, &ctx->ghash_key,
ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head, walk.nbytes ? buf : NULL, NULL);
pmull_ghash_update_p64);
err = skcipher_walk_done(&walk, 0); if (walk.nbytes)
err = skcipher_walk_done(&walk, 0);
put_unaligned_be64(dg[1], tag);
put_unaligned_be64(dg[0], tag + 8);
put_unaligned_be32(1, iv + GCM_IV_SIZE);
aes_encrypt(&ctx->aes_key, iv, iv);
crypto_xor(tag, iv, AES_BLOCK_SIZE);
} }
if (err) if (err)
return err; return err;
gcm_final(req, ctx, dg, tag, req->cryptlen);
/* copy authtag to end of dst */ /* copy authtag to end of dst */
scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen, scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen,
crypto_aead_authsize(aead), 1); crypto_aead_authsize(aead), 1);
...@@ -540,75 +514,65 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -540,75 +514,65 @@ static int gcm_decrypt(struct aead_request *req)
struct crypto_aead *aead = crypto_aead_reqtfm(req); struct crypto_aead *aead = crypto_aead_reqtfm(req);
struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead); struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
unsigned int authsize = crypto_aead_authsize(aead); unsigned int authsize = crypto_aead_authsize(aead);
int nrounds = num_rounds(&ctx->aes_key);
struct skcipher_walk walk; struct skcipher_walk walk;
u8 iv[2 * AES_BLOCK_SIZE]; u8 buf[AES_BLOCK_SIZE];
u8 tag[AES_BLOCK_SIZE]; u8 iv[AES_BLOCK_SIZE];
u8 buf[2 * GHASH_BLOCK_SIZE];
u64 dg[2] = {}; u64 dg[2] = {};
int nrounds = num_rounds(&ctx->aes_key); u128 lengths;
u8 *tag;
int err; int err;
lengths.a = cpu_to_be64(req->assoclen * 8);
lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8);
if (req->assoclen) if (req->assoclen)
gcm_calculate_auth_mac(req, dg); gcm_calculate_auth_mac(req, dg);
memcpy(iv, req->iv, GCM_IV_SIZE); memcpy(iv, req->iv, GCM_IV_SIZE);
put_unaligned_be32(1, iv + GCM_IV_SIZE); put_unaligned_be32(2, iv + GCM_IV_SIZE);
err = skcipher_walk_aead_decrypt(&walk, req, false); err = skcipher_walk_aead_decrypt(&walk, req, false);
if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) { if (likely(crypto_simd_usable())) {
u32 const *rk = NULL;
kernel_neon_begin();
pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
put_unaligned_be32(2, iv + GCM_IV_SIZE);
do { do {
int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; const u8 *src = walk.src.virt.addr;
int rem = walk.total - blocks * AES_BLOCK_SIZE; u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
if (rk)
kernel_neon_begin();
pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
walk.src.virt.addr, &ctx->ghash_key,
iv, rk, nrounds);
/* check if this is the final iteration of the loop */
if (rem < (2 * AES_BLOCK_SIZE)) {
u8 *iv2 = iv + AES_BLOCK_SIZE;
if (rem > AES_BLOCK_SIZE) {
memcpy(iv2, iv, AES_BLOCK_SIZE);
crypto_inc(iv2, AES_BLOCK_SIZE);
}
pmull_gcm_encrypt_block(iv, iv, NULL, nrounds); tag = (u8 *)&lengths;
if (rem > AES_BLOCK_SIZE) if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
pmull_gcm_encrypt_block(iv2, iv2, NULL, src = dst = memcpy(buf + sizeof(buf) - nbytes,
nrounds); src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
} }
kernel_neon_begin();
pmull_gcm_decrypt(nbytes, dst, src, &ctx->ghash_key, dg,
iv, ctx->aes_key.key_enc, nrounds,
tag);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, if (unlikely(!nbytes))
walk.nbytes % (2 * AES_BLOCK_SIZE)); break;
rk = ctx->aes_key.key_enc; if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
} while (walk.nbytes >= 2 * AES_BLOCK_SIZE); memcpy(walk.dst.virt.addr,
} else { buf + sizeof(buf) - nbytes, nbytes);
aes_encrypt(&ctx->aes_key, tag, iv);
put_unaligned_be32(2, iv + GCM_IV_SIZE);
while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) { err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2; } while (walk.nbytes);
} else {
while (walk.nbytes >= AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE;
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr; u8 *dst = walk.dst.virt.addr;
u8 *src = walk.src.virt.addr;
ghash_do_update(blocks, dg, walk.src.virt.addr, ghash_do_update(blocks, dg, walk.src.virt.addr,
&ctx->ghash_key, NULL, &ctx->ghash_key, NULL, NULL);
pmull_ghash_update_p64);
do { do {
aes_encrypt(&ctx->aes_key, buf, iv); aes_encrypt(&ctx->aes_key, buf, iv);
...@@ -620,49 +584,38 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -620,49 +584,38 @@ static int gcm_decrypt(struct aead_request *req)
} while (--blocks > 0); } while (--blocks > 0);
err = skcipher_walk_done(&walk, err = skcipher_walk_done(&walk,
walk.nbytes % (2 * AES_BLOCK_SIZE)); walk.nbytes % AES_BLOCK_SIZE);
} }
if (walk.nbytes) {
if (walk.nbytes > AES_BLOCK_SIZE) {
u8 *iv2 = iv + AES_BLOCK_SIZE;
memcpy(iv2, iv, AES_BLOCK_SIZE);
crypto_inc(iv2, AES_BLOCK_SIZE);
aes_encrypt(&ctx->aes_key, iv2, iv2); /* handle the tail */
} if (walk.nbytes) {
aes_encrypt(&ctx->aes_key, iv, iv); memcpy(buf, walk.src.virt.addr, walk.nbytes);
memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
} }
}
/* handle the tail */ tag = (u8 *)&lengths;
if (walk.nbytes) { ghash_do_update(1, dg, tag, &ctx->ghash_key,
const u8 *src = walk.src.virt.addr; walk.nbytes ? buf : NULL, NULL);
const u8 *head = NULL;
unsigned int nbytes = walk.nbytes;
if (walk.nbytes > GHASH_BLOCK_SIZE) { if (walk.nbytes) {
head = src; aes_encrypt(&ctx->aes_key, buf, iv);
src += GHASH_BLOCK_SIZE;
nbytes %= GHASH_BLOCK_SIZE;
}
memcpy(buf, src, nbytes); crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes); buf, walk.nbytes);
ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head,
pmull_ghash_update_p64);
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv, err = skcipher_walk_done(&walk, 0);
walk.nbytes); }
err = skcipher_walk_done(&walk, 0); put_unaligned_be64(dg[1], tag);
put_unaligned_be64(dg[0], tag + 8);
put_unaligned_be32(1, iv + GCM_IV_SIZE);
aes_encrypt(&ctx->aes_key, iv, iv);
crypto_xor(tag, iv, AES_BLOCK_SIZE);
} }
if (err) if (err)
return err; return err;
gcm_final(req, ctx, dg, tag, req->cryptlen - authsize);
/* compare calculated auth tag with the stored one */ /* compare calculated auth tag with the stored one */
scatterwalk_map_and_copy(buf, req->src, scatterwalk_map_and_copy(buf, req->src,
req->assoclen + req->cryptlen - authsize, req->assoclen + req->cryptlen - authsize,
...@@ -675,7 +628,7 @@ static int gcm_decrypt(struct aead_request *req) ...@@ -675,7 +628,7 @@ static int gcm_decrypt(struct aead_request *req)
static struct aead_alg gcm_aes_alg = { static struct aead_alg gcm_aes_alg = {
.ivsize = GCM_IV_SIZE, .ivsize = GCM_IV_SIZE,
.chunksize = 2 * AES_BLOCK_SIZE, .chunksize = AES_BLOCK_SIZE,
.maxauthsize = AES_BLOCK_SIZE, .maxauthsize = AES_BLOCK_SIZE,
.setkey = gcm_setkey, .setkey = gcm_setkey,
.setauthsize = gcm_setauthsize, .setauthsize = gcm_setauthsize,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment