Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add configurable timeout #1103

Merged
merged 5 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ public static class Periods
public static readonly TimeSpan DefaultServersPingInterval = TimeSpan.FromSeconds(5);
// Depends on DefaultStatusPingInterval, make 1/2 to fast check.
public static readonly TimeSpan DefaultCloseDelayInterval = TimeSpan.FromSeconds(5);

// Custom handshake timeout of SignalR Service
public const int DefaultHandshakeTimeout = 15;
public const int MaxCustomHandshakeTimeout = 30;
}

public static class ClaimType
Expand All @@ -59,6 +63,7 @@ public static class ClaimType
public const string ServiceEndpointsCount = AzureSignalRSysPrefix + "secn";
public const string MaxPollInterval = AzureSignalRSysPrefix + "ttl";
public const string DiagnosticClient = AzureSignalRSysPrefix + "dc";
public const string CustomHandshakeTimeout = AzureSignalRSysPrefix + "cht";

public const string AzureSignalRUserPrefix = "asrs.u.";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static IEnumerable<Claim> BuildJwtClaims(
bool enableDetailedErrors = false,
int endpointsCount = 1,
int? maxPollInterval = null,
bool isDiagnosticClient = false)
bool isDiagnosticClient = false, int handshakeTimeout = Constants.Periods.DefaultHandshakeTimeout)
{
if (userId != null)
{
Expand All @@ -52,6 +52,11 @@ public static IEnumerable<Claim> BuildJwtClaims(
yield return new Claim(Constants.ClaimType.DiagnosticClient, "true");
}

if (handshakeTimeout != Constants.Periods.DefaultHandshakeTimeout)
{
yield return new Claim(Constants.ClaimType.CustomHandshakeTimeout, handshakeTimeout.ToString());
}

var authenticationType = user?.Identity?.AuthenticationType;

// No need to pass it when the authentication type is Bearer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton(typeof(AzureSignalRMarkerService))
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<NegotiateHandler>();
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
builder.Services.TryAddSingleton(typeof(IEndpointRouter), typeof(DefaultEndpointRouter));
Expand Down
89 changes: 77 additions & 12 deletions src/Microsoft.Azure.SignalR/HubHost/NegotiateHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Localization;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Microsoft.Azure.SignalR
{
internal class NegotiateHandler
internal class NegotiateHandler<THub> where THub : Hub
{
private readonly IUserIdProvider _userIdProvider;
private readonly IConnectionRequestIdProvider _connectionRequestIdProvider;
Expand All @@ -29,17 +30,23 @@ internal class NegotiateHandler
private readonly bool _enableDetailedErrors;
private readonly int _endpointsCount;
private readonly int? _maxPollInterval;
private readonly int _customHandshakeTimeout;
private readonly string _hubName;
private readonly ILogger<NegotiateHandler<THub>> _logger;

public NegotiateHandler(
IOptions<HubOptions> hubOptions,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions,
IServiceEndpointManager endpointManager,
IEndpointRouter router,
IUserIdProvider userIdProvider,
IServerNameProvider nameProvider,
IConnectionRequestIdProvider connectionRequestIdProvider,
IOptions<ServiceOptions> options,
IBlazorDetector blazorDetector)
IBlazorDetector blazorDetector,
ILogger<NegotiateHandler<THub>> logger)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_endpointManager = endpointManager ?? throw new ArgumentNullException(nameof(endpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_serverName = nameProvider?.GetName();
Expand All @@ -49,18 +56,20 @@ public NegotiateHandler(
_diagnosticClientFilter = options?.Value?.DiagnosticClientFilter;
_blazorDetector = blazorDetector ?? new DefaultBlazorDetector();
_mode = options.Value.ServerStickyMode;
_enableDetailedErrors = hubOptions.Value.EnableDetailedErrors == true;
_enableDetailedErrors = globalHubOptions.Value.EnableDetailedErrors == true;
_endpointsCount = options.Value.Endpoints.Length;
_maxPollInterval = options.Value.MaxPollIntervalInSeconds;
_customHandshakeTimeout = GetCustomHandshakeTimeout(hubOptions.Value.HandshakeTimeout ?? globalHubOptions.Value.HandshakeTimeout);
_hubName = typeof(THub).Name;
}

public async Task<NegotiationResponse> Process(HttpContext context, string hubName)
public async Task<NegotiationResponse> Process(HttpContext context)
{
var claims = BuildClaims(context, hubName);
var claims = BuildClaims(context);
var request = context.Request;
var cultureName = context.Features.Get<IRequestCultureFeature>()?.RequestCulture.Culture.Name;
var originalPath = GetOriginalPath(request.Path);
var provider = _endpointManager.GetEndpointProvider(_router.GetNegotiateEndpoint(context, _endpointManager.GetEndpoints(hubName)));
var provider = _endpointManager.GetEndpointProvider(_router.GetNegotiateEndpoint(context, _endpointManager.GetEndpoints(_hubName)));

if (provider == null)
{
Expand All @@ -71,8 +80,8 @@ public async Task<NegotiationResponse> Process(HttpContext context, string hubNa

return new NegotiationResponse
{
Url = provider.GetClientEndpoint(hubName, originalPath, queryString),
AccessToken = await provider.GenerateClientAccessTokenAsync(hubName, claims),
Url = provider.GetClientEndpoint(_hubName, originalPath, queryString),
AccessToken = await provider.GenerateClientAccessTokenAsync(_hubName, claims),
// Need to set this even though it's technically protocol violation https://github.com/aspnet/SignalR/issues/2133
AvailableTransports = new List<AvailableTransport>()
};
Expand All @@ -97,12 +106,12 @@ private string GetQueryString(string originalQueryString, string cultureName)
: queryString;
}

private IEnumerable<Claim> BuildClaims(HttpContext context, string hubName)
private IEnumerable<Claim> BuildClaims(HttpContext context)
{
// Make sticky mode required if detect using blazor
var mode = _blazorDetector.IsBlazor(hubName) ? ServerStickyMode.Required : _mode;
var mode = _blazorDetector.IsBlazor(_hubName) ? ServerStickyMode.Required : _mode;
var userId = _userIdProvider.GetUserId(new ServiceHubConnectionContext(context));
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context)).ToList();
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context), _customHandshakeTimeout).ToList();
}

private Func<IEnumerable<Claim>> GetClaimsProvider(HttpContext context)
Expand All @@ -120,12 +129,68 @@ private bool IsDiagnosticClient(HttpContext context)
return _diagnosticClientFilter != null && _diagnosticClientFilter(context);
}

private int GetCustomHandshakeTimeout(TimeSpan? handshakeTimeout)
{
if (!handshakeTimeout.HasValue)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

var timeout = (int)handshakeTimeout.Value.TotalSeconds;

// use default handshake timeout
if (timeout == Constants.Periods.DefaultHandshakeTimeout)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is invalid, use default hanshake timeout instead
if (timeout <= 0 || timeout > Constants.Periods.MaxCustomHandshakeTimeout)
{
Log.FailToSetCustomHandshakeTimeout(_logger, new ArgumentOutOfRangeException(nameof(handshakeTimeout)));
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is valid
Log.SucceedToSetCustomHandshakeTimeout(_logger, timeout);
return timeout;
}

private static string GetOriginalPath(string path)
{
path = path.TrimEnd('/');
return path.EndsWith(Constants.Path.Negotiate)
? path.Substring(0, path.Length - Constants.Path.Negotiate.Length)
: string.Empty;
}

private static class Log
{
private static readonly Action<ILogger, Exception> _useDefaultHandshakeTimeout =
LoggerMessage.Define(LogLevel.Information, new EventId(0, "UseDefaultHandshakeTimeout"), "Use default handshake timeout.");

private static readonly Action<ILogger, int, Exception> _succeedToSetCustomHandshakeTimeout =
LoggerMessage.Define<int>(LogLevel.Information, new EventId(1, "SucceedToSetCustomHandshakeTimeout"), "Succeed to set custom handshake timeout: {timeout} seconds.");

private static readonly Action<ILogger, Exception> _failToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailToSetCustomHandshakeTimeout"), $"Fail to set custom handshake timeout, use default handshake timeout {Constants.Periods.DefaultHandshakeTimeout} seconds instead. The range of custom handshake timeout should between 1 second to {Constants.Periods.MaxCustomHandshakeTimeout} seconds.");

public static void UseDefaultHandshakeTimeout(ILogger logger)
{
_useDefaultHandshakeTimeout(logger, null);
}

public static void SucceedToSetCustomHandshakeTimeout(ILogger logger, int customHandshakeTimeout)
{
_succeedToSetCustomHandshakeTimeout(logger, customHandshakeTimeout, null);
}

public static void FailToSetCustomHandshakeTimeout(ILogger logger, Exception exception)
{
_failToSetCustomHandshakeTimeout(logger, exception);
}
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public void MapHub<THub>(PathString path) where THub : Hub
{
// Get auth attributes
var authorizationData = AuthorizeHelper.BuildAuthorizePolicy(typeof(THub));
_routes.MapRoute(path + Constants.Path.Negotiate, c => ServiceRouteHelper.RedirectToService(c, typeof(THub).Name, authorizationData));
_routes.MapRoute(path + Constants.Path.Negotiate, c => ServiceRouteHelper.RedirectToService<THub>(c, authorizationData));

Start<THub>();
}
Expand Down
20 changes: 16 additions & 4 deletions src/Microsoft.Azure.SignalR/Startup/NegotiateMatcherPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
Expand All @@ -16,9 +17,11 @@ namespace Microsoft.Azure.SignalR
{
internal class NegotiateMatcherPolicy : MatcherPolicy, IEndpointSelectorPolicy
{
private static readonly MethodInfo _createNegotiateEndpointCoreMethodInfo = typeof(NegotiateMatcherPolicy).GetMethod(nameof(CreateNegotiateEndpointCore), BindingFlags.NonPublic | BindingFlags.Static);

// This caches the replacement endpoints for negotiate so they are not recomputed on every request
private readonly ConcurrentDictionary<Type, Endpoint> _negotiateEndpointCache = new ConcurrentDictionary<Type, Endpoint>();

public override int Order => 1;

public bool AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
Expand Down Expand Up @@ -49,7 +52,7 @@ public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
// skip endpoint not apply hub.
if (hubMetadata != null)
{
var newEndpoint = _negotiateEndpointCache.GetOrAdd(hubMetadata.HubType, e => CreateNegotiateEndpoint(routeEndpoint));
var newEndpoint = _negotiateEndpointCache.GetOrAdd(hubMetadata.HubType, CreateNegotiateEndpoint(hubMetadata.HubType, routeEndpoint));

candidates.ReplaceEndpoint(i, newEndpoint, candidate.Values);
}
Expand All @@ -59,14 +62,23 @@ public Task ApplyAsync(HttpContext httpContext, CandidateSet candidates)
return Task.CompletedTask;
}

private Endpoint CreateNegotiateEndpoint(RouteEndpoint routeEndpoint)
private Func<Type, Endpoint> CreateNegotiateEndpoint(Type hubType, RouteEndpoint routeEndpoint)
{
var genericMethodInfo = _createNegotiateEndpointCoreMethodInfo.MakeGenericMethod(hubType);
return type =>
{
return (Endpoint)genericMethodInfo.Invoke(this, new object[] { routeEndpoint });
};
}

private static Endpoint CreateNegotiateEndpointCore<THub>(RouteEndpoint routeEndpoint) where THub : Hub
{
var hubMetadata = routeEndpoint.Metadata.GetMetadata<HubMetadata>();

// Replaces the negotiate endpoint with one that does the service redirect
var routeEndpointBuilder = new RouteEndpointBuilder(async context =>
{
await ServiceRouteHelper.RedirectToService(context, hubMetadata.HubType.Name, null);
await ServiceRouteHelper.RedirectToService<THub>(context, null);
},
routeEndpoint.RoutePattern,
routeEndpoint.Order);
Expand Down
7 changes: 4 additions & 3 deletions src/Microsoft.Azure.SignalR/Utilities/ServiceRouteHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
Expand All @@ -15,9 +16,9 @@ namespace Microsoft.Azure.SignalR
{
internal class ServiceRouteHelper
{
public static async Task RedirectToService(HttpContext context, string hubName, IList<IAuthorizeData> authorizationData)
public static async Task RedirectToService<THub>(HttpContext context, IList<IAuthorizeData> authorizationData) where THub : Hub
{
var handler = context.RequestServices.GetRequiredService<NegotiateHandler>();
var handler = context.RequestServices.GetRequiredService<NegotiateHandler<THub>>();
var loggerFactory = context.RequestServices.GetService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<ServiceRouteHelper>();

Expand All @@ -29,7 +30,7 @@ public static async Task RedirectToService(HttpContext context, string hubName,
NegotiationResponse negotiateResponse = null;
try
{
negotiateResponse = await handler.Process(context, hubName);
negotiateResponse = await handler.Process(context);

if (context.Response.HasStarted)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
Expand Down
60 changes: 60 additions & 0 deletions test/Microsoft.Azure.SignalR.Tests/AddAzureSignalRFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@

using System;
using System.Collections.Generic;
using System.IdentityModel.Tokens.Jwt;
using System.IO;
using System.Linq;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Common;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
Expand Down Expand Up @@ -500,5 +504,61 @@ public async Task AddAzureSignalRHotReloadConfigValue()
Assert.Single(manager.Endpoints.Where(x => x.Value.ConnectionString == customeCS));
}
}

[Fact]
public async Task AddAzureSignalRWithCustomHandshakeTimeout()
{
// set custom handshake timeout in global hub options
var claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(1)).AddAzureSignalR());
Assert.Contains(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout && c.Value == "1");

// set custom handshake timeout in particular hub options to override the settings in global hub options
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(1)).AddHubOptions<TestHub>(o => o.HandshakeTimeout = TimeSpan.FromSeconds(2)).AddAzureSignalR());
Assert.Contains(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout && c.Value == "2");

// no custom timeout
claims = await GetClaims(sc => sc.AddSignalR().AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);

// invalid timeout: larger than 30s
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(31)).AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);

// invalid timeout: smaller than 1s
claims = await GetClaims(sc => sc.AddSignalR(o => o.HandshakeTimeout = TimeSpan.FromSeconds(0)).AddAzureSignalR());
Assert.DoesNotContain(claims, c => c.Type == Constants.ClaimType.CustomHandshakeTimeout);
}

private static async Task<IEnumerable<Claim>> GetClaims(Action<ServiceCollection> addSignalR)
{
var config = new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
{
{"Azure:SignalR:ConnectionString", "Endpoint=http://localhost;AccessKey=ABCDEFGHIJKLMNOPQR55555555012345678933333333;Version=1.0;"}
})
.Build();

var services = new ServiceCollection();
addSignalR(services);

var sp = services
.AddLogging()
.AddSingleton<IHostApplicationLifetime>(new EmptyApplicationLifetime())
.AddSingleton<IConfiguration>(config)
.BuildServiceProvider();

var app = new ApplicationBuilder(sp);
app.UseRouting();
app.UseEndpoints(routes =>
{
routes.MapHub<TestHub>("/chat");
});

var h = sp.GetRequiredService<NegotiateHandler<TestHub>>();
var r = await h.Process(new DefaultHttpContext());
var jwtSecurityTokenHandler = new JwtSecurityTokenHandler();
var t = jwtSecurityTokenHandler.ReadJwtToken(r.AccessToken);
return t.Claims;
}
}
}
Loading