Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RegisterValidAudience aware of JsonWebToken #2421

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/Microsoft.Identity.Web/Resource/RegisterValidAudience.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using Microsoft.Extensions.Options;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Tokens;

namespace Microsoft.Identity.Web.Resource
Expand Down Expand Up @@ -57,11 +58,12 @@ public void RegisterAudienceValidation(
SecurityToken securityToken,
TokenValidationParameters validationParameters)
{
JwtSecurityToken? token = securityToken as JwtSecurityToken;
if (token == null)
var claims = securityToken switch
jmprieur marked this conversation as resolved.
Show resolved Hide resolved
{
throw new SecurityTokenValidationException(IDWebErrorMessage.TokenIsNotJwtToken);
}
JwtSecurityToken jwtSecurityToken => jwtSecurityToken.Claims,
JsonWebToken jwtWebToken => jwtWebToken.Claims,
_ => throw new SecurityTokenValidationException(IDWebErrorMessage.TokenIsNotJwtToken),
};

validationParameters.AudienceValidator = null;

Expand All @@ -70,13 +72,13 @@ public void RegisterAudienceValidation(
validationParameters.ValidAudiences == null)
{
// handle v2.0 access token or Azure AD B2C tokens (even if v1.0)
if (IsB2C || token.Claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V2))
if (IsB2C || claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V2))
{
validationParameters.ValidAudience = $"{ClientId}";
}

// handle v1.0 access token
else if (token.Claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V1))
else if (claims.Any(c => c.Type == Constants.Version && c.Value == Constants.V1))
{
validationParameters.ValidAudience = $"api://{ClientId}";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Identity.Client;
using Microsoft.Identity.Web.Internal;
using Microsoft.IdentityModel.JsonWebTokens;

namespace Microsoft.Identity.Web
{
Expand Down Expand Up @@ -99,7 +100,8 @@ internal static void CallsWebApiImplementation(

options.Events.OnTokenValidated = async context =>
{
context.HttpContext.StoreTokenUsedToCallWebAPI(context.SecurityToken as JwtSecurityToken);
// Only pass through a token if it is of an expected type
context.HttpContext.StoreTokenUsedToCallWebAPI(context.SecurityToken is JwtSecurityToken or JsonWebToken ? context.SecurityToken : null);
jmprieur marked this conversation as resolved.
Show resolved Hide resolved
await onTokenValidatedHandler(context).ConfigureAwait(false);
};
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Security.Claims;
using Microsoft.Identity.Web.Resource;
using Microsoft.Identity.Web.Test.Common;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Tokens;
using Xunit;

Expand All @@ -20,7 +22,7 @@ public class RegisterValidAudienceTests
private const string V2 = "2.0";
private const string V3 = "3.0";
#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor. Consider declaring as nullable.
private JwtSecurityToken _token;
private SecurityToken _token;
private RegisterValidAudience _registerValidAudience;
private TokenValidationParameters _validationParams;
private IEnumerable<string> _validAudiences;
Expand All @@ -36,7 +38,19 @@ public void ValidateAudience_FromToken(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceFromToken();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_FromToken_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceFromToken();
}

Expand All @@ -48,7 +62,19 @@ public void ValidateAudience_ProvidedInValidAudience(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceProvidedInValidAudience();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_ProvidedInValidAudience_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceProvidedInValidAudience();
}

Expand All @@ -60,7 +86,19 @@ public void ValidateAudience_ProvidedInValidAudiences(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertAudienceProvidedInValidAudiences();
}

[Theory]
[InlineData(false, V1)]
[InlineData(false, V2)]
[InlineData(true, V1)]
public void ValidateAudience_ProvidedInValidAudiences_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertAudienceProvidedInValidAudiences();
}

Expand All @@ -70,13 +108,24 @@ public void InvalidAudience_AssertFails(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion);
InitializeTests(isB2C, tokenVersion, claims => new JwtSecurityToken(null, null, claims));
AssertFailureOnInvalidAudienceInToken();
}

[Theory]
[InlineData(false, V3)]
public void InvalidAudience_AssertFails_JsonWeb(
bool isB2C,
string tokenVersion)
{
InitializeTests(isB2C, tokenVersion, claims => new TestJsonWebToken(claims));
AssertFailureOnInvalidAudienceInToken();
}

private void InitializeTests(
bool isB2C,
string tokenVersion)
string tokenVersion,
Func<IEnumerable<Claim>, SecurityToken> tokenGenerator)
{
_options = new MicrosoftIdentityOptions
{
Expand All @@ -103,7 +152,7 @@ private void InitializeTests(
new Claim(Audience, _expectedAudience),
};

_token = new JwtSecurityToken(null, null, claims);
_token = tokenGenerator(claims);
_validationParams = new TokenValidationParameters();
_registerValidAudience = new RegisterValidAudience();
_registerValidAudience.RegisterAudienceValidation(_validationParams, _options);
Expand All @@ -116,8 +165,8 @@ private void AssertAudienceFromToken()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private void AssertAudienceProvidedInValidAudience()
Expand All @@ -127,8 +176,8 @@ private void AssertAudienceProvidedInValidAudience()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private void AssertAudienceProvidedInValidAudiences()
Expand All @@ -144,16 +193,38 @@ private void AssertAudienceProvidedInValidAudiences()
_validAudiences,
_token,
_validationParams));
Assert.Equal(_expectedAudience, _token.Audiences.FirstOrDefault());
Assert.Single(_token.Audiences);
Assert.Equal(_expectedAudience, Audiences.FirstOrDefault());
Assert.Single(Audiences);
}

private IEnumerable<string> Audiences => _token switch
{
JwtSecurityToken s => s.Audiences,
TestJsonWebToken w => w.Audiences,
_ => throw new System.NotImplementedException(),
};

private void AssertFailureOnInvalidAudienceInToken()
{
Assert.Throws<SecurityTokenInvalidAudienceException>(() => _registerValidAudience.ValidateAudience(
_validAudiences,
_token,
_validationParams));
}

private class TestJsonWebToken : JsonWebToken
{
private const string TestJwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";

public TestJsonWebToken(IEnumerable<Claim> claims)
: base(TestJwt)
{
Claims = claims;
}

public override IEnumerable<Claim> Claims { get; }

public new IEnumerable<string> Audiences => Claims.Where(c => c.Type == Audience).Select(c => c.Value);
}
}
}