Commit 23a251cc authored by Nathan Huckleberry's avatar Nathan Huckleberry Committed by Herbert Xu

crypto: arm64/aes-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for ARM64 CPUs with ARMv8
Crypto Extension support.  This XCTR implementation is based on the CTR
implementation in aes-modes.S.

More information on XCTR can be found in
the HCTR2 paper: "Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdfSigned-off-by: default avatarNathan Huckleberry <nhuck@google.com>
Reviewed-by: default avatarArd Biesheuvel <ardb@kernel.org>
Reviewed-by: default avatarEric Biggers <ebiggers@google.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent fd94fcf0
...@@ -96,13 +96,13 @@ config CRYPTO_AES_ARM64_CE_CCM ...@@ -96,13 +96,13 @@ config CRYPTO_AES_ARM64_CE_CCM
select CRYPTO_LIB_AES select CRYPTO_LIB_AES
config CRYPTO_AES_ARM64_CE_BLK config CRYPTO_AES_ARM64_CE_BLK
tristate "AES in ECB/CBC/CTR/XTS modes using ARMv8 Crypto Extensions" tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using ARMv8 Crypto Extensions"
depends on KERNEL_MODE_NEON depends on KERNEL_MODE_NEON
select CRYPTO_SKCIPHER select CRYPTO_SKCIPHER
select CRYPTO_AES_ARM64_CE select CRYPTO_AES_ARM64_CE
config CRYPTO_AES_ARM64_NEON_BLK config CRYPTO_AES_ARM64_NEON_BLK
tristate "AES in ECB/CBC/CTR/XTS modes using NEON instructions" tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using NEON instructions"
depends on KERNEL_MODE_NEON depends on KERNEL_MODE_NEON
select CRYPTO_SKCIPHER select CRYPTO_SKCIPHER
select CRYPTO_LIB_AES select CRYPTO_LIB_AES
......
...@@ -34,10 +34,11 @@ ...@@ -34,10 +34,11 @@
#define aes_essiv_cbc_encrypt ce_aes_essiv_cbc_encrypt #define aes_essiv_cbc_encrypt ce_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt ce_aes_essiv_cbc_decrypt #define aes_essiv_cbc_decrypt ce_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt ce_aes_ctr_encrypt #define aes_ctr_encrypt ce_aes_ctr_encrypt
#define aes_xctr_encrypt ce_aes_xctr_encrypt
#define aes_xts_encrypt ce_aes_xts_encrypt #define aes_xts_encrypt ce_aes_xts_encrypt
#define aes_xts_decrypt ce_aes_xts_decrypt #define aes_xts_decrypt ce_aes_xts_decrypt
#define aes_mac_update ce_aes_mac_update #define aes_mac_update ce_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions"); MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
#else #else
#define MODE "neon" #define MODE "neon"
#define PRIO 200 #define PRIO 200
...@@ -50,16 +51,18 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions"); ...@@ -50,16 +51,18 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_essiv_cbc_encrypt neon_aes_essiv_cbc_encrypt #define aes_essiv_cbc_encrypt neon_aes_essiv_cbc_encrypt
#define aes_essiv_cbc_decrypt neon_aes_essiv_cbc_decrypt #define aes_essiv_cbc_decrypt neon_aes_essiv_cbc_decrypt
#define aes_ctr_encrypt neon_aes_ctr_encrypt #define aes_ctr_encrypt neon_aes_ctr_encrypt
#define aes_xctr_encrypt neon_aes_xctr_encrypt
#define aes_xts_encrypt neon_aes_xts_encrypt #define aes_xts_encrypt neon_aes_xts_encrypt
#define aes_xts_decrypt neon_aes_xts_decrypt #define aes_xts_decrypt neon_aes_xts_decrypt
#define aes_mac_update neon_aes_mac_update #define aes_mac_update neon_aes_mac_update
MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON"); MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
#endif #endif
#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS) #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
MODULE_ALIAS_CRYPTO("ecb(aes)"); MODULE_ALIAS_CRYPTO("ecb(aes)");
MODULE_ALIAS_CRYPTO("cbc(aes)"); MODULE_ALIAS_CRYPTO("cbc(aes)");
MODULE_ALIAS_CRYPTO("ctr(aes)"); MODULE_ALIAS_CRYPTO("ctr(aes)");
MODULE_ALIAS_CRYPTO("xts(aes)"); MODULE_ALIAS_CRYPTO("xts(aes)");
MODULE_ALIAS_CRYPTO("xctr(aes)");
#endif #endif
MODULE_ALIAS_CRYPTO("cts(cbc(aes))"); MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)"); MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
...@@ -89,6 +92,9 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[], ...@@ -89,6 +92,9 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 ctr[]); int rounds, int bytes, u8 ctr[]);
asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int bytes, u8 ctr[], int byte_ctr);
asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
int rounds, int bytes, u32 const rk2[], u8 iv[], int rounds, int bytes, u32 const rk2[], u8 iv[],
int first); int first);
...@@ -442,6 +448,44 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req) ...@@ -442,6 +448,44 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
return err ?: cbc_decrypt_walk(req, &walk); return err ?: cbc_decrypt_walk(req, &walk);
} }
static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
int err, rounds = 6 + ctx->key_length / 4;
struct skcipher_walk walk;
unsigned int byte_ctr = 0;
err = skcipher_walk_virt(&walk, req, false);
while (walk.nbytes > 0) {
const u8 *src = walk.src.virt.addr;
unsigned int nbytes = walk.nbytes;
u8 *dst = walk.dst.virt.addr;
u8 buf[AES_BLOCK_SIZE];
if (unlikely(nbytes < AES_BLOCK_SIZE))
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
else if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);
kernel_neon_begin();
aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
walk.iv, byte_ctr);
kernel_neon_end();
if (unlikely(nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
byte_ctr += nbytes;
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
}
return err;
}
static int __maybe_unused ctr_encrypt(struct skcipher_request *req) static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
...@@ -669,6 +713,22 @@ static struct skcipher_alg aes_algs[] = { { ...@@ -669,6 +713,22 @@ static struct skcipher_alg aes_algs[] = { {
.setkey = skcipher_aes_setkey, .setkey = skcipher_aes_setkey,
.encrypt = ctr_encrypt, .encrypt = ctr_encrypt,
.decrypt = ctr_encrypt, .decrypt = ctr_encrypt,
}, {
.base = {
.cra_name = "xctr(aes)",
.cra_driver_name = "xctr-aes-" MODE,
.cra_priority = PRIO,
.cra_blocksize = 1,
.cra_ctxsize = sizeof(struct crypto_aes_ctx),
.cra_module = THIS_MODULE,
},
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.chunksize = AES_BLOCK_SIZE,
.setkey = skcipher_aes_setkey,
.encrypt = xctr_encrypt,
.decrypt = xctr_encrypt,
}, { }, {
.base = { .base = {
.cra_name = "xts(aes)", .cra_name = "xts(aes)",
......
...@@ -318,79 +318,102 @@ AES_FUNC_END(aes_cbc_cts_decrypt) ...@@ -318,79 +318,102 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff .byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.previous .previous
/* /*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds, * This macro generates the code for CTR and XCTR mode.
* int bytes, u8 ctr[])
*/ */
.macro ctr_encrypt xctr
AES_FUNC_START(aes_ctr_encrypt)
stp x29, x30, [sp, #-16]! stp x29, x30, [sp, #-16]!
mov x29, sp mov x29, sp
enc_prepare w3, x2, x12 enc_prepare w3, x2, x12
ld1 {vctr.16b}, [x5] ld1 {vctr.16b}, [x5]
umov x12, vctr.d[1] /* keep swabbed ctr in reg */ .if \xctr
rev x12, x12 umov x12, vctr.d[0]
lsr w11, w6, #4
.else
umov x12, vctr.d[1] /* keep swabbed ctr in reg */
rev x12, x12
.endif
.LctrloopNx: .LctrloopNx\xctr:
add w7, w4, #15 add w7, w4, #15
sub w4, w4, #MAX_STRIDE << 4 sub w4, w4, #MAX_STRIDE << 4
lsr w7, w7, #4 lsr w7, w7, #4
mov w8, #MAX_STRIDE mov w8, #MAX_STRIDE
cmp w7, w8 cmp w7, w8
csel w7, w7, w8, lt csel w7, w7, w8, lt
adds x12, x12, x7
.if \xctr
add x11, x11, x7
.else
adds x12, x12, x7
.endif
mov v0.16b, vctr.16b mov v0.16b, vctr.16b
mov v1.16b, vctr.16b mov v1.16b, vctr.16b
mov v2.16b, vctr.16b mov v2.16b, vctr.16b
mov v3.16b, vctr.16b mov v3.16b, vctr.16b
ST5( mov v4.16b, vctr.16b ) ST5( mov v4.16b, vctr.16b )
bcs 0f .if \xctr
sub x6, x11, #MAX_STRIDE - 1
.subsection 1 sub x7, x11, #MAX_STRIDE - 2
/* apply carry to outgoing counter */ sub x8, x11, #MAX_STRIDE - 3
0: umov x8, vctr.d[0] sub x9, x11, #MAX_STRIDE - 4
rev x8, x8 ST5( sub x10, x11, #MAX_STRIDE - 5 )
add x8, x8, #1 eor x6, x6, x12
rev x8, x8 eor x7, x7, x12
ins vctr.d[0], x8 eor x8, x8, x12
eor x9, x9, x12
/* apply carry to N counter blocks for N := x12 */ ST5( eor x10, x10, x12 )
cbz x12, 2f mov v0.d[0], x6
adr x16, 1f mov v1.d[0], x7
sub x16, x16, x12, lsl #3 mov v2.d[0], x8
br x16 mov v3.d[0], x9
bti c ST5( mov v4.d[0], x10 )
mov v0.d[0], vctr.d[0] .else
bti c bcs 0f
mov v1.d[0], vctr.d[0] .subsection 1
bti c /* apply carry to outgoing counter */
mov v2.d[0], vctr.d[0] 0: umov x8, vctr.d[0]
bti c rev x8, x8
mov v3.d[0], vctr.d[0] add x8, x8, #1
ST5( bti c ) rev x8, x8
ST5( mov v4.d[0], vctr.d[0] ) ins vctr.d[0], x8
1: b 2f
.previous /* apply carry to N counter blocks for N := x12 */
cbz x12, 2f
2: rev x7, x12 adr x16, 1f
ins vctr.d[1], x7 sub x16, x16, x12, lsl #3
sub x7, x12, #MAX_STRIDE - 1 br x16
sub x8, x12, #MAX_STRIDE - 2 bti c
sub x9, x12, #MAX_STRIDE - 3 mov v0.d[0], vctr.d[0]
rev x7, x7 bti c
rev x8, x8 mov v1.d[0], vctr.d[0]
mov v1.d[1], x7 bti c
rev x9, x9 mov v2.d[0], vctr.d[0]
ST5( sub x10, x12, #MAX_STRIDE - 4 ) bti c
mov v2.d[1], x8 mov v3.d[0], vctr.d[0]
ST5( rev x10, x10 ) ST5( bti c )
mov v3.d[1], x9 ST5( mov v4.d[0], vctr.d[0] )
ST5( mov v4.d[1], x10 ) 1: b 2f
tbnz w4, #31, .Lctrtail .previous
2: rev x7, x12
ins vctr.d[1], x7
sub x7, x12, #MAX_STRIDE - 1
sub x8, x12, #MAX_STRIDE - 2
sub x9, x12, #MAX_STRIDE - 3
rev x7, x7
rev x8, x8
mov v1.d[1], x7
rev x9, x9
ST5( sub x10, x12, #MAX_STRIDE - 4 )
mov v2.d[1], x8
ST5( rev x10, x10 )
mov v3.d[1], x9
ST5( mov v4.d[1], x10 )
.endif
tbnz w4, #31, .Lctrtail\xctr
ld1 {v5.16b-v7.16b}, [x1], #48 ld1 {v5.16b-v7.16b}, [x1], #48
ST4( bl aes_encrypt_block4x ) ST4( bl aes_encrypt_block4x )
ST5( bl aes_encrypt_block5x ) ST5( bl aes_encrypt_block5x )
...@@ -403,16 +426,17 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 ) ...@@ -403,16 +426,17 @@ ST5( ld1 {v5.16b-v6.16b}, [x1], #32 )
ST5( eor v4.16b, v6.16b, v4.16b ) ST5( eor v4.16b, v6.16b, v4.16b )
st1 {v0.16b-v3.16b}, [x0], #64 st1 {v0.16b-v3.16b}, [x0], #64
ST5( st1 {v4.16b}, [x0], #16 ) ST5( st1 {v4.16b}, [x0], #16 )
cbz w4, .Lctrout cbz w4, .Lctrout\xctr
b .LctrloopNx b .LctrloopNx\xctr
.Lctrout: .Lctrout\xctr:
st1 {vctr.16b}, [x5] /* return next CTR value */ .if !\xctr
st1 {vctr.16b}, [x5] /* return next CTR value */
.endif
ldp x29, x30, [sp], #16 ldp x29, x30, [sp], #16
ret ret
.Lctrtail: .Lctrtail\xctr:
/* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
mov x16, #16 mov x16, #16
ands x6, x4, #0xf ands x6, x4, #0xf
csel x13, x6, x16, ne csel x13, x6, x16, ne
...@@ -427,7 +451,7 @@ ST5( csel x14, x16, xzr, gt ) ...@@ -427,7 +451,7 @@ ST5( csel x14, x16, xzr, gt )
adr_l x12, .Lcts_permute_table adr_l x12, .Lcts_permute_table
add x12, x12, x13 add x12, x12, x13
ble .Lctrtail1x ble .Lctrtail1x\xctr
ST5( ld1 {v5.16b}, [x1], x14 ) ST5( ld1 {v5.16b}, [x1], x14 )
ld1 {v6.16b}, [x1], x15 ld1 {v6.16b}, [x1], x15
...@@ -459,9 +483,9 @@ ST5( st1 {v5.16b}, [x0], x14 ) ...@@ -459,9 +483,9 @@ ST5( st1 {v5.16b}, [x0], x14 )
add x13, x13, x0 add x13, x13, x0
st1 {v9.16b}, [x13] // overlapping stores st1 {v9.16b}, [x13] // overlapping stores
st1 {v8.16b}, [x0] st1 {v8.16b}, [x0]
b .Lctrout b .Lctrout\xctr
.Lctrtail1x: .Lctrtail1x\xctr:
sub x7, x6, #16 sub x7, x6, #16
csel x6, x6, x7, eq csel x6, x6, x7, eq
add x1, x1, x6 add x1, x1, x6
...@@ -476,9 +500,27 @@ ST5( mov v3.16b, v4.16b ) ...@@ -476,9 +500,27 @@ ST5( mov v3.16b, v4.16b )
eor v5.16b, v5.16b, v3.16b eor v5.16b, v5.16b, v3.16b
bif v5.16b, v6.16b, v11.16b bif v5.16b, v6.16b, v11.16b
st1 {v5.16b}, [x0] st1 {v5.16b}, [x0]
b .Lctrout b .Lctrout\xctr
.endm
/*
* aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 ctr[])
*/
AES_FUNC_START(aes_ctr_encrypt)
ctr_encrypt 0
AES_FUNC_END(aes_ctr_encrypt) AES_FUNC_END(aes_ctr_encrypt)
/*
* aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int bytes, u8 const iv[], int byte_ctr)
*/
AES_FUNC_START(aes_xctr_encrypt)
ctr_encrypt 1
AES_FUNC_END(aes_xctr_encrypt)
/* /*
* aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds, * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment