Skip to content

Commit

Permalink
Merge pull request #7549 from douzzer/20240516-wc_AesXtsEnDecryptFinal
Browse files Browse the repository at this point in the history
20240516-wc_AesXtsEnDecryptFinal
  • Loading branch information
SparkiDev authored May 16, 2024
2 parents 219a338 + 6d0f611 commit c0015cb
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 42 deletions.
54 changes: 38 additions & 16 deletions linuxkm/lkcapi_glue.c
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,14 @@ static int km_AesXtsEncrypt(struct skcipher_request *req)
if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);

err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
if (nbytes & ((unsigned int)AES_BLOCK_SIZE - 1U))
err = wc_AesXtsEncryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
else
err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsEncryptUpdate failed: %d\n",
Expand Down Expand Up @@ -979,12 +984,12 @@ static int km_AesXtsEncrypt(struct skcipher_request *req)
if (err)
return err;

err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
err = wc_AesXtsEncryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, walk.nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsEncryptUpdate failed: %d\n",
pr_err("%s: wc_AesXtsEncryptFinal failed: %d\n",
crypto_tfm_alg_driver_name(crypto_skcipher_tfm(tfm)), err);
return -EINVAL;
}
Expand Down Expand Up @@ -1071,9 +1076,14 @@ static int km_AesXtsDecrypt(struct skcipher_request *req)
if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);

err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
if (nbytes & ((unsigned int)AES_BLOCK_SIZE - 1U))
err = wc_AesXtsDecryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
else
err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsDecryptUpdate failed: %d\n",
Expand Down Expand Up @@ -1105,12 +1115,12 @@ static int km_AesXtsDecrypt(struct skcipher_request *req)
if (err)
return err;

err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
err = wc_AesXtsDecryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, walk.nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsDecryptUpdate failed: %d\n",
pr_err("%s: wc_AesXtsDecryptFinal failed: %d\n",
crypto_tfm_alg_driver_name(crypto_skcipher_tfm(tfm)), err);
return -EINVAL;
}
Expand Down Expand Up @@ -2029,7 +2039,7 @@ static int aes_xts_128_test(void)
ret = wc_AesXtsEncryptUpdate(aes, buf, p2, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
ret = wc_AesXtsEncryptUpdate(aes, buf + AES_BLOCK_SIZE,
ret = wc_AesXtsEncryptFinal(aes, buf + AES_BLOCK_SIZE,
p2 + AES_BLOCK_SIZE,
sizeof(p2) - AES_BLOCK_SIZE, iv);
if (ret != 0)
Expand Down Expand Up @@ -2214,7 +2224,10 @@ static int aes_xts_128_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsEncryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2252,7 +2265,10 @@ static int aes_xts_128_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsDecryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2611,7 +2627,7 @@ static int aes_xts_256_test(void)
ret = wc_AesXtsEncryptUpdate(aes, buf, p2, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
ret = wc_AesXtsEncryptUpdate(aes, buf + AES_BLOCK_SIZE,
ret = wc_AesXtsEncryptFinal(aes, buf + AES_BLOCK_SIZE,
p2 + AES_BLOCK_SIZE,
sizeof(p2) - AES_BLOCK_SIZE, iv);
if (ret != 0)
Expand Down Expand Up @@ -2700,7 +2716,10 @@ static int aes_xts_256_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsEncryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2738,7 +2757,10 @@ static int aes_xts_256_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsDecryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down
60 changes: 54 additions & 6 deletions wolfcrypt/src/aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -12907,8 +12907,9 @@ int wc_AesXtsEncryptInit(XtsAes* xaes, byte* i, word32 iSz)

/* Block-streaming AES-XTS
*
* Note that sz must be greater than AES_BLOCK_SIZE in each call, and must be a
* multiple of AES_BLOCK_SIZE in all but the final call.
* Note that sz must be >= AES_BLOCK_SIZE in each call, and must be a multiple
* of AES_BLOCK_SIZE in each call to wc_AesXtsEncryptUpdate().
* wc_AesXtsEncryptFinal() can handle any length >= AES_BLOCK_SIZE.
*
* xaes AES keys to use for block encrypt/decrypt
* out output buffer to hold cipher text
Expand All @@ -12920,7 +12921,7 @@ int wc_AesXtsEncryptInit(XtsAes* xaes, byte* i, word32 iSz)
*
* returns 0 on success
*/
int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
static int AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
Expand Down Expand Up @@ -12975,6 +12976,29 @@ int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
return ret;
}

int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
if (sz & ((word32)AES_BLOCK_SIZE - 1U))
return BAD_FUNC_ARG;
return AesXtsEncryptUpdate(xaes, out, in, sz, i);
}

int wc_AesXtsEncryptFinal(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
if (sz > 0)
ret = AesXtsEncryptUpdate(xaes, out, in, sz, i);
else
ret = 0;
ForceZero(i, AES_BLOCK_SIZE);
#ifdef WOLFSSL_CHECK_MEM_ZERO
wc_MemZero_Check(i, AES_BLOCK_SIZE);
#endif
return ret;
}

#endif /* WOLFSSL_AESXTS_STREAM */


Expand Down Expand Up @@ -13284,8 +13308,9 @@ int wc_AesXtsDecryptInit(XtsAes* xaes, byte* i, word32 iSz)

/* Block-streaming AES-XTS
*
* Note that sz must be greater than AES_BLOCK_SIZE in each call, and must be a
* multiple of AES_BLOCK_SIZE in all but the final call.
* Note that sz must be >= AES_BLOCK_SIZE in each call, and must be a multiple
* of AES_BLOCK_SIZE in each call to wc_AesXtsDecryptUpdate().
* wc_AesXtsDecryptFinal() can handle any length >= AES_BLOCK_SIZE.
*
* xaes AES keys to use for block encrypt/decrypt
* out output buffer to hold plain text
Expand All @@ -13295,7 +13320,7 @@ int wc_AesXtsDecryptInit(XtsAes* xaes, byte* i, word32 iSz)
*
* returns 0 on success
*/
int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
static int AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
Expand Down Expand Up @@ -13353,6 +13378,29 @@ int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
return ret;
}

int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
if (sz & ((word32)AES_BLOCK_SIZE - 1U))
return BAD_FUNC_ARG;
return AesXtsDecryptUpdate(xaes, out, in, sz, i);
}

int wc_AesXtsDecryptFinal(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
if (sz > 0)
ret = AesXtsDecryptUpdate(xaes, out, in, sz, i);
else
ret = 0;
ForceZero(i, AES_BLOCK_SIZE);
#ifdef WOLFSSL_CHECK_MEM_ZERO
wc_MemZero_Check(i, AES_BLOCK_SIZE);
#endif
return ret;
}

#endif /* WOLFSSL_AESXTS_STREAM */

#endif /* !WOLFSSL_ARMASM || WOLFSSL_ARMASM_NO_HW_CRYPTO */
Expand Down
Loading

0 comments on commit c0015cb

Please sign in to comment.