diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/AesCryptographyProvider.cs b/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/AesCryptographyProvider.cs index 3cf95308f6379..adc4486c82250 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/AesCryptographyProvider.cs +++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/AesCryptographyProvider.cs @@ -81,6 +81,9 @@ public override EncryptResult Encrypt(EncryptOptions options, CancellationToken EncryptionAlgorithm algorithm = options.Algorithm; if (algorithm.GetAesCbcEncryptionAlgorithm() is AesCbc aesCbc) { + // Make sure the IV is initialized. + options.Initialize(); + using ICryptoTransform encryptor = aesCbc.CreateEncryptor(KeyMaterial.K, options.Iv); byte[] plaintext = options.Plaintext; diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/RemoteCryptographyClient.cs b/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/RemoteCryptographyClient.cs index 42d2a592ba776..8258d77d10312 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/RemoteCryptographyClient.cs +++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/src/Cryptography/RemoteCryptographyClient.cs @@ -49,6 +49,10 @@ public virtual async Task> EncryptAsync(EncryptOptions o try { + // Make sure the IV is initialized. + // TODO: Remove this call once the service will initialized it: https://github.com/Azure/azure-sdk-for-net/issues/16175 + options.Initialize(); + return await Pipeline.SendRequestAsync(RequestMethod.Post, options, () => new EncryptResult { Algorithm = options.Algorithm }, cancellationToken, "/encrypt").ConfigureAwait(false); } catch (Exception e) @@ -66,6 +70,10 @@ public virtual Response Encrypt(EncryptOptions options, Cancellat try { + // Make sure the IV is initialized. + // TODO: Remove this call once the service will initialized it: https://github.com/Azure/azure-sdk-for-net/issues/16175 + options.Initialize(); + return Pipeline.SendRequest(RequestMethod.Post, options, () => new EncryptResult { Algorithm = options.Algorithm }, cancellationToken, "/encrypt"); } catch (Exception e) diff --git a/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/AesCryptographyProviderTests.cs b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/AesCryptographyProviderTests.cs index 56df0118d0b88..059d81778a16e 100644 --- a/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/AesCryptographyProviderTests.cs +++ b/sdk/keyvault/Azure.Security.KeyVault.Keys/tests/AesCryptographyProviderTests.cs @@ -137,7 +137,7 @@ public void DecryptionAlgorithmNotSupported() Assert.AreEqual("invalid", e.GetProperty("algorithm")); } - [TestCaseSource(nameof(EncryptDecryptRoundtripsData))] + [TestCaseSource(nameof(GetEncryptionAlgorithms), methodParams: new object[] { true })] public void EncryptDecryptRoundtrips(EncryptionAlgorithm algorithm) { // Use a 256-bit key which will be truncated based on the selected algorithm. @@ -171,6 +171,73 @@ public void EncryptDecryptRoundtrips(EncryptionAlgorithm algorithm) #endif Assert.IsNotNull(encrypted); + switch (algorithm.ToString()) + { + // TODO: Move to new test to make sure CryptoClient and LocalCryptoClient initialize a null ICM for AES-CBC(PAD). + case EncryptionAlgorithm.A128CbcValue: + CollectionAssert.AreEqual( + new byte[] { 0x63, 0x23, 0x21, 0xaf, 0x94, 0xf9, 0xe1, 0x21, 0xc2, 0xbd, 0xb1, 0x1b, 0x04, 0x89, 0x8c, 0x3a }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A192CbcValue: + CollectionAssert.AreEqual( + new byte[] { 0x95, 0x9d, 0x75, 0x91, 0x09, 0x8b, 0x70, 0x0b, 0x9c, 0xfe, 0xaf, 0xcd, 0x60, 0x1f, 0xaa, 0x79 }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A256CbcValue: + CollectionAssert.AreEqual( + new byte[] { 0xf4, 0xe8, 0x5a, 0xa4, 0xa8, 0xb3, 0xff, 0xc3, 0x85, 0x89, 0x17, 0x9a, 0x70, 0x09, 0x96, 0x7f }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A128CbcPadValue: + CollectionAssert.AreEqual( + new byte[] { 0xec, 0xb2, 0x63, 0x4c, 0xe0, 0x04, 0xe0, 0x31, 0x2d, 0x9a, 0x77, 0xb2, 0x11, 0xe5, 0x28, 0x7f }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A192CbcPadValue: + CollectionAssert.AreEqual( + new byte[] { 0xc3, 0x4e, 0x1b, 0xe7, 0x6e, 0xa1, 0xf1, 0xc3, 0x24, 0xae, 0x05, 0x1b, 0x0e, 0x32, 0xac, 0xb4 }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A256CbcPadValue: + CollectionAssert.AreEqual( + new byte[] { 0x4e, 0xbd, 0x78, 0xda, 0x90, 0x73, 0xc8, 0x97, 0x67, 0x2b, 0xa1, 0x0a, 0x41, 0x67, 0xf8, 0x99 }, + encrypted.Ciphertext); + CollectionAssert.AreEqual(iv, encrypted.Iv); + Assert.IsNull(encrypted.AuthenticationTag); + Assert.IsNull(encrypted.AdditionalAuthenticatedData); + break; + + case EncryptionAlgorithm.A128GcmValue: + case EncryptionAlgorithm.A192GcmValue: + case EncryptionAlgorithm.A256GcmValue: + Assert.IsNotNull(encrypted.Ciphertext); + Assert.IsNotNull(encrypted.Iv); + Assert.IsNotNull(encrypted.AuthenticationTag); + CollectionAssert.AreEqual(aad, encrypted.AdditionalAuthenticatedData); + break; + } + DecryptOptions decryptOptions = algorithm.IsAesGcm() ? new DecryptOptions(algorithm, encrypted.Ciphertext, encrypted.Iv, encrypted.AuthenticationTag, encrypted.AdditionalAuthenticatedData) : new DecryptOptions(algorithm, encrypted.Ciphertext, encrypted.Iv); @@ -182,19 +249,52 @@ public void EncryptDecryptRoundtrips(EncryptionAlgorithm algorithm) StringAssert.StartsWith("plaintext", Encoding.UTF8.GetString(decrypted.Plaintext)); } - private static IEnumerable EncryptDecryptRoundtripsData => new[] + [TestCaseSource(nameof(GetEncryptionAlgorithms), methodParams: new object[] { false })] + public void InitializesIv(EncryptionAlgorithm algorithm) + { + // Use a 256-bit key which will be truncated based on the selected algorithm. + byte[] k = new byte[] { 0xe2, 0x7e, 0xd0, 0xc8, 0x45, 0x12, 0xbb, 0xd5, 0x5b, 0x6a, 0xf4, 0x34, 0xd2, 0x37, 0xc1, 0x1f, 0xeb, 0xa3, 0x11, 0x87, 0x0f, 0x80, 0xf2, 0xc2, 0xe3, 0x36, 0x42, 0x60, 0xf3, 0x1c, 0x82, 0xc8 }; + + JsonWebKey key = new JsonWebKey(new[] { KeyOperation.Encrypt, KeyOperation.Decrypt }) + { + K = k, + }; + + AesCryptographyProvider provider = new AesCryptographyProvider(key, null); + + byte[] plaintext = Encoding.UTF8.GetBytes("plaintext"); + + EncryptOptions encryptOptions = new EncryptOptions(algorithm, plaintext, null, null); + EncryptResult encrypted = provider.Encrypt(encryptOptions, default); + + Assert.IsNotNull(encryptOptions.Iv); + CollectionAssert.AreEqual(encryptOptions.Iv, encrypted.Iv); + + DecryptOptions decryptOptions = new DecryptOptions(algorithm, encrypted.Ciphertext, encrypted.Iv); + DecryptResult decrypted = provider.Decrypt(decryptOptions, default); + + Assert.IsNotNull(decrypted); + + // AES-CBC will be zero-padded. + StringAssert.StartsWith("plaintext", Encoding.UTF8.GetString(decrypted.Plaintext)); + } + + private static IEnumerable GetEncryptionAlgorithms(bool includeAesGcm) { - EncryptionAlgorithm.A128Cbc, - EncryptionAlgorithm.A192Cbc, - EncryptionAlgorithm.A256Cbc, - - EncryptionAlgorithm.A128CbcPad, - EncryptionAlgorithm.A192CbcPad, - EncryptionAlgorithm.A256CbcPad, - - EncryptionAlgorithm.A128Gcm, - EncryptionAlgorithm.A192Gcm, - EncryptionAlgorithm.A256Gcm, - }; + yield return EncryptionAlgorithm.A128Cbc; + yield return EncryptionAlgorithm.A192Cbc; + yield return EncryptionAlgorithm.A256Cbc; + + yield return EncryptionAlgorithm.A128CbcPad; + yield return EncryptionAlgorithm.A192CbcPad; + yield return EncryptionAlgorithm.A256CbcPad; + + if (includeAesGcm) + { + yield return EncryptionAlgorithm.A128Gcm; + yield return EncryptionAlgorithm.A192Gcm; + yield return EncryptionAlgorithm.A256Gcm; + } + } } }