Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate with EnableDetailedErrors #694

Merged
merged 2 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal class NegotiateMiddleware : OwinMiddleware

private readonly string _serverName;
private readonly ServerStickyMode _mode;
private readonly bool _enableDetailedErrors;

public NegotiateMiddleware(OwinMiddleware next, HubConfiguration configuration, string appName, IServiceEndpointManager endpointManager, IEndpointRouter router, ServiceOptions options, IServerNameProvider serverNameProvider, IConnectionRequestIdProvider connectionRequestIdProvider, ILoggerFactory loggerFactory)
: base(next)
Expand All @@ -55,6 +56,7 @@ public NegotiateMiddleware(OwinMiddleware next, HubConfiguration configuration,
_logger = loggerFactory?.CreateLogger<NegotiateMiddleware>() ?? throw new ArgumentNullException(nameof(loggerFactory));
_serverName = serverNameProvider?.GetName();
_mode = options.ServerStickyMode;
_enableDetailedErrors = configuration.EnableDetailedErrors;
}

public override Task Invoke(IOwinContext owinContext)
Expand Down Expand Up @@ -196,7 +198,7 @@ private IEnumerable<Claim> BuildClaims(IOwinContext owinContext, IRequest reques
var user = owinContext.Authentication?.User;
var userId = _provider?.GetUserId(request);

var claims = ClaimsUtility.BuildJwtClaims(user, userId, GetClaimsProvider(owinContext), _serverName, _mode);
var claims = ClaimsUtility.BuildJwtClaims(user, userId, GetClaimsProvider(owinContext), _serverName, _mode, _enableDetailedErrors);

yield return new Claim(Constants.ClaimType.Version, AssemblyVersion);

Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public static class ClaimType
public const string Id = AzureSignalRSysPrefix + "id";
public const string AppName = AzureSignalRSysPrefix + "apn";
public const string Version = AzureSignalRSysPrefix + "vn";
public const string EnableDetailedErrors = AzureSignalRSysPrefix + "derror";

public const string AzureSignalRUserPrefix = "asrs.u.";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ internal static class ClaimsUtility
private static readonly string DefaultNameClaimType = DefaultClaimsIdentity.NameClaimType;
private static readonly string DefaultRoleClaimType = DefaultClaimsIdentity.RoleClaimType;

public static IEnumerable<Claim> BuildJwtClaims(ClaimsPrincipal user, string userId, Func<IEnumerable<Claim>> claimsProvider, string serverName = null, ServerStickyMode mode = ServerStickyMode.Disabled)
public static IEnumerable<Claim> BuildJwtClaims(ClaimsPrincipal user, string userId, Func<IEnumerable<Claim>> claimsProvider, string serverName = null, ServerStickyMode mode = ServerStickyMode.Disabled, bool enableDetailedErrors = false)
{
if (userId != null)
{
Expand All @@ -46,6 +46,11 @@ public static IEnumerable<Claim> BuildJwtClaims(ClaimsPrincipal user, string use
yield return new Claim(Constants.ClaimType.AuthenticationType, authenticationType);
}

if (enableDetailedErrors)
{
yield return new Claim(Constants.ClaimType.EnableDetailedErrors, true.ToString());
}

// Return custom NameClaimType and RoleClaimType
// We can have multiple Identities, for now, choose the default one
if (user?.Identity is ClaimsIdentity identity)
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.Azure.SignalR/HubHost/NegotiateHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ internal class NegotiateHandler
private readonly IEndpointRouter _router;
private readonly string _serverName;
private readonly ServerStickyMode _mode;
private readonly bool _enableDetailedErrors;

public NegotiateHandler(IServiceEndpointManager endpointManager, IEndpointRouter router, IUserIdProvider userIdProvider, IServerNameProvider nameProvider, IConnectionRequestIdProvider connectionRequestIdProvider, IOptions<ServiceOptions> options)
public NegotiateHandler(
IOptions<HubOptions> hubOptions,
IServiceEndpointManager endpointManager, IEndpointRouter router, IUserIdProvider userIdProvider, IServerNameProvider nameProvider, IConnectionRequestIdProvider connectionRequestIdProvider, IOptions<ServiceOptions> options)
{
_endpointManager = endpointManager ?? throw new ArgumentNullException(nameof(endpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
Expand All @@ -32,6 +35,7 @@ public NegotiateHandler(IServiceEndpointManager endpointManager, IEndpointRouter
_connectionRequestIdProvider = connectionRequestIdProvider ?? throw new ArgumentNullException(nameof(connectionRequestIdProvider));
_claimsProvider = options?.Value?.ClaimsProvider;
_mode = options.Value.ServerStickyMode;
_enableDetailedErrors = hubOptions.Value.EnableDetailedErrors == true;
}

public NegotiationResponse Process(HttpContext context, string hubName)
Expand Down Expand Up @@ -75,7 +79,7 @@ private string GetQueryString(string originalQueryString)
private IEnumerable<Claim> BuildClaims(HttpContext context)
{
var userId = _userIdProvider.GetUserId(new ServiceHubConnectionContext(context));
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, _mode).ToList();
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, _mode, _enableDetailedErrors).ToList();
}

private Func<IEnumerable<Claim>> GetClaimsProvider(HttpContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ private async Task WaitOnApplicationTask(ServiceConnectionContext connection)
try
{
// Wait for the application task to complete
// application task can end when exception, or Context.Abort() from hub
await connection.ApplicationTask;
}
catch (Exception ex)
Expand All @@ -197,7 +198,7 @@ private async Task WaitOnApplicationTask(ServiceConnectionContext connection)
if (connection.AbortOnClose)
{
// Inform the Service that we will remove the client because SignalR told us it is disconnected.
var serviceMessage = new CloseConnectionMessage(connection.ConnectionId, errorMessage: "Web application task completed, close the client.");
var serviceMessage = new CloseConnectionMessage(connection.ConnectionId, errorMessage: exception?.Message);
await WriteAsync(serviceMessage);
Log.CloseConnection(Logger, connection.ConnectionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ public async Task TestClaimsProviderInServiceOptionsTakeEffect()
using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug))
{
var hubConfiguration = Utility.GetTestHubConfig(loggerFactory);
hubConfiguration.EnableDetailedErrors = true;
using (WebApp.Start(ServiceUrl, a => a.RunAzureSignalR(AppName, hubConfiguration, options =>
{
options.ConnectionString = ConnectionString;
Expand Down Expand Up @@ -623,6 +624,9 @@ public async Task TestClaimsProviderInServiceOptionsTakeEffect()
var requestId = token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.Id);
Assert.Null(requestId);
Assert.Equal(TimeSpan.FromDays(1), token.ValidTo - token.ValidFrom);

var enableDetailedErrors = token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.EnableDetailedErrors);
Assert.Equal("True", enableDetailedErrors.Value);
}
}
}
Expand Down Expand Up @@ -672,6 +676,8 @@ public async Task TestStickyServerInServiceOptionsTakeEffect()
var requestId = token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.Id);
Assert.Null(requestId);
Assert.Equal(TimeSpan.FromDays(1), token.ValidTo - token.ValidFrom);
var enableDetailedErrors = token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.EnableDetailedErrors);
Assert.Null(enableDetailedErrors);
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions test/Microsoft.Azure.SignalR.Tests/NegotiateHandlerFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public class NegotiateHandlerFacts
public void GenerateNegotiateResponseWithUserId(Type type, string expectedUserId)
{
var config = new ConfigurationBuilder().Build();
var serviceProvider = new ServiceCollection().AddSignalR()
var serviceProvider = new ServiceCollection()
.AddSignalR(o => o.EnableDetailedErrors = false)
.AddAzureSignalR(
o =>
{
Expand Down Expand Up @@ -79,6 +80,7 @@ public void GenerateNegotiateResponseWithUserId(Type type, string expectedUserId
Assert.Equal(TimeSpan.FromDays(1), token.ValidTo - token.ValidFrom);
Assert.Null(token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.ServerName));
Assert.Null(token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.ServerStickyMode));
Assert.Null(token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.EnableDetailedErrors));
}

[Fact]
Expand All @@ -87,7 +89,8 @@ public void GenerateNegotiateResponseWithUserIdAndServerSticky()
var name = nameof(GenerateNegotiateResponseWithUserIdAndServerSticky);
var serverNameProvider = new TestServerNameProvider(name);
var config = new ConfigurationBuilder().Build();
var serviceProvider = new ServiceCollection().AddSignalR()
var serviceProvider = new ServiceCollection()
.AddSignalR(o => o.EnableDetailedErrors = true)
.AddAzureSignalR(
o =>
{
Expand Down Expand Up @@ -130,6 +133,7 @@ public void GenerateNegotiateResponseWithUserIdAndServerSticky()
Assert.Equal(name, serverName);
var mode = token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.ServerStickyMode)?.Value;
Assert.Equal("Required", mode);
Assert.Equal("True", token.Claims.FirstOrDefault(s => s.Type == Constants.ClaimType.EnableDetailedErrors)?.Value);
}

[Theory]
Expand Down