Skip to content

Commit

Permalink
move get handshake timeout part outside of buildClaims
Browse files Browse the repository at this point in the history
  • Loading branch information
wanlwanl committed Nov 16, 2020
1 parent 1dcfc85 commit 042f46b
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ private IEnumerable<Claim> BuildClaims(IOwinContext owinContext, IRequest reques
yield return new Claim(Constants.ClaimType.AppName, _appName);
var user = owinContext.Authentication?.User;
var userId = _provider?.GetUserId(request);
var claims = ClaimsUtility.BuildJwtClaims(_logger, user, userId, GetClaimsProvider(owinContext), _serverName, _mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(owinContext));
var claims = ClaimsUtility.BuildJwtClaims(user, userId, GetClaimsProvider(owinContext), _serverName, _mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(owinContext));

yield return new Claim(Constants.ClaimType.Version, AssemblyVersion);

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ public static class Periods
public static readonly TimeSpan DefaultCloseDelayInterval = TimeSpan.FromSeconds(5);

// Custom handshake timeout of SignalR Service
public static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15);
public static readonly TimeSpan MaxCustomHandshakeTimeout = TimeSpan.FromSeconds(30);
public const int DefaultHandshakeTimeout = 15;
public const int MaxCustomHandshakeTimeout = 30;
}

public static class ClaimType
Expand Down
51 changes: 3 additions & 48 deletions src/Microsoft.Azure.SignalR.Common/Utilities/ClaimsUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Linq;
using System.Security.Claims;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;

namespace Microsoft.Azure.SignalR
{
Expand All @@ -27,7 +26,6 @@ internal static class ClaimsUtility
private static readonly string DefaultRoleClaimType = DefaultClaimsIdentity.RoleClaimType;

public static IEnumerable<Claim> BuildJwtClaims(
ILogger logger,
ClaimsPrincipal user,
string userId,
Func<IEnumerable<Claim>> claimsProvider,
Expand All @@ -36,7 +34,7 @@ public static IEnumerable<Claim> BuildJwtClaims(
bool enableDetailedErrors = false,
int endpointsCount = 1,
int? maxPollInterval = null,
bool isDiagnosticClient = false, TimeSpan? handshakeTimeout = null)
bool isDiagnosticClient = false, int handshakeTimeout = Constants.Periods.DefaultHandshakeTimeout)
{
if (userId != null)
{
Expand All @@ -54,9 +52,9 @@ public static IEnumerable<Claim> BuildJwtClaims(
yield return new Claim(Constants.ClaimType.DiagnosticClient, "true");
}

if (TryGetCustomHandshakeTimeoutClaim(handshakeTimeout, logger, out var handshakeTimeoutClaim))
if (handshakeTimeout != Constants.Periods.DefaultHandshakeTimeout)
{
yield return handshakeTimeoutClaim;
yield return new Claim(Constants.ClaimType.CustomHandshakeTimeout, handshakeTimeout.ToString());
}

var authenticationType = user?.Identity?.AuthenticationType;
Expand Down Expand Up @@ -175,48 +173,5 @@ internal static ClaimsPrincipal GetUserPrincipal(Claim[] messageClaims)

return new ClaimsPrincipal(new ClaimsIdentity(claims, authenticationType, nameType, roleType));
}

private static bool TryGetCustomHandshakeTimeoutClaim(TimeSpan? handshakeTimeout, ILogger logger, out Claim claim)
{
// use default handshake timeout
if (!handshakeTimeout.HasValue || handshakeTimeout.Value.Equals(Constants.Periods.DefaultHandshakeTimeout))
{
claim = null;
return false;
}

// the custom handshake timeout is invalid, use default hanshake timeout instead
if (handshakeTimeout.Value.CompareTo(TimeSpan.Zero) <= 0 ||
handshakeTimeout.Value.CompareTo(Constants.Periods.MaxCustomHandshakeTimeout) > 0)
{
Log.FailToSetCustomHandshakeTimeout(logger, new ArgumentOutOfRangeException(nameof(handshakeTimeout)));
claim = null;
return false;
}

// the custom handshake timeout is valid
Log.SucceedToSetCustomHandshakeTimeout(logger, handshakeTimeout.Value);
claim = new Claim(Constants.ClaimType.CustomHandshakeTimeout, ((int)handshakeTimeout.Value.TotalSeconds).ToString());
return true;
}

private static class Log
{
private static readonly Action<ILogger, int, Exception> _succeedToSetCustomHandshakeTimeout =
LoggerMessage.Define<int>(LogLevel.Information, new EventId(1, "SucceedToSetCustomHandshakeTimeout"), "Succeed to set custom handshake timeout: {timeout} seconds.");

private static readonly Action<ILogger, Exception> _failToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailToSetCustomHandshakeTimeout"), $"Fail to set custom handshake timeout, use default handshake timeout {Constants.Periods.DefaultHandshakeTimeout.TotalSeconds} seconds instead. The range of custom handshake timeout should between 1 second to {Constants.Periods.MaxCustomHandshakeTimeout.TotalSeconds} seconds.");

public static void SucceedToSetCustomHandshakeTimeout(ILogger logger, TimeSpan customHandshakeTimeout)
{
_succeedToSetCustomHandshakeTimeout(logger, (int)customHandshakeTimeout.TotalSeconds, null);
}

public static void FailToSetCustomHandshakeTimeout(ILogger logger, Exception exception)
{
_failToSetCustomHandshakeTimeout(logger, exception);
}
}
}
}
62 changes: 59 additions & 3 deletions src/Microsoft.Azure.SignalR/HubHost/NegotiateHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ internal class NegotiateHandler<THub> where THub : Hub
private readonly bool _enableDetailedErrors;
private readonly int _endpointsCount;
private readonly int? _maxPollInterval;
private readonly TimeSpan? _customHandshakeTimeout;
private readonly int _customHandshakeTimeout;
private readonly string _hubName;
private readonly ILogger<NegotiateHandler<THub>> _logger;

Expand Down Expand Up @@ -59,7 +59,7 @@ public NegotiateHandler(
_enableDetailedErrors = globalHubOptions.Value.EnableDetailedErrors == true;
_endpointsCount = options.Value.Endpoints.Length;
_maxPollInterval = options.Value.MaxPollIntervalInSeconds;
_customHandshakeTimeout = hubOptions.Value.HandshakeTimeout ?? globalHubOptions.Value.HandshakeTimeout;
_customHandshakeTimeout = GetCustomHandshakeTimeout(hubOptions.Value.HandshakeTimeout ?? globalHubOptions.Value.HandshakeTimeout);
_hubName = typeof(THub).Name;
}

Expand Down Expand Up @@ -111,7 +111,7 @@ private IEnumerable<Claim> BuildClaims(HttpContext context)
// Make sticky mode required if detect using blazor
var mode = _blazorDetector.IsBlazor(_hubName) ? ServerStickyMode.Required : _mode;
var userId = _userIdProvider.GetUserId(new ServiceHubConnectionContext(context));
return ClaimsUtility.BuildJwtClaims(_logger, context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context), _customHandshakeTimeout).ToList();
return ClaimsUtility.BuildJwtClaims(context.User, userId, GetClaimsProvider(context), _serverName, mode, _enableDetailedErrors, _endpointsCount, _maxPollInterval, IsDiagnosticClient(context), _customHandshakeTimeout).ToList();
}

private Func<IEnumerable<Claim>> GetClaimsProvider(HttpContext context)
Expand All @@ -129,12 +129,68 @@ private bool IsDiagnosticClient(HttpContext context)
return _diagnosticClientFilter != null && _diagnosticClientFilter(context);
}

private int GetCustomHandshakeTimeout(TimeSpan? handshakeTimeout)
{
if (!handshakeTimeout.HasValue)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

var timeout = (int)handshakeTimeout.Value.TotalSeconds;

// use default handshake timeout
if (timeout == Constants.Periods.DefaultHandshakeTimeout)
{
Log.UseDefaultHandshakeTimeout(_logger);
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is invalid, use default hanshake timeout instead
if (timeout <= 0 || timeout > Constants.Periods.MaxCustomHandshakeTimeout)
{
Log.FailToSetCustomHandshakeTimeout(_logger, new ArgumentOutOfRangeException(nameof(handshakeTimeout)));
return Constants.Periods.DefaultHandshakeTimeout;
}

// the custom handshake timeout is valid
Log.SucceedToSetCustomHandshakeTimeout(_logger, timeout);
return timeout;
}

private static string GetOriginalPath(string path)
{
path = path.TrimEnd('/');
return path.EndsWith(Constants.Path.Negotiate)
? path.Substring(0, path.Length - Constants.Path.Negotiate.Length)
: string.Empty;
}

private static class Log
{
private static readonly Action<ILogger, Exception> _useDefaultHandshakeTimeout =
LoggerMessage.Define(LogLevel.Information, new EventId(0, "UseDefaultHandshakeTimeout"), "Use default handshake timeout.");

private static readonly Action<ILogger, int, Exception> _succeedToSetCustomHandshakeTimeout =
LoggerMessage.Define<int>(LogLevel.Information, new EventId(1, "SucceedToSetCustomHandshakeTimeout"), "Succeed to set custom handshake timeout: {timeout} seconds.");

private static readonly Action<ILogger, Exception> _failToSetCustomHandshakeTimeout =
LoggerMessage.Define(LogLevel.Warning, new EventId(2, "FailToSetCustomHandshakeTimeout"), $"Fail to set custom handshake timeout, use default handshake timeout {Constants.Periods.DefaultHandshakeTimeout} seconds instead. The range of custom handshake timeout should between 1 second to {Constants.Periods.MaxCustomHandshakeTimeout} seconds.");

public static void UseDefaultHandshakeTimeout(ILogger logger)
{
_useDefaultHandshakeTimeout(logger, null);
}

public static void SucceedToSetCustomHandshakeTimeout(ILogger logger, int customHandshakeTimeout)
{
_succeedToSetCustomHandshakeTimeout(logger, customHandshakeTimeout, null);
}

public static void FailToSetCustomHandshakeTimeout(ILogger logger, Exception exception)
{
_failToSetCustomHandshakeTimeout(logger, exception);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ private static readonly (ClaimsIdentity identity, string userId, Func<IEnumerabl
[Fact]
public void TestGetSystemClaimsWithDefaultValue()
{
var claims = ClaimsUtility.BuildJwtClaims(NullLogger.Instance, null, null, null).ToList();
var claims = ClaimsUtility.BuildJwtClaims(null, null, null).ToList();
Assert.Empty(claims);
}

[Theory]
[MemberData(nameof(ClaimsParameters))]
public void TestGetSystemClaims(ClaimsIdentity identity, string userId, Func<IEnumerable<Claim>> provider, string expectedAuthenticationType, int expectedClaimsCount)
{
var claims = ClaimsUtility.BuildJwtClaims(NullLogger.Instance, new ClaimsPrincipal(identity), userId, provider).ToArray();
var claims = ClaimsUtility.BuildJwtClaims(new ClaimsPrincipal(identity), userId, provider).ToArray();
var resultIdentity = ClaimsUtility.GetUserPrincipal(claims).Identity;

var ci = resultIdentity as ClaimsIdentity;
Expand Down

0 comments on commit 042f46b

Please sign in to comment.