Skip to content

Commit

Permalink
feat: add national cloud support (#332)
Browse files Browse the repository at this point in the history
feat: add national cloud support

Allows logging in to national clouds using the `--environment` CLI option.
See microsoftgraph/msgraph-cli#396

perf: enable concurrent io when clearing the token cache
  • Loading branch information
calebkiage committed Feb 1, 2024
1 parent e9b73fc commit f78f591
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 62 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,6 @@ FodyWeavers.xsd
*.sln.iml

### VisualStudio Patch ###
# Additional files built by Visual Studio
# Additional files built by Visual Studio

.env.local
55 changes: 45 additions & 10 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"request": "launch",
"preLaunchTask": "build sample",
// If you have changed target frameworks, make sure to update the program path.
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--strategy",
Expand All @@ -30,7 +30,7 @@
"request": "launch",
"preLaunchTask": "build sample",
// If you have changed target frameworks, make sure to update the program path.
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--strategy",
Expand All @@ -46,7 +46,7 @@
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--strategy",
Expand All @@ -67,7 +67,7 @@
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--strategy",
Expand All @@ -77,6 +77,23 @@
"--client-id",
"e49807f2-94cc-4f59-9e14-be2a37eab7c2"
],
"envFile": "${workspaceFolder}/.env.local",
"cwd": "${workspaceFolder}/src/sample",
"console": "internalConsole",
"stopAtEntry": false
},
{
"name": "login national cloud (sample)",
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--environment",
"US_GOV"
],
"envFile": "${workspaceFolder}/.env.local",
"cwd": "${workspaceFolder}/src/sample",
"console": "internalConsole",
"stopAtEntry": false
Expand All @@ -86,13 +103,32 @@
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"users",
"list",
"--debug",
"--top",
"2"
"2",
"--headers",
"sample=header"
],
"envFile": "${workspaceFolder}/.env.local",
"cwd": "${workspaceFolder}/src/sample",
"console": "integratedTerminal",
"stopAtEntry": false,
"justMyCode": false
},
{
"name": "me get (sample)",
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"me",
"get",
"--debug"
],
"cwd": "${workspaceFolder}/src/sample",
"console": "integratedTerminal",
Expand All @@ -103,7 +139,7 @@
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"logout",
"--debug"
Expand All @@ -117,9 +153,8 @@
"type": "coreclr",
"request": "launch",
"preLaunchTask": "build sample",
"program": "${workspaceFolder}/src/sample/bin/Debug/net7.0/sample.dll",
"program": "${workspaceFolder}/src/sample/bin/Debug/net8.0/sample.dll",
"args": [
"login",
"--help"
],
"cwd": "${workspaceFolder}/src/sample",
Expand All @@ -132,4 +167,4 @@
"request": "attach"
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ public AuthenticationServiceFactory(IPathUtility pathUtility, IAuthenticationCac
/// <param name="clientId">Client Id</param>
/// <param name="certificateName">Certificate name</param>
/// <param name="certificateThumbPrint">Certificate thumb-print</param>
/// <param name="environment">The national cloud environment. Either 'Global', 'US_GOV', 'US_GOV_DOD' or 'China'</param>
/// <param name="cancellationToken">Cancellation token</param>
/// <returns>Returns a login service instance.</returns>
/// <exception cref="InvalidOperationException">When an unsupported authentication strategy is provided.</exception>
public virtual async Task<LoginServiceBase> GetAuthenticationServiceAsync(AuthenticationStrategy strategy, string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, CancellationToken cancellationToken = default)
public virtual async Task<LoginServiceBase> GetAuthenticationServiceAsync(AuthenticationStrategy strategy, string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, CloudEnvironment environment, CancellationToken cancellationToken = default)
{
var credential = await GetTokenCredentialAsync(strategy, tenantId, clientId, certificateName, certificateThumbPrint, cancellationToken);
var credential = await GetTokenCredentialAsync(strategy, tenantId, clientId, certificateName, certificateThumbPrint, environment, cancellationToken);
if (strategy == AuthenticationStrategy.DeviceCode && credential is DeviceCodeCredential deviceCred)
{
return new InteractiveLoginService<DeviceCodeCredential>(deviceCred, pathUtility);
Expand Down Expand Up @@ -81,35 +82,33 @@ public virtual async Task<LoginServiceBase> GetAuthenticationServiceAsync(Authen
/// <param name="clientId">Client Id</param>
/// <param name="certificateName">Certificate name</param>
/// <param name="certificateThumbPrint">Certificate thumb-print</param>
/// <param name="environment">The cloud environment. <see cref="CloudEnvironment"/></param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>A token credential instance.</returns>
/// <exception cref="InvalidOperationException">When an unsupported authentication strategy is provided.</exception>
public virtual async Task<TokenCredential> GetTokenCredentialAsync(AuthenticationStrategy strategy, string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, CancellationToken cancellationToken = default)
/// <exception cref="ArgumentNullException">When a null url is provided for the authority host.</exception>
public virtual async Task<TokenCredential> GetTokenCredentialAsync(AuthenticationStrategy strategy, string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, CloudEnvironment environment, CancellationToken cancellationToken = default)
{
switch (strategy)
var authorityHost = environment.Authority();
return strategy switch
{
case AuthenticationStrategy.DeviceCode:
return await GetDeviceCodeCredentialAsync(tenantId, clientId, cancellationToken);
case AuthenticationStrategy.InteractiveBrowser:
return await GetInteractiveBrowserCredentialAsync(tenantId, clientId, cancellationToken);
case AuthenticationStrategy.ClientCertificate:
return GetClientCertificateCredential(tenantId, clientId, certificateName, certificateThumbPrint);
case AuthenticationStrategy.Environment:
return new EnvironmentCredential(tenantId, clientId);
case AuthenticationStrategy.ManagedIdentity:
return new ManagedIdentityCredential(clientId);
default:
throw new InvalidOperationException($"The authentication strategy {strategy} is not supported");
}
AuthenticationStrategy.DeviceCode => await GetDeviceCodeCredentialAsync(tenantId, clientId, authorityHost, cancellationToken),
AuthenticationStrategy.InteractiveBrowser => await GetInteractiveBrowserCredentialAsync(tenantId, clientId, authorityHost, cancellationToken),
AuthenticationStrategy.ClientCertificate => GetClientCertificateCredential(tenantId, clientId, certificateName, certificateThumbPrint, authorityHost),
AuthenticationStrategy.Environment => new EnvironmentCredential(tenantId, clientId, new TokenCredentialOptions { AuthorityHost = authorityHost }),
AuthenticationStrategy.ManagedIdentity => new ManagedIdentityCredential(clientId, new TokenCredentialOptions { AuthorityHost = authorityHost }),
_ => throw new InvalidOperationException($"The authentication strategy {strategy} is not supported"),
};
}

private async Task<DeviceCodeCredential> GetDeviceCodeCredentialAsync(string? tenantId, string? clientId, CancellationToken cancellationToken = default)
private async Task<DeviceCodeCredential> GetDeviceCodeCredentialAsync(string? tenantId, string? clientId, Uri authorityHost, CancellationToken cancellationToken = default)
{
DeviceCodeCredentialOptions credOptions = new()
{
ClientId = clientId ?? Constants.DefaultAppId,
TenantId = tenantId ?? Constants.DefaultTenant,
DisableAutomaticAuthentication = true,
AuthorityHost = authorityHost
};

TokenCachePersistenceOptions tokenCacheOptions = new() { Name = Constants.TokenCacheName };
Expand All @@ -119,13 +118,14 @@ private async Task<DeviceCodeCredential> GetDeviceCodeCredentialAsync(string? te
return new DeviceCodeCredential(credOptions);
}

private async Task<InteractiveBrowserCredential> GetInteractiveBrowserCredentialAsync(string? tenantId, string? clientId, CancellationToken cancellationToken = default)
private async Task<InteractiveBrowserCredential> GetInteractiveBrowserCredentialAsync(string? tenantId, string? clientId, Uri authorityHost, CancellationToken cancellationToken = default)
{
InteractiveBrowserCredentialOptions credOptions = new()
{
ClientId = clientId ?? Constants.DefaultAppId,
TenantId = tenantId ?? Constants.DefaultTenant,
DisableAutomaticAuthentication = true,
AuthorityHost = authorityHost
};

TokenCachePersistenceOptions tokenCacheOptions = new() { Name = Constants.TokenCacheName };
Expand All @@ -135,8 +135,8 @@ private async Task<InteractiveBrowserCredential> GetInteractiveBrowserCredential
return new InteractiveBrowserCredential(credOptions);
}

private ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint)
private ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, Uri authorityHost)
{
return ClientCertificateCredentialFactory.GetClientCertificateCredential(tenantId ?? Constants.DefaultTenant, clientId ?? Constants.DefaultAppId, certificateName, certificateThumbPrint);
return ClientCertificateCredentialFactory.GetClientCertificateCredential(tenantId ?? Constants.DefaultTenant, clientId ?? Constants.DefaultAppId, certificateName, certificateThumbPrint, authorityHost);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Azure.Identity;
using Microsoft.Graph.Cli.Core.Utils;

namespace Microsoft.Graph.Cli.Core.Authentication;

Expand All @@ -19,15 +20,17 @@ public static class ClientCertificateCredentialFactory
/// <param name="clientId">ClientId</param>
/// <param name="certificateName">Subject name of the certificate.</param>
/// <param name="certificateThumbPrint">Thumb print of the certificate.</param>
/// <param name="authorityHost">The entra authentication endpoint (to use with national clouds)</param>
/// <returns>A ClientCertificateCredential</returns>
public static ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint)
/// <exception cref="ArgumentNullException">When a null url is provided for the authority host.</exception>
public static ClientCertificateCredential GetClientCertificateCredential(string? tenantId, string? clientId, string? certificateName, string? certificateThumbPrint, Uri authorityHost)
{
if (string.IsNullOrWhiteSpace(certificateName) && string.IsNullOrWhiteSpace(certificateThumbPrint))
{
throw new ArgumentException("Either a certificate name or a certificate thumb print must be provided.");
}

ClientCertificateCredentialOptions credOptions = new();
ClientCertificateCredentialOptions credOptions = new() { AuthorityHost = authorityHost };

// // TODO: Enable token caching

Check warning on line 35 in src/Microsoft.Graph.Cli.Core/Authentication/ClientCertificateCredentialFactory.cs

View workflow job for this annotation

GitHub Actions / Build

Complete the task associated to this 'TODO' comment. (https://rules.sonarsource.com/csharp/RSPEC-1135)
// // Fix error:
Expand Down
73 changes: 73 additions & 0 deletions src/Microsoft.Graph.Cli.Core/Authentication/CloudEnvironment.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using System;
using Azure.Identity;
using Microsoft.Graph;
using Microsoft.Graph.Cli.Core.Utils;


/// <summary>
/// The cloud environment to use
/// </summary>
public enum CloudEnvironment

Check warning on line 10 in src/Microsoft.Graph.Cli.Core/Authentication/CloudEnvironment.cs

View workflow job for this annotation

GitHub Actions / Build

Move 'CloudEnvironment' into a named namespace. (https://rules.sonarsource.com/csharp/RSPEC-3903)
{
/// <summary>
/// Global environment.
/// </summary>
Global,
/// <summary>
/// US Government cloud environment.
/// </summary>
USGov,
/// <summary>
/// US Government Department of Defense (DoD) cloud environment.
/// </summary>
USGovDoD,
/// <summary>
/// China cloud environment.
/// </summary>
China,
}

/// <summary>
/// Provides methods for the <see cref="CloudEnvironment"/> class.
/// </summary>
public static class CloudEnvironmentExtensions

Check warning on line 33 in src/Microsoft.Graph.Cli.Core/Authentication/CloudEnvironment.cs

View workflow job for this annotation

GitHub Actions / Build

Move 'CloudEnvironmentExtensions' into a named namespace. (https://rules.sonarsource.com/csharp/RSPEC-3903)
{
/// <summary>
/// Gets the authority URL for the specified cloud environment.
/// </summary>
/// <param name="environment">The cloud environment.</param>
/// <returns>The authority URL.</returns>
/// <exception cref="ArgumentException">
/// If the cloud environment is not one of the <see cref="CloudEnvironment"/> members.
/// </exception>
public static Uri Authority(this CloudEnvironment environment)
{
return environment switch
{
CloudEnvironment.Global => AzureAuthorityHosts.AzurePublicCloud,
CloudEnvironment.USGov or CloudEnvironment.USGovDoD => AzureAuthorityHosts.AzureGovernment,
CloudEnvironment.China => AzureAuthorityHosts.AzureChina,
_ => throw new ArgumentException("Unknown cloud environment", nameof(environment))
};
}

/// <summary>
/// Gets the GraphClient Cloud identifier.
/// </summary>
/// <param name="environment">The cloud environment.</param>
/// <returns>The cloud identifier to be used by the graph client.</returns>
/// <exception cref="ArgumentException">
/// If the cloud environment is not one of the <see cref="CloudEnvironment"/> members.
/// </exception>
public static string GraphClientCloud(this CloudEnvironment environment)
{
return environment switch
{
CloudEnvironment.Global => GraphClientFactory.Global_Cloud,
CloudEnvironment.USGov => GraphClientFactory.USGOV_Cloud,
CloudEnvironment.USGovDoD => GraphClientFactory.USGOV_DOD_Cloud,
CloudEnvironment.China => GraphClientFactory.China_Cloud,
_ => throw new ArgumentException("Unknown cloud environment", nameof(environment))
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ public EnvironmentCredential(string? tenantId, string? clientId, TokenCredential
bool sendCertificateChain = !string.IsNullOrEmpty(clientSendCertificateChain) &&
(clientSendCertificateChain == "1" || clientSendCertificateChain == "true");

ClientCertificateCredentialOptions clientCertificateCredentialOptions = new ClientCertificateCredentialOptions
ClientCertificateCredentialOptions clientCertificateCredentialOptions = new()
{
AuthorityHost = _options.AuthorityHost,
Transport = _options.Transport,
SendCertificateChain = sendCertificateChain
SendCertificateChain = sendCertificateChain,
};
// Use reflection to set internal properties.
X509Certificate2? cert;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public sealed class LoginCommand : Command

private Option<string> certificateThumbPrintOption = new("--certificate-thumb-print", "The thumbprint of your certificate. The certificate will be retrieved from the current user's certificate store.");

private Option<CloudEnvironment> environmentOption = new("--environment", () => CloudEnvironment.Global, "Select the cloud environment to log in to. If login is run without providing an environment, Global is used.");

private Option<AuthenticationStrategy> strategyOption = new("--strategy", () => Constants.defaultAuthStrategy);

internal LoginCommand() : base("login", "Login and store the session for use in subsequent commands")
Expand All @@ -36,6 +38,7 @@ public sealed class LoginCommand : Command
AddOption(tenantIdOption);
AddOption(certificateNameOption);
AddOption(certificateThumbPrintOption);
AddOption(environmentOption);
AddOption(strategyOption);
this.SetHandler(async (context) =>
{
Expand All @@ -44,15 +47,16 @@ public sealed class LoginCommand : Command
var tenantId = context.ParseResult.GetValueForOption(tenantIdOption);
var certificateName = context.ParseResult.GetValueForOption(certificateNameOption);
var certificateThumbPrint = context.ParseResult.GetValueForOption(certificateThumbPrintOption);
var environment = context.ParseResult.GetValueForOption(environmentOption);
var strategy = context.ParseResult.GetValueForOption(strategyOption);
var cancellationToken = context.GetCancellationToken();
var authUtil = context.BindingContext.GetRequiredService<IAuthenticationCacheManager>();
var authSvcFactory = context.BindingContext.GetRequiredService<AuthenticationServiceFactory>();
var authService = await authSvcFactory.GetAuthenticationServiceAsync(strategy, tenantId, clientId, certificateName, certificateThumbPrint, cancellationToken);
var authService = await authSvcFactory.GetAuthenticationServiceAsync(strategy, tenantId, clientId, certificateName, certificateThumbPrint, environment, cancellationToken);
await authService.LoginAsync(scopes, cancellationToken);
await authUtil.SaveAuthenticationIdentifiersAsync(clientId, tenantId, certificateName, certificateThumbPrint, strategy, cancellationToken);
await authUtil.SaveAuthenticationIdentifiersAsync(clientId, tenantId, certificateName, certificateThumbPrint, strategy, environment, cancellationToken);
});
}

Expand Down
Loading

0 comments on commit f78f591

Please sign in to comment.