Skip to content

Commit

Permalink
Add the option to have UserAssignedMsi when authentication uses KeyVa…
Browse files Browse the repository at this point in the history
…ult certificate. Previously, the code would only use the system assigned MSI if any. (#14676)

Co-authored-by: dnicolescu <dnicolescu@hotmail.com>
  • Loading branch information
danicole and dnicolescu authored Sep 12, 2020
1 parent 3cc9799 commit 6ae13e6
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ private enum CertIdentifierType
/// <param name="certIdentifierType"></param>
/// <returns></returns>
[Theory]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier)]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, false)]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, true)]
[InlineData(CertIdentifierType.SubjectName)]
[InlineData(CertIdentifierType.Thumbprint)]
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType)
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType, bool useUserAssignedMsi = false)
{
string testCertUrl = Environment.GetEnvironmentVariable(Constants.TestCertUrlEnv);

Expand Down Expand Up @@ -208,7 +209,9 @@ private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType
connectionString = $"RunAs=App;AppId={app.AppId};TenantId={_tenantId};{thumbprintOrSubjectName};CertificateStoreLocation={Constants.CurrentUserStore};";
break;
case CertIdentifierType.KeyVaultCertificateSecretIdentifier:
connectionString = $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
connectionString = useUserAssignedMsi
? $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};KeyVaultUserAssignedManagedIdentityId={Constants.TestUserAssignedManagedIdentityId}" //TODO: figure out real MSI to use here. Also, does the test really use MSI or does it rely on the fallback?
: $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class Constants
public static readonly string CertificateConnStringThumbprintCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateThumbprint=123;CertificateStoreLocation=CurrentUser";
public static readonly string CertificateConnStringSubjectNameCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateSubjectName=123;CertificateStoreLocation=CurrentUser";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifier = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier;KeyVaultAppId={TestUserAssignedManagedIdentityId}";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
public static readonly string ClientSecretConnString = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};AppKey={ClientSecret}";
public static readonly string ConnectionStringEnvironmentVariableName = "AzureServicesAuthConnectionString";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public void CertValidTest()
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifier, provider.ConnectionString);
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);

provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, Constants.AzureAdInstance);
Assert.NotNull(provider);
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, provider.ConnectionString);
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);

provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, Constants.AzureAdInstance);
Assert.NotNull(provider);
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, provider.ConnectionString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task ThumbprintSuccessTest()

// Create ClientCertificateAzureServiceTokenProvider instance
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on thumbprint in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);
Expand All @@ -64,7 +64,7 @@ public async Task ThumbprintFailTest()

// Create ClientCertificateAzureServiceTokenProvider instance
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Ensure exception is thrown when getting the token
var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId));
Expand All @@ -89,12 +89,12 @@ public void ClientIdNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(null,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(string.Empty,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -114,12 +114,12 @@ public void StoreLocationNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -135,12 +135,12 @@ public void CertSubjectNameOrThumbprintNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -160,7 +160,7 @@ public void InvalidStoreLocationTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.InvalidCertLocationError, exception.ToString());
}
Expand All @@ -177,7 +177,7 @@ public async Task SubjectNameSuccessTest()

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty).ConfigureAwait(false);
Expand All @@ -204,7 +204,7 @@ public void CannotAcquireTokenThroughCertTest()

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var exception = Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty));
Expand All @@ -226,7 +226,7 @@ public async Task CertificateNotFoundTest()
MockAuthenticationContext mockAuthenticationContext = new MockAuthenticationContext(MockAuthenticationContext.MockAuthenticationContextTestType.AcquireTokenAsyncClientCertificateSuccess);

ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)));

Expand Down Expand Up @@ -257,7 +257,7 @@ public async Task KeyVaultCertificateSecretIdentifierSuccessTest(bool includeTen

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, mockAuthenticationContext, keyVaultClient);
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.ArmResourceId, string.Empty).ConfigureAwait(false);
Expand All @@ -283,7 +283,7 @@ public async Task KeyVaultCertificateNotFoundTest()

string SecretIdentifier = "https://testbedkeyvault.vault.azure.net/secrets/secret/";
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext, keyVaultClient);
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.ArmResourceId, Constants.TenantId)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class AzureServiceTokenProviderFactory
private const string CertificateSubjectName = "CertificateSubjectName";
private const string CertificateThumbprint = "CertificateThumbprint";
private const string KeyVaultCertificateSecretIdentifier = "KeyVaultCertificateSecretIdentifier";
private const string KeyVaultUserAssignedManagedIdentityId = "KeyVaultUserAssignedManagedIdentityId";
private const string CertificateStoreLocation = "CertificateStoreLocation";
private const string MsiRetryTimeout = "MsiRetryTimeout";

Expand Down Expand Up @@ -125,7 +126,7 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
azureAdInstance,
connectionSettings[TenantId],
0,
new AdalAuthenticationContext(httpClientFactory));
authenticationContext: new AdalAuthenticationContext(httpClientFactory));
}
else if (connectionSettings.ContainsKey(CertificateThumbprint) ||
connectionSettings.ContainsKey(CertificateSubjectName))
Expand All @@ -138,6 +139,11 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
{
ValidateMsiRetryTimeout(connectionSettings, connectionString);

var msiRetryTimeout = connectionSettings.ContainsKey(MsiRetryTimeout)
? int.Parse(connectionSettings[MsiRetryTimeout])
: 0;
connectionSettings.TryGetValue(KeyVaultUserAssignedManagedIdentityId, out var keyVaultUserAssignedManagedIdentityId);

azureServiceTokenProvider =
new ClientCertificateAzureServiceTokenProvider(
connectionSettings[AppId],
Expand All @@ -148,9 +154,8 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
connectionSettings.ContainsKey(TenantId) // tenantId can be specified in connection string or retrieved from Key Vault access token later
? connectionSettings[TenantId]
: default,
connectionSettings.ContainsKey(MsiRetryTimeout)
? int.Parse(connectionSettings[MsiRetryTimeout])
: 0,
msiRetryTimeout,
keyVaultUserAssignedManagedIdentityId,
new AdalAuthenticationContext(httpClientFactory));
}
else if (connectionSettings.ContainsKey(AppKey))
Expand Down
Loading

0 comments on commit 6ae13e6

Please sign in to comment.