Skip to content

Commit

Permalink
Make RegisterValidAudience aware of JsonWebToken (#2421)
Browse files Browse the repository at this point in the history
RegisterValidAudience needed the claims but only knew about JwtSecurityToken. This updates it to retrieve the claims from JwtSecurityToken or a JsonWebToken.
  • Loading branch information
twsouthwick authored Sep 4, 2023
1 parent a8bf5aa commit 85c1a75
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 20 deletions.
14 changes: 8 additions & 6 deletions src/Microsoft.Identity.Web/Resource/RegisterValidAudience.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using Microsoft.Extensions.Options;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Tokens;

namespace Microsoft.Identity.Web.Resource
Expand Down Expand Up @@ -57,11 +58,12 @@ public void RegisterAudienceValidation(
SecurityToken securityToken,
TokenValidationParameters validationParameters)
{
JwtSecurityToken? token = securityToken as JwtSecurityToken;
if (token == null)
var claims = securityToken switch
{
throw new SecurityTokenValidationException(IDWebErrorMessage.TokenIsNotJwtToken);
}
JwtSecurityToken jwtSecurityToken => jwtSecurityToken.Claims,
JsonWebToken jwtWebToken => jwtWebToken.Claims,
_ => throw new SecurityTokenValidationException(IDWebErrorMessage.TokenIsNotJwtToken),
};

validationParameters.AudienceValidator = null;

Expand All @@ -70,13 +72,13 @@ public void RegisterAudienceValidation(
validationParameters.ValidAudiences == null)
{
// handle v2.0 access token or Azure AD B2C tokens (even if v1.0)
if (IsB2C || token.Claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V2))
if (IsB2C || claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V2))
{
validationParameters.ValidAudience = $"{ClientId}";
}

// handle v1.0 access token
else if (token.Claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V1))
else if (claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V1))
{
validationParameters.ValidAudience = $"api://{ClientId}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Identity.Client;
using Microsoft.Identity.Web.Internal;
using Microsoft.IdentityModel.JsonWebTokens;

namespace Microsoft.Identity.Web
{
Expand Down Expand Up @@ -99,7 +100,8 @@ internal static void CallsWebApiImplementation(

options.Events.OnTokenValidated = async context =>
{
context.HttpContext.StoreTokenUsedToCallWebAPI(context.SecurityToken as JwtSecurityToken);
// Only pass through a token if it is of an expected type
context.HttpContext.StoreTokenUsedToCallWebAPI(context.SecurityToken is JwtSecurityToken or JsonWebToken ? context.SecurityToken : null);
await onTokenValidatedHandler(context).ConfigureAwait(false);
};
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Security.Claims;
using Microsoft.Identity.Web.Resource;
using Microsoft.Identity.Web.Test.Common;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Tokens;
using Xunit;

Expand All @@ -20,7 +22,7 @@ public class RegisterValidAudienceTests
private const string V2 = "2.0";
private const string V3 = "3.0";
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
private JwtSecurityToken _token;
private SecurityToken _token;
private RegisterValidAudience _registerValidAudience;
private TokenValidationParameters _validationParams;
private IEnumerable<string> _validAudiences;
Expand All @@ -36,7 +38,19 @@ public void ValidateAudience_FromToken(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceFromToken();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_FromToken_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceFromToken();
}

Expand All @@ -48,7 +62,19 @@ public void ValidateAudience_ProvidedInValidAudience(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceProvidedInValidAudience();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_ProvidedInValidAudience_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceProvidedInValidAudience();
}

Expand All @@ -60,7 +86,19 @@ public void ValidateAudience_ProvidedInValidAudiences(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceProvidedInValidAudiences();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_ProvidedInValidAudiences_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceProvidedInValidAudiences();
}

Expand All @@ -70,13 +108,24 @@ public void InvalidAudience_AssertFails(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertFailureOnInvalidAudienceInToken();
}

[Theory]
[InlineData(false, V3)]
public void InvalidAudience_AssertFails_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertFailureOnInvalidAudienceInToken();
}

private void InitializeTests(
bool isB2C,
string tokenVersion)
string tokenVersion,
Func<IEnumerable<Claim>, SecurityToken> tokenGenerator)
{
_options = new MicrosoftIdentityOptions
{
Expand All @@ -103,7 +152,7 @@ private void InitializeTests(
new Claim(Audience, _expectedAudience),
};

_token = new JwtSecurityToken(null, null, claims);
_token = tokenGenerator(claims);
_validationParams = new TokenValidationParameters();
_registerValidAudience = new RegisterValidAudience();
_registerValidAudience.RegisterAudienceValidation(_validationParams, _options);
Expand All @@ -116,8 +165,8 @@ private void AssertAudienceFromToken()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private void AssertAudienceProvidedInValidAudience()
Expand All @@ -127,8 +176,8 @@ private void AssertAudienceProvidedInValidAudience()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private void AssertAudienceProvidedInValidAudiences()
Expand All @@ -144,16 +193,38 @@ private void AssertAudienceProvidedInValidAudiences()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private IEnumerable<string> Audiences => _token switch
{
JwtSecurityToken s => s.Audiences,
TestJsonWebToken w => w.Audiences,
_ => throw new System.NotImplementedException(),
};

private void AssertFailureOnInvalidAudienceInToken()
{
Assert.Throws<SecurityTokenInvalidAudienceException>(() => _registerValidAudience.ValidateAudience(
_validAudiences,
_token,
_validationParams));
}

private class TestJsonWebToken : JsonWebToken
{
private const string TestJwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";

public TestJsonWebToken(IEnumerable<Claim> claims)
: base(TestJwt)
{
Claims = claims;
}

public override IEnumerable<Claim> Claims { get; }

public new IEnumerable<string> Audiences => Claims.Where(c => c.Type == Audience).Select(c => c.Value);
}
}
}

0 comments on commit 85c1a75

Please sign in to comment.