Skip to content

Commit

Permalink
add configurable timeout (#1103)
Browse files Browse the repository at this point in the history
* add configurable timeout

* move reflection to cache (_negotiateEndpointCache)

* move get handshake timeout part outside of buildClaims

* remove method info outside of lambda to avoid re-evaluating

* make static readonly
  • Loading branch information
wanlwanl authored Nov 17, 2020
1 parent d0b179d commit 346970a
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 52 deletions.
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public static class Periods
public static readonly TimeSpan DefaultServersPingInterval = TimeSpan.FromSeconds(5);
// Depends on DefaultStatusPingInterval, make 1/2 to fast check.
public static readonly TimeSpan DefaultCloseDelayInterval = TimeSpan.FromSeconds(5);

// Custom handshake timeout of SignalR Service
public const int DefaultHandshakeTimeout = 15;
public const int MaxCustomHandshakeTimeout = 30;
}

public static class ClaimType
Expand All @@ -59,6 +63,7 @@ public static class ClaimType
public const string ServiceEndpointsCount = AzureSignalRSysPrefix + "secn";
public const string MaxPollInterval = AzureSignalRSysPrefix + "ttl";
public const string DiagnosticClient = AzureSignalRSysPrefix + "dc";
public const string CustomHandshakeTimeout = AzureSignalRSysPrefix + "cht";

public const string AzureSignalRUserPrefix = "asrs.u.";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static IEnumerable<Claim> BuildJwtClaims(
bool enableDetailedErrors = false,
int endpointsCount = 1,
int? maxPollInterval = null,
bool isDiagnosticClient = false)
bool isDiagnosticClient = false, int handshakeTimeout = Constants.Periods.DefaultHandshakeTimeout)
{
if (userId != null)
{
Expand All @@ -52,6 +52,11 @@ public static IEnumerable<Claim> BuildJwtClaims(
yield return new Claim(Constants.ClaimType.DiagnosticClient, "true");
}

if (handshakeTimeout != Constants.Periods.DefaultHandshakeTimeout)
{
yield return new Claim(Constants.ClaimType.CustomHandshakeTimeout, handshakeTimeout.ToString());
}

var authenticationType = user?.Identity?.AuthenticationType;

// No need to pass it when the authentication type is Bearer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<NegotiateHandler>();
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
builder.Services.TryAddSingleton(typeof(IEndpointRouter), typeof(DefaultEndpointRouter));
Expand Down
89 changes: 77 additions & 12 deletions src/Microsoft.Azure.SignalR/HubHost/NegotiateHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Localization;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Microsoft.Azure.SignalR
{
internal class NegotiateHandler
internal class NegotiateHandler<THub> where THub : Hub
{
private readonly IUserIdProvider _userIdProvider;
private readonly IConnectionRequestIdProvider _connectionRequestIdProvider;
Expand All @@ -29,17 +30,23 @@ internal class NegotiateHandler
private readonly bool _enableDetailedErrors;
private readonly int _endpointsCount;
private readonly int? _maxPollInterval;
private readonly int _customHandshakeTimeout;
private readonly string _hubName;
private readonly ILogger<NegotiateHandler<THub>> _logger;

public NegotiateHandler(
IOptions<HubOptions> hubOptions,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions,
IServiceEndpointManager endpointManager,
IEndpointRouter router,
IUserIdProvider userIdProvider,
IServerNameProvider nameProvider,
IConnectionRequestIdProvider connectionRequestIdProvider,
IOptions<ServiceOptions> options,
IBlazorDetector blazorDetector)
IBlazorDetector blazorDetector,
ILogger<NegotiateHandler<THub>> logger)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_endpointManager = endpointManager ?? throw new ArgumentNullException(nameof(endpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_serverName = nameProvider?.GetName();
Expand All @@ -49,18 +56,20 @@ public NegotiateHandler(
_diagnosticClientFilter = options?.Value?.DiagnosticClientFilter;
_blazorDetector = blazorDetector ?? new DefaultBlazorDetector();
_mode = options.Value.ServerStickyMode;
_enableDetailedErrors = hubOptions.Value.EnableDetailedErrors == true;
_enableDetailedErrors = globalHubOptions.Value.EnableDetailedErrors == true;
_endpointsCount = options.Value.Endpoints.Length;
_maxPollInterval = options.Value.MaxPollIntervalInSeconds;
_customHandshakeTimeout = GetCustomHandshakeTimeout(hubOptions.Value.HandshakeTimeout ?? globalHubOptions.Value.HandshakeTimeout);
_hubName = typeof(THub).Name;
}

public async Task<NegotiationResponse> Process(HttpContext context, string hubName)
public async Task<NegotiationResponse> Process(HttpContext context)
{
var claims = BuildClaims(context, hubName);
var claims = BuildClaims(context);
var request = context.Request;
var cultureName = context.Features.Get<IRequestCultureFeature>()?.RequestCulture.Culture.Name;
var originalPath = GetOriginalPath(request.Path);
var provider = _endpointManager.GetEndpointProvider(_router.GetNegotiateEndpoint(context, _endpointManager.GetEndpoints(hubName)));
var provider = _endpointManager.GetEndpointProvider(_router.GetNegotiateEndpoint(context, _endpointManager.GetEndpoints(_hubName)));

if (provider == null)
{
Expand All @@ -71,8 +80,8 @@ public async Task<NegotiationResponse> Process(HttpContext context, string hubNa

return new NegotiationResponse
{
Url = provider.GetClientEndpoint(hubName, originalPath, queryString),
AccessToken = await provider.GenerateClientAccessTokenAsync(hubName, claims),
Url = provider.GetClientEndpoint(_hubName, originalPath, queryString),
AccessToken = await provider.GenerateClientAccessTokenAsync(_hubName, claims),
// Need to set this even though it's technically protocol violation https://github.com/aspnet/SignalR/issues/2133
AvailableTransports = new List<AvailableTransport>()
};
Expand All @@ -97,12 +106,12 @@ private string GetQueryString(string originalQueryString, string cultureName)
: queryString;
}

private IEnumerable<Claim> BuildClaims(HttpContext context, string hubName)
private IEnumerable<Claim> BuildClaims(HttpContext context)
{
// Make sticky mode required if detect using blazor
var mode = _blazorDetector.IsBlazor(hubName) ? ServerStickyMode.Required : _mode;
var mode = _blazorDetector.IsBlazor(_hubName) ? ServerStickyMode.Required : _mode;
var userId = _userIdProvider.GetUserId(new ServiceHubConnectionContext(context));
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context)).ToList();
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context), _customHandshakeTimeout).ToList();
}

private Func<IEnumerable<Claim>> GetClaimsProvider(HttpContext context)
Expand All @@ -120,12 +129,68 @@ private bool IsDiagnosticClient(HttpContext context)
return _diagnosticClientFilter != null && _diagnosticClientFilter(context);
}

private int GetCustomHandshakeTimeout(TimeSpan? handshakeTimeout)
{
if (!handshakeTimeout.HasValue)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

var timeout = (int)handshakeTimeout.Value.TotalSeconds;

// use default handshake timeout
if (timeout == Constants.Periods.DefaultHandshakeTimeout)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is invalid, use default hanshake timeout instead
if (timeout <= 0 || timeout > Constants.Periods.MaxCustomHandshakeTimeout)
{
Log.FailToSetCustomHandshakeTimeout(_logger, new ArgumentOutOfRangeException(nameof(handshakeTimeout)));
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is valid
Log.SucceedToSetCustomHandshakeTimeout(_logger, timeout);
return timeout;
}

private static string GetOriginalPath(string path)
{
path = path.TrimEnd('/');
return path.EndsWith(Constants.Path.Negotiate)
? path.Substring(0, path.Length - Constants.Path.Negotiate.Length)
: string.Empty;
}

private static class Log
{
private static readonly Action<ILogger, Exception> _useDefaultHandshakeTimeout =
LoggerMessage.Define(LogLevel.Information, new EventId(0, "UseDefaultHandshakeTimeout"), "Use default handshake timeout.");

private static readonly Action<ILogger, int, Exception> _succeedToSetCustomHandshakeTimeout =
LoggerMessage.Define<int>(LogLevel.Information, new EventId(1, "SucceedToSetCustomHandshakeTimeout"), "Succeed to set custom handshake timeout: {timeout} seconds.");

private static readonly Action<ILogger, Exception> _failToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailToSetCustomHandshakeTimeout"), $"Fail to set custom handshake timeout, use default handshake timeout {Constants.Periods.DefaultHandshakeTimeout} seconds instead. The range of custom handshake timeout should between 1 second to {Constants.Periods.MaxCustomHandshakeTimeout} seconds.");

public static void UseDefaultHandshakeTimeout(ILogger logger)
{
_useDefaultHandshakeTimeout(logger, null);
}

public static void SucceedToSetCustomHandshakeTimeout(ILogger logger, int customHandshakeTimeout)
{
_succeedToSetCustomHandshakeTimeout(logger, customHandshakeTimeout, null);
}

public static void FailToSetCustomHandshakeTimeout(ILogger logger, Exception exception)
{
_failToSetCustomHandshakeTimeout(logger, exception);
}
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void MapHub<THub>(PathString path) where THub : Hub
{
// Get auth attributes
var authorizationData = AuthorizeHelper.BuildAuthorizePolicy(typeof(THub));
_routes.MapRoute(path + Constants.Path.Negotiate, c => ServiceRouteHelper.RedirectToService(c, typeof(THub).Name, authorizationData));
_routes.MapRoute(path + Constants.Path.Negotiate, c => ServiceRouteHelper.RedirectToService<THub>(c, authorizationData));

Start<THub>();
}
Expand Down
20 changes: 16 additions & 4 deletions src/Microsoft.Azure.SignalR/Startup/NegotiateMatcherPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
Expand All @@ -16,9 +17,11 @@ namespace Microsoft.Azure.SignalR
{
internal class NegotiateMatcherPolicy : MatcherPolicy, IEndpointSelectorPolicy
{
private static readonly MethodInfo _createNegotiateEndpointCoreMethodInfo = typeof(NegotiateMatcherPolicy).GetMethod(nameof(CreateNegotiateEndpointCore), BindingFlags.NonPublic | BindingFlags.Static);

// This caches the replacement endpoints for negotiate so they are not recomputed on every request
private readonly ConcurrentDictionary<Type, Endpoint> _negotiateEndpointCache = new ConcurrentDictionary<Type, Endpoint>();

public override int Order => 1;

public bool AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
Expand Down Expand Up @@ -49,7 +52,7 @@ public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
// skip endpoint not apply hub.
if (hubMetadata != null)
{
var newEndpoint = _negotiateEndpointCache.GetOrAdd(hubMetadata.HubType, e => CreateNegotiateEndpoint(routeEndpoint));
var newEndpoint = _negotiateEndpointCache.GetOrAdd(hubMetadata.HubType, CreateNegotiateEndpoint(hubMetadata.HubType, routeEndpoint));

candidates.ReplaceEndpoint(i, newEndpoint, candidate.Values);
}
Expand All @@ -59,14 +62,23 @@ public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
return Task.CompletedTask;
}

private Endpoint CreateNegotiateEndpoint(RouteEndpoint routeEndpoint)
private Func<Type, Endpoint> CreateNegotiateEndpoint(Type hubType, RouteEndpoint routeEndpoint)
{
var genericMethodInfo = _createNegotiateEndpointCoreMethodInfo.MakeGenericMethod(hubType);
return type =>
{
return (Endpoint)genericMethodInfo.Invoke(this, new object[] { routeEndpoint });
};
}

private static Endpoint CreateNegotiateEndpointCore<THub>(RouteEndpoint routeEndpoint) where THub : Hub
{
var hubMetadata = routeEndpoint.Metadata.GetMetadata<HubMetadata>();

// Replaces the negotiate endpoint with one that does the service redirect
var routeEndpointBuilder = new RouteEndpointBuilder(async context =>
{
await ServiceRouteHelper.RedirectToService(context, hubMetadata.HubType.Name, null);
await ServiceRouteHelper.RedirectToService<THub>(context, null);
},
routeEndpoint.RoutePattern,
routeEndpoint.Order);
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.Azure.SignalR/Utilities/ServiceRouteHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Expand All @@ -15,9 +16,9 @@ namespace Microsoft.Azure.SignalR
{
internal class ServiceRouteHelper
{
public static async Task RedirectToService(HttpContext context, string hubName, IList<IAuthorizeData> authorizationData)
public static async Task RedirectToService<THub>(HttpContext context, IList<IAuthorizeData> authorizationData) where THub : Hub
{
var handler = context.RequestServices.GetRequiredService<NegotiateHandler>();
var handler = context.RequestServices.GetRequiredService<NegotiateHandler<THub>>();
var loggerFactory = context.RequestServices.GetService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<ServiceRouteHelper>();

Expand All @@ -29,7 +30,7 @@ public static async Task RedirectToService(HttpContext context, string hubName,
NegotiationResponse negotiateResponse = null;
try
{
negotiateResponse = await handler.Process(context, hubName);
negotiateResponse = await handler.Process(context);

if (context.Response.HasStarted)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
Expand Down
60 changes: 60 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/AddAzureSignalRFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@

using System;
using System.Collections.Generic;
using System.IdentityModel.Tokens.Jwt;
using System.IO;
using System.Linq;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
Expand Down Expand Up @@ -500,5 +504,61 @@ public async Task AddAzureSignalRHotReloadConfigValue()
Assert.Single(manager.Endpoints.Where(x => x.Value.ConnectionString == customeCS));
}
}

[Fact]
public async Task AddAzureSignalRWithCustomHandshakeTimeout()
{
// set custom handshake timeout in global hub options
var claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(1)).AddAzureSignalR());
Assert.Contains(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout && c.Value == "1");

// set custom handshake timeout in particular hub options to override the settings in global hub options
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(1)).AddHubOptions<TestHub>(o => o.HandshakeTimeout = TimeSpan.FromSeconds(2)).AddAzureSignalR());
Assert.Contains(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout && c.Value == "2");

// no custom timeout
claims = await GetClaims(sc => sc.AddSignalR().AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);

// invalid timeout: larger than 30s
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(31)).AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);

// invalid timeout: smaller than 1s
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(0)).AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);
}

private static async Task<IEnumerable<Claim>> GetClaims(Action<ServiceCollection> addSignalR)
{
var config = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
{
{"Azure:SignalR:ConnectionString", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQR55555555012345678933333333;Version=1.0;"}
})
.Build();

var services = new ServiceCollection();
addSignalR(services);

var sp = services
.AddLogging()
.AddSingleton<IHostApplicationLifetime>(new EmptyApplicationLifetime())
.AddSingleton<IConfiguration>(config)
.BuildServiceProvider();

var app = new ApplicationBuilder(sp);
app.UseRouting();
app.UseEndpoints(routes =>
{
routes.MapHub<TestHub>("/chat");
});

var h = sp.GetRequiredService<NegotiateHandler<TestHub>>();
var r = await h.Process(new DefaultHttpContext());
var jwtSecurityTokenHandler = new JwtSecurityTokenHandler();
var t = jwtSecurityTokenHandler.ReadJwtToken(r.AccessToken);
return t.Claims;
}
}
}
Loading

0 comments on commit 346970a

Please sign in to comment.