Skip to content

Commit

Permalink
AAD: with InMemoryChannel (#2290)
Browse files Browse the repository at this point in the history
* aad with InMemoryChannel
  • Loading branch information
TimothyMothra authored May 29, 2021
1 parent 9ecce67 commit 72fbd32
Show file tree
Hide file tree
Showing 11 changed files with 277 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace Microsoft.ApplicationInsights.TestFramework.Extensibility.Implementati
using System;
using System.Threading.Tasks;

using Microsoft.ApplicationInsights.Channel;
using Microsoft.ApplicationInsights.Extensibility;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Authentication;
using Microsoft.VisualStudio.TestTools.UnitTesting;
Expand Down Expand Up @@ -47,6 +48,50 @@ public void VerifyCannotSetInvalidObjectOnTelemetryConfiguration()
var telemetryConfiguration = new TelemetryConfiguration();
telemetryConfiguration.SetAzureTokenCredential(Guid.Empty);
}

[TestMethod]
public void VerifySetCredential_CorrectlySetsTelemetryChannel_CredentialFirst()
{
// SETUP
var tc = TelemetryConfiguration.CreateDefault();
Assert.IsInstanceOfType(tc.TelemetryChannel, typeof(InMemoryChannel));
Assert.IsTrue(tc.TelemetryChannel.EndpointAddress.Contains("v2")); // defaults to old api

// ACT
// set credential first
tc.SetAzureTokenCredential(new MockCredential());
Assert.IsTrue(tc.TelemetryChannel.EndpointAddress.Contains("v2.1")); // api switch

// test new channel
var channel = new InMemoryChannel();
Assert.IsNull(channel.EndpointAddress); // new channel defaults null

// change config channel
tc.TelemetryChannel = channel;
Assert.IsTrue(channel.EndpointAddress.Contains("v2.1")); // configuration sets new api
}

[TestMethod]
public void VerifySetCredential_CorrectlySetsTelemetryChannel_TelemetryChannelFirst()
{
// SETUP
var tc = TelemetryConfiguration.CreateDefault();
Assert.IsInstanceOfType(tc.TelemetryChannel, typeof(InMemoryChannel));
Assert.IsTrue(tc.TelemetryChannel.EndpointAddress.Contains("v2")); // defaults to old api

// ACT
// set new channel first
var channel = new InMemoryChannel();
Assert.IsNull(channel.EndpointAddress); // new channel defaults null

// change config channel
tc.TelemetryChannel = channel;
Assert.IsTrue(channel.EndpointAddress.Contains("v2")); // configuration sets new api

// set credential second
tc.SetAzureTokenCredential(new MockCredential());
Assert.IsTrue(tc.TelemetryChannel.EndpointAddress.Contains("v2.1")); // api switch
}
}
}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#if !NET452 && !NET46
namespace Microsoft.ApplicationInsights.TestFramework.Extensibility.Implementation.Authentication
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;

using Microsoft.ApplicationInsights.Channel;
using Microsoft.ApplicationInsights.DataContracts;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Authentication;
using Microsoft.VisualStudio.TestTools.UnitTesting;

/// <summary>
/// These tests verify that <see cref="Transmission"/> can receive and store an instance of <see cref="Azure.Core.TokenCredential"/>.
/// </summary>
/// <remarks>
/// These tests do not run in NET452 OR NET46.
/// In these cases, the test runner is NET452 or NET46 and Azure.Core.TokenCredential is NOT SUPPORTED in these frameworks.
/// This does not affect the end user because we REQUIRE the end user to create their own instance of TokenCredential.
/// This ensures that the end user is consuming the AI SDK in one of the newer frameworks.
/// </remarks>
[TestClass]
[TestCategory("AAD")]
public class TransmissionCredentialEnvelopeTests
{
private readonly Uri testUri = new Uri("https://127.0.0.1/");

[TestMethod]
public async Task VerifyTransmissionSendAsync_Default()
{
var handler = new HandlerForFakeHttpClient
{
InnerHandler = new HttpClientHandler(),
OnSendAsync = (req, cancellationToken) =>
{
// VALIDATE
Assert.IsNull(req.Headers.Authorization);
return Task.FromResult<HttpResponseMessage>(new HttpResponseMessage());
}
};

using (var fakeHttpClient = new HttpClient(handler))
{
var expectedContentType = "content/type";
var expectedContentEncoding = "contentEncoding";
var items = new List<ITelemetry> { new EventTelemetry() };

// Instantiate Transmission with the mock HttpClient
var transmission = new Transmission(testUri, new byte[] { 1, 2, 3, 4, 5 }, fakeHttpClient, expectedContentType, expectedContentEncoding);

var result = await transmission.SendAsync();
}
}

[TestMethod]
public async Task VerifyTransmissionSendAsync_WithCredential_SetsAuthHeader()
{
var credendialEnvelope = new ReflectionCredentialEnvelope(new MockCredential());
var token = credendialEnvelope.GetToken();


var handler = new HandlerForFakeHttpClient
{
InnerHandler = new HttpClientHandler(),
OnSendAsync = (req, cancellationToken) =>
{
// VALIDATE
Assert.AreEqual(AuthConstants.AuthorizationTokenPrefix.Trim(), req.Headers.Authorization.Scheme);
Assert.AreEqual(token, req.Headers.Authorization.Parameter);
return Task.FromResult<HttpResponseMessage>(new HttpResponseMessage());
}
};

using (var fakeHttpClient = new HttpClient(handler))
{
var expectedContentType = "content/type";
var expectedContentEncoding = "contentEncoding";
var items = new List<ITelemetry> { new EventTelemetry() };

// Instantiate Transmission with the mock HttpClient
var transmission = new Transmission(testUri, new byte[] { 1, 2, 3, 4, 5 }, fakeHttpClient, expectedContentType, expectedContentEncoding);
transmission.CredentialEnvelope = credendialEnvelope;

var result = await transmission.SendAsync();
}
}
}
}
#endif
18 changes: 18 additions & 0 deletions BASE/src/Microsoft.ApplicationInsights/Channel/InMemoryChannel.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
namespace Microsoft.ApplicationInsights.Channel
{
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.ApplicationInsights.Common;
using Microsoft.ApplicationInsights.Extensibility;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Authentication;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Tracing;

/// <summary>
Expand Down Expand Up @@ -122,6 +126,20 @@ public int BacklogSize
set { this.buffer.BacklogSize = value; }
}

/// <summary>
/// Gets or sets the <see cref="CredentialEnvelope"/> which is used for AAD.
/// FOR INTERNAL USE. Customers should use <see cref="TelemetryConfiguration.SetAzureTokenCredential"/> instead.
/// </summary>
/// <remarks>
/// <see cref="InMemoryChannel.CredentialEnvelope"/> sets <see cref="InMemoryTransmitter.CredentialEnvelope"/>
/// which is used to set <see cref="Transmission.CredentialEnvelope"/> just before calling <see cref="Transmission.SendAsync"/>.
/// </remarks>
internal CredentialEnvelope CredentialEnvelope
{
get => this.transmitter.CredentialEnvelope;
set => this.transmitter.CredentialEnvelope = value;
}

internal bool IsDisposed => this.isDisposed;

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ namespace Microsoft.ApplicationInsights.Channel
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.ApplicationInsights.Common.Extensions;
using Microsoft.ApplicationInsights.Extensibility;
using Microsoft.ApplicationInsights.Extensibility.Implementation;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Authentication;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Tracing;

/// <summary>
Expand All @@ -32,13 +34,13 @@ internal class InMemoryTransmitter : IDisposable
[SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", Justification = "Object is disposed within the using statement of the " + nameof(Runner) + " method.")]
private AutoResetEvent startRunnerEvent;
private bool enabled = true;

/// <summary>
/// The number of times this object was disposed.
/// </summary>
private int disposeCount = 0;
private TimeSpan sendingInterval = TimeSpan.FromSeconds(30);

internal InMemoryTransmitter(TelemetryBuffer buffer)
{
this.buffer = buffer;
Expand All @@ -47,11 +49,11 @@ internal InMemoryTransmitter(TelemetryBuffer buffer)
// Starting the Runner
Task.Factory.StartNew(this.Runner, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default)
.ContinueWith(
task =>
task =>
{
string msg = string.Format(CultureInfo.InvariantCulture, "InMemoryTransmitter: Unhandled exception in Runner: {0}", task.Exception);
CoreEventSource.Log.LogVerbose(msg);
},
},
TaskContinuationOptions.OnlyOnFaulted);
}

Expand All @@ -63,6 +65,15 @@ internal TimeSpan SendingInterval
set { this.sendingInterval = value; }
}

/// <summary>
/// Gets or sets the <see cref="CredentialEnvelope"/> which is used for AAD.
/// </summary>
/// <remarks>
/// <see cref="InMemoryChannel.CredentialEnvelope"/> sets <see cref="InMemoryTransmitter.CredentialEnvelope"/>
/// which is used to set <see cref="Transmission.CredentialEnvelope"/> just before calling <see cref="Transmission.SendAsync"/>.
/// </remarks>
internal CredentialEnvelope CredentialEnvelope { get; set; }

/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
Expand Down Expand Up @@ -163,7 +174,7 @@ private Task Send(IEnumerable<ITelemetry> telemetryItems, TimeSpan timeout)
}

var transmission = new Transmission(this.EndpointAddress, data, JsonSerializer.ContentType, JsonSerializer.CompressionType, timeout);

transmission.CredentialEnvelope = this.CredentialEnvelope;
return transmission.SendAsync();
}

Expand Down
24 changes: 24 additions & 0 deletions BASE/src/Microsoft.ApplicationInsights/Channel/Transmission.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.ApplicationInsights.Extensibility.Implementation;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Authentication;
using Microsoft.ApplicationInsights.Extensibility.Implementation.Tracing;

/// <summary>
Expand Down Expand Up @@ -141,6 +143,12 @@ public ICollection<ITelemetry> TelemetryItems
get; private set;
}

/// <summary>
/// Gets or sets the <see cref="CredentialEnvelope"/> which is used for AAD.
/// This is used include an AAD token on HTTP Requests sent to ingestion.
/// </summary>
internal CredentialEnvelope CredentialEnvelope { get; set; }

/// <summary>
/// Gets the flush async id for the transmission.
/// </summary>
Expand Down Expand Up @@ -404,6 +412,22 @@ protected virtual HttpRequestMessage CreateRequestMessage(Uri address, Stream co
request.Content.Headers.Add(ContentEncodingHeader, this.ContentEncoding);
}

if (this.CredentialEnvelope != null)
{
// TODO: NEED TO USE CACHING HERE
var authToken = this.CredentialEnvelope.GetToken();

if (authToken == null)
{
// TODO: DO NOT SEND. RETURN FAILURE AND LET CHANNEL DECIDE WHEN TO RETRY.
// This could be either a configuration error or the AAD service is unavailable.
}
else
{
request.Headers.TryAddWithoutValidation(AuthConstants.AuthorizationHeaderName, AuthConstants.AuthorizationTokenPrefix + authToken);
}
}

return request;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
{
internal static class AuthConstants
{
public const string AuthorizationHeaderName = "Authorization";

public const string AuthorizationTokenPrefix = "Bearer ";

/// <summary>
/// Source:
/// (https://docs.microsoft.com/azure/active-directory/develop/msal-acquire-cache-tokens#scopes-when-acquiring-tokens).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ public abstract class CredentialEnvelope
/// <summary>
/// Gets an Azure.Core.AccessToken.
/// </summary>
/// <remarks>
/// Whomever uses this MUST verify that it's called within <see cref="SdkInternalOperationsMonitor.Enter"/> otherwise dependency calls will be tracked.
/// </remarks>
/// <param name="cancellationToken">The System.Threading.CancellationToken to use.</param>
/// <returns>A valid Azure.Core.AccessToken.</returns>
public abstract string GetToken(CancellationToken cancellationToken = default);

/// <summary>
/// Gets an Azure.Core.AccessToken.
/// </summary>
/// <remarks>
/// Whomever uses this MUST verify that it's called within <see cref="SdkInternalOperationsMonitor.Enter"/> otherwise dependency calls will be tracked.
/// </remarks>
/// <param name="cancellationToken">The System.Threading.CancellationToken to use.</param>
/// <returns>A valid Azure.Core.AccessToken.</returns>
public abstract Task<string> GetTokenAsync(CancellationToken cancellationToken = default);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ public ReflectionCredentialEnvelope(object tokenCredential)
/// <summary>
/// Gets an Azure.Core.AccessToken.
/// </summary>
/// <remarks>
/// Whomever uses this MUST verify that it's called within <see cref="SdkInternalOperationsMonitor.Enter"/> otherwise dependency calls will be tracked.
/// </remarks>
/// <param name="cancellationToken">The System.Threading.CancellationToken to use.</param>
/// <returns>A valid Azure.Core.AccessToken.</returns>
public override string GetToken(CancellationToken cancellationToken = default)
{
SdkInternalOperationsMonitor.Enter();
// TODO: NEED TO FULLY TEST IF WE NEED TO CALL SdkInternalOperationsMonitor.Enter
try
{
return AzureCore.InvokeGetToken(this.tokenCredential, this.tokenRequestContext, cancellationToken);
Expand All @@ -60,20 +63,19 @@ public override string GetToken(CancellationToken cancellationToken = default)
CoreEventSource.Log.FailedToGetToken(ex.ToInvariantString());
return null;
}
finally
{
SdkInternalOperationsMonitor.Exit();
}
}

/// <summary>
/// Gets an Azure.Core.AccessToken.
/// </summary>
/// <remarks>
/// Whomever uses this MUST verify that it's called within <see cref="SdkInternalOperationsMonitor.Enter"/> otherwise dependency calls will be tracked.
/// </remarks>
/// <param name="cancellationToken">The System.Threading.CancellationToken to use.</param>
/// <returns>A valid Azure.Core.AccessToken.</returns>
public override async Task<string> GetTokenAsync(CancellationToken cancellationToken = default)
{
SdkInternalOperationsMonitor.Enter();
// TODO: NEED TO FULLY TEST IF WE NEED TO CALL SdkInternalOperationsMonitor.Enter
try
{
return await AzureCore.InvokeGetTokenAsync(this.tokenCredential, this.tokenRequestContext, cancellationToken).ConfigureAwait(false);
Expand All @@ -83,10 +85,6 @@ public override async Task<string> GetTokenAsync(CancellationToken cancellationT
CoreEventSource.Log.FailedToGetToken(ex.ToInvariantString());
return null;
}
finally
{
SdkInternalOperationsMonitor.Exit();
}
}

/// <summary>
Expand Down
Loading

0 comments on commit 72fbd32

Please sign in to comment.