Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

20240516-wc_AesXtsEnDecryptFinal #7549

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions linuxkm/lkcapi_glue.c
Original file line number Diff line number Diff line change
Expand Up @@ -945,9 +945,14 @@ static int km_AesXtsEncrypt(struct skcipher_request *req)
if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);

err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
if (nbytes & ((unsigned int)AES_BLOCK_SIZE - 1U))
err = wc_AesXtsEncryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
else
err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsEncryptUpdate failed: %d\n",
Expand Down Expand Up @@ -979,12 +984,12 @@ static int km_AesXtsEncrypt(struct skcipher_request *req)
if (err)
return err;

err = wc_AesXtsEncryptUpdate(ctx->aesXts, walk.dst.virt.addr,
err = wc_AesXtsEncryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, walk.nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsEncryptUpdate failed: %d\n",
pr_err("%s: wc_AesXtsEncryptFinal failed: %d\n",
crypto_tfm_alg_driver_name(crypto_skcipher_tfm(tfm)), err);
return -EINVAL;
}
Expand Down Expand Up @@ -1071,9 +1076,14 @@ static int km_AesXtsDecrypt(struct skcipher_request *req)
if (nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);

err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
if (nbytes & ((unsigned int)AES_BLOCK_SIZE - 1U))
err = wc_AesXtsDecryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);
else
err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsDecryptUpdate failed: %d\n",
Expand Down Expand Up @@ -1105,12 +1115,12 @@ static int km_AesXtsDecrypt(struct skcipher_request *req)
if (err)
return err;

err = wc_AesXtsDecryptUpdate(ctx->aesXts, walk.dst.virt.addr,
err = wc_AesXtsDecryptFinal(ctx->aesXts, walk.dst.virt.addr,
walk.src.virt.addr, walk.nbytes,
walk.iv);

if (unlikely(err)) {
pr_err("%s: wc_AesXtsDecryptUpdate failed: %d\n",
pr_err("%s: wc_AesXtsDecryptFinal failed: %d\n",
crypto_tfm_alg_driver_name(crypto_skcipher_tfm(tfm)), err);
return -EINVAL;
}
Expand Down Expand Up @@ -2029,7 +2039,7 @@ static int aes_xts_128_test(void)
ret = wc_AesXtsEncryptUpdate(aes, buf, p2, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
ret = wc_AesXtsEncryptUpdate(aes, buf + AES_BLOCK_SIZE,
ret = wc_AesXtsEncryptFinal(aes, buf + AES_BLOCK_SIZE,
p2 + AES_BLOCK_SIZE,
sizeof(p2) - AES_BLOCK_SIZE, iv);
if (ret != 0)
Expand Down Expand Up @@ -2214,7 +2224,10 @@ static int aes_xts_128_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsEncryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2252,7 +2265,10 @@ static int aes_xts_128_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsDecryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2611,7 +2627,7 @@ static int aes_xts_256_test(void)
ret = wc_AesXtsEncryptUpdate(aes, buf, p2, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
ret = wc_AesXtsEncryptUpdate(aes, buf + AES_BLOCK_SIZE,
ret = wc_AesXtsEncryptFinal(aes, buf + AES_BLOCK_SIZE,
p2 + AES_BLOCK_SIZE,
sizeof(p2) - AES_BLOCK_SIZE, iv);
if (ret != 0)
Expand Down Expand Up @@ -2700,7 +2716,10 @@ static int aes_xts_256_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsEncryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsEncryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down Expand Up @@ -2738,7 +2757,10 @@ static int aes_xts_256_test(void)
if (ret != 0)
goto out;
for (k = 0; k < j; k += AES_BLOCK_SIZE) {
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, (j - k) < AES_BLOCK_SIZE*2 ? j - k : AES_BLOCK_SIZE, iv);
if ((j - k) < AES_BLOCK_SIZE*2)
ret = wc_AesXtsDecryptFinal(aes, large_input + k, large_input + k, j - k, iv);
else
ret = wc_AesXtsDecryptUpdate(aes, large_input + k, large_input + k, AES_BLOCK_SIZE, iv);
if (ret != 0)
goto out;
if ((j - k) < AES_BLOCK_SIZE*2)
Expand Down
60 changes: 54 additions & 6 deletions wolfcrypt/src/aes.c
Original file line number Diff line number Diff line change
Expand Up @@ -12907,8 +12907,9 @@ int wc_AesXtsEncryptInit(XtsAes* xaes, byte* i, word32 iSz)

/* Block-streaming AES-XTS
*
* Note that sz must be greater than AES_BLOCK_SIZE in each call, and must be a
* multiple of AES_BLOCK_SIZE in all but the final call.
* Note that sz must be >= AES_BLOCK_SIZE in each call, and must be a multiple
* of AES_BLOCK_SIZE in each call to wc_AesXtsEncryptUpdate().
* wc_AesXtsEncryptFinal() can handle any length >= AES_BLOCK_SIZE.
*
* xaes AES keys to use for block encrypt/decrypt
* out output buffer to hold cipher text
Expand All @@ -12920,7 +12921,7 @@ int wc_AesXtsEncryptInit(XtsAes* xaes, byte* i, word32 iSz)
*
* returns 0 on success
*/
int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
static int AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
Expand Down Expand Up @@ -12975,6 +12976,29 @@ int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
return ret;
}

int wc_AesXtsEncryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
if (sz & ((word32)AES_BLOCK_SIZE - 1U))
return BAD_FUNC_ARG;
return AesXtsEncryptUpdate(xaes, out, in, sz, i);
}

int wc_AesXtsEncryptFinal(XtsAes* xaes, byte* out, const byte* in, word32 sz,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What value does "final" add? Is it the zero'ing of "i"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a couple things:

(1) lets us add error-checking for all calls to Update() to make sure they're correctly block-aligned (trying to get ahead of ZenDesk tickets on that).

(2) the zeroing lets us check for API abuse, in that it guarantees a wrong result if the user calls Update after a Final, and of course it also prevents anything valuable from leaking out through the IV.

byte *i)
{
int ret;
if (sz > 0)
ret = AesXtsEncryptUpdate(xaes, out, in, sz, i);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this input "sz" not have to be multiple of block size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right. for the last call, the input can have any length >= AES_BLOCK_SIZE. if the length isn't a multiple of the block size, then the ciphertext stealing stuff is used to finish out the message. and btw that's exactly why only the last call can be non-block-aligned.

else
ret = 0;
ForceZero(i, AES_BLOCK_SIZE);
#ifdef WOLFSSL_CHECK_MEM_ZERO
wc_MemZero_Check(i, AES_BLOCK_SIZE);
#endif
return ret;
}

#endif /* WOLFSSL_AESXTS_STREAM */


Expand Down Expand Up @@ -13284,8 +13308,9 @@ int wc_AesXtsDecryptInit(XtsAes* xaes, byte* i, word32 iSz)

/* Block-streaming AES-XTS
*
* Note that sz must be greater than AES_BLOCK_SIZE in each call, and must be a
* multiple of AES_BLOCK_SIZE in all but the final call.
* Note that sz must be >= AES_BLOCK_SIZE in each call, and must be a multiple
* of AES_BLOCK_SIZE in each call to wc_AesXtsDecryptUpdate().
* wc_AesXtsDecryptFinal() can handle any length >= AES_BLOCK_SIZE.
*
* xaes AES keys to use for block encrypt/decrypt
* out output buffer to hold plain text
Expand All @@ -13295,7 +13320,7 @@ int wc_AesXtsDecryptInit(XtsAes* xaes, byte* i, word32 iSz)
*
* returns 0 on success
*/
int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
static int AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
Expand Down Expand Up @@ -13353,6 +13378,29 @@ int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
return ret;
}

int wc_AesXtsDecryptUpdate(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
if (sz & ((word32)AES_BLOCK_SIZE - 1U))
return BAD_FUNC_ARG;
return AesXtsDecryptUpdate(xaes, out, in, sz, i);
}

int wc_AesXtsDecryptFinal(XtsAes* xaes, byte* out, const byte* in, word32 sz,
byte *i)
{
int ret;
if (sz > 0)
ret = AesXtsDecryptUpdate(xaes, out, in, sz, i);
else
ret = 0;
ForceZero(i, AES_BLOCK_SIZE);
#ifdef WOLFSSL_CHECK_MEM_ZERO
wc_MemZero_Check(i, AES_BLOCK_SIZE);
#endif
return ret;
}

#endif /* WOLFSSL_AESXTS_STREAM */

#endif /* !WOLFSSL_ARMASM || WOLFSSL_ARMASM_NO_HW_CRYPTO */
Expand Down
Loading