From a70e5c7143476dbcf207be98e1a2b5441e4c7e55 Mon Sep 17 00:00:00 2001 From: stdrickforce Date: Mon, 18 Nov 2019 15:35:59 +0800 Subject: [PATCH] server connection migration implementation --- .gitignore | 6 +- .../Constants.cs | 2 + .../ServiceConnectionBase.cs | 2 +- .../ServiceConnectionContainerBase.cs | 4 +- .../StrongServiceConnectionContainer.cs | 2 +- .../ClientConnectionContext.cs | 10 ++- .../ServiceConnection.Log.cs | 16 ++++ .../ServerConnections/ServiceConnection.cs | 89 +++++++++++++++++-- .../TestClientConnectionFactory.cs | 19 ++++ .../Infrastructure/TestConnectionFactory.cs | 17 +++- .../ServiceConnectionTests.cs | 84 ++++++++++++++++- .../ServiceContextFacts.cs | 18 +++- 12 files changed, 249 insertions(+), 20 deletions(-) create mode 100644 test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestClientConnectionFactory.cs diff --git a/.gitignore b/.gitignore index 4c5ea419b..1f3597a0d 100644 --- a/.gitignore +++ b/.gitignore @@ -292,8 +292,8 @@ __pycache__/ .publish/ -# docker -.docker/ - # vim *.swp + +# docker +.docker/ diff --git a/src/Microsoft.Azure.SignalR.Common/Constants.cs b/src/Microsoft.Azure.SignalR.Common/Constants.cs index 95bd11767..9bbb26b7b 100644 --- a/src/Microsoft.Azure.SignalR.Common/Constants.cs +++ b/src/Microsoft.Azure.SignalR.Common/Constants.cs @@ -13,6 +13,8 @@ internal static class Constants public const int DefaultShutdownTimeoutInSeconds = 30; + public const string AsrsMigrateIn = "Asrs-Migrate-In"; + public const string AsrsMigrateOut = "Asrs-Migrate-Out"; public const string AsrsUserAgent = "Asrs-User-Agent"; public const string AsrsInstanceId = "Asrs-Instance-Id"; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index dab47cb39..92ad0a39e 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -16,7 +16,7 @@ namespace Microsoft.Azure.SignalR { internal abstract class ServiceConnectionBase : IServiceConnection { - private static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15); + protected static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15); // Service ping rate is 5 sec to let server know service status. Set timeout for 30 sec for some space. private static readonly TimeSpan DefaultServiceTimeout = TimeSpan.FromSeconds(30); private static readonly long DefaultServiceTimeoutTicks = DefaultServiceTimeout.Seconds * Stopwatch.Frequency; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 084ca5a38..fc258e7cd 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -355,7 +355,7 @@ protected async Task WriteFinAsync(IServiceConnection c) await c.WriteAsync(_shutdownFinMessage); } - protected async Task RemoveConnectionFromService(IServiceConnection c) + protected async Task RemoveConnectionAsync(IServiceConnection c) { _ = WriteFinAsync(c); @@ -371,7 +371,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c) public virtual Task OfflineAsync() { - return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c))); + return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionAsync(c))); } private static class Log diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs index dd882c7f2..c50610831 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs @@ -45,7 +45,7 @@ public override Task StopAsync() public override Task OfflineAsync() { var task1 = base.OfflineAsync(); - var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c))); + var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionAsync(c))); return Task.WhenAll(task1, task2); } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index 8bef08107..8a6ebd99a 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -15,7 +15,6 @@ using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features.Authentication; -using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Primitives; @@ -37,7 +36,9 @@ internal class ClientConnectionContext : ConnectionContext, private readonly TaskCompletionSource _connectionEndTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public Task CompleteTask => _connectionEndTcs.Task; + public Task CompleteTask => _connectionEndTcs.Task; + + public bool IsMigrated { get; } private readonly object _heartbeatLock = new object(); private List<(Action handler, object state)> _heartbeatHandlers; @@ -47,6 +48,11 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action _applicationTaskTimedOut = LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskTimedOut"), "Timed out waiting for the application task to complete."); + private static readonly Action _migrationStarting = + LoggerMessage.Define(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server."); + + private static readonly Action _errorSkippingHandshakeResponse = + LoggerMessage.Define(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}"); + public static void FailedToCleanupConnections(ILogger logger, Exception exception) { _failedToCleanupConnections(logger, exception); @@ -82,6 +88,11 @@ public static void ConnectedStarting(ILogger logger, string connectionId) _connectedStarting(logger, connectionId, null); } + public static void MigrationStarting(ILogger logger, string connectionId) + { + _migrationStarting(logger, connectionId, null); + } + public static void ConnectedEnding(ILogger logger, string connectionId) { _connectedEnding(logger, connectionId, null); @@ -101,6 +112,11 @@ public static void ApplicationTaskTimedOut(ILogger logger) { _applicationTaskTimedOut(logger, null); } + + public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex) + { + _errorSkippingHandshakeResponse(logger, ex.Message, ex); + } } } } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index 832ce682d..fe5546b1d 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -12,7 +12,8 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; - +using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol; + namespace Microsoft.Azure.SignalR { internal partial class ServiceConnection : ServiceConnectionBase @@ -22,6 +23,8 @@ internal partial class ServiceConnection : ServiceConnectionBase private static readonly Dictionary CustomHeader = new Dictionary { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } }; private static readonly TimeSpan CloseTimeout = TimeSpan.FromSeconds(5); + private readonly bool _enableConnectionMigration; + private const string ClientConnectionCountInHub = "#clientInHub"; private const string ClientConnectionCountInServiceConnection = "#client"; @@ -53,6 +56,8 @@ public ServiceConnection(IServiceProtocol serviceProtocol, _connectionFactory = connectionFactory; _connectionDelegate = connectionDelegate; _clientConnectionFactory = clientConnectionFactory; + + _enableConnectionMigration = false; } protected override Task CreateConnection(string target = null) @@ -102,8 +107,16 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message) { var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext); connection.ServiceConnection = this; - AddClientConnection(connection, GetInstanceId(message.Headers)); - Log.ConnectedStarting(Logger, connection.ConnectionId); + AddClientConnection(connection, message); + + if (connection.IsMigrated) + { + Log.MigrationStarting(Logger, connection.ConnectionId); + } + else + { + Log.ConnectedStarting(Logger, connection.ConnectionId); + } // Execute the application code connection.ApplicationTask = _connectionDelegate(connection); @@ -120,6 +133,17 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message) protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage) { var connectionId = closeConnectionMessage.ConnectionId; + if (_enableConnectionMigration && _clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context)) + { + if (!context.HttpContext.Request.Headers.ContainsKey(Constants.AsrsMigrateOut)) + { + context.HttpContext.Request.Headers.Add(Constants.AsrsMigrateOut, ""); + } + // We have to prevent SignalR `{type: 7}` (close message) from reaching our client while doing migration. + // Since all user-created messages will be sent to `ServiceConnection` directly. + // We can simply ignore all messages came from the application pipe. + context.Application.Input.CancelPendingRead(); + } return PerformDisconnectAsyncCore(connectionId, false); } @@ -158,20 +182,72 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect } } + private async Task SkipHandshakeResponse(ClientConnectionContext connection, CancellationToken token) + { + try + { + while (true) + { + var result = await connection.Application.Input.ReadAsync(token); + if (result.IsCanceled || token.IsCancellationRequested) + { + return false; + } + + var buffer = result.Buffer; + if (buffer.IsEmpty) + { + continue; + } + + if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) + { + connection.Application.Input.AdvanceTo(buffer.Start); + return true; + } + + if (result.IsCompleted) + { + return false; + } + } + } + catch (Exception ex) + { + Log.ErrorSkippingHandshakeResponse(Logger, ex); + } + return false; + } + private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connection) { try - { + { + if (connection.IsMigrated) + { + using var source = new CancellationTokenSource(DefaultHandshakeTimeout); + + // A handshake response is not expected to be given + // if the connection was migrated from another server, + // since the connection hasn't been `dropped` from the client point of view. + if (!await SkipHandshakeResponse(connection, source.Token)) + { + return; + } + } + while (true) { var result = await connection.Application.Input.ReadAsync(); + if (result.IsCanceled) { break; } var buffer = result.Buffer; - if (!buffer.IsEmpty) + + if (!buffer.IsEmpty) { try { @@ -209,8 +285,9 @@ private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connecti } } - private void AddClientConnection(ClientConnectionContext connection, string instanceId) + private void AddClientConnection(ClientConnectionContext connection, OpenConnectionMessage message) { + var instanceId = GetInstanceId(message.Headers); _clientConnectionManager.AddClientConnection(connection); _connectionIds.TryAdd(connection.ConnectionId, instanceId); } diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestClientConnectionFactory.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestClientConnectionFactory.cs new file mode 100644 index 000000000..5e3686a3a --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestClientConnectionFactory.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.SignalR.Protocol; + +namespace Microsoft.Azure.SignalR.Tests +{ + class TestClientConnectionFactory : IClientConnectionFactory + { + public IList Connections = new List(); + + public ClientConnectionContext CreateConnection(OpenConnectionMessage message, Action configureContext = null) + { + var context = new ClientConnectionContext(message, configureContext); + Connections.Add(context); + return context; + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestConnectionFactory.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestConnectionFactory.cs index 03547945b..76299787b 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestConnectionFactory.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestConnectionFactory.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -14,13 +13,20 @@ internal class TestConnectionFactory : IConnectionFactory { private readonly Func _connectCallback; + public IList Connections = new List(); + public List Times { get; } = new List(); + public TestConnectionFactory() + { + _connectCallback = null; + } + public TestConnectionFactory(Func connectCallback) { _connectCallback = connectCallback; } - + public async Task ConnectAsync(HubServiceEndpoint endpoint, TransferFormat transferFormat, string connectionId, string target, CancellationToken cancellationToken = default, IDictionary headers = null) { @@ -31,13 +37,18 @@ public async Task ConnectAsync(HubServiceEndpoint endpoint, T ConnectionId = connectionId, Target = target }; + Connections.Add(connection); + // Start a task to process handshake request from the newly-created server connection. _ = HandshakeAsync(connection); // Do something for test purpose await AfterConnectedAsync(connection); - await _connectCallback(connection); + if (_connectCallback != null) + { + await _connectCallback(connection); + } return connection; } diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs index e9968eaec..e00c75919 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs @@ -4,16 +4,19 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.IO.Pipelines; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; - +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Primitives; using Xunit; using Xunit.Abstractions; @@ -196,6 +199,37 @@ await transportConnection.Application.Output.WriteAsync( } } + [Fact] + public async void ServiceConnectionShouldIgnoreFirstHandshakeResponse() + { + var factory = new TestClientConnectionFactory(); + var connection = MockServiceConnection(null, factory); + + // create a connection with migration header. + await connection.OnClientConnectedAsyncForTest(new OpenConnectionMessage("foo", new Claim[0]) + { + Headers = new Dictionary{ + { Constants.AsrsMigrateIn, "another-server" } + } + }); + + Assert.Equal(1, factory.Connections.Count); + var context = factory.Connections[0]; + Assert.True(context.IsMigrated); + + var message = new AspNetCore.SignalR.Protocol.HandshakeResponseMessage(""); + HandshakeProtocol.WriteResponseMessage(message, context.Transport.Output); + await context.Transport.Output.FlushAsync(); + + var task = context.Transport.Input.ReadAsync(); + await Task.Delay(100); + + // nothing should be written into the transport + Assert.False(task.IsCompleted); + // but the `migrated` status should remain False (readonly) + Assert.True(context.IsMigrated); + } + private sealed class TestConnectionHandler : ConnectionHandler { private TaskCompletionSource _startedTcs = new TaskCompletionSource(); @@ -225,6 +259,54 @@ public override async Task OnConnectedAsync(ConnectionContext connection) } } + private class TestServiceConnection : ServiceConnection + { + public TestServiceConnection(IConnectionFactory serviceConnectionFactory, + IClientConnectionFactory clientConnectionFactory, + ILoggerFactory loggerFactory, + ConnectionDelegate handler) : base( + new ServiceProtocol(), + new TestClientConnectionManager(), + serviceConnectionFactory, + loggerFactory, + handler, + clientConnectionFactory, + Guid.NewGuid().ToString("N"), + null, + null + ) + { + } + + public Task OnClientConnectedAsyncForTest(OpenConnectionMessage message) + { + return base.OnClientConnectedAsync(message); + } + } + + private TestServiceConnection MockServiceConnection(IConnectionFactory serviceConnectionFactory = null, + IClientConnectionFactory clientConnectionFactory = null, + ILoggerFactory loggerFactory = null) + { + clientConnectionFactory ??= new ClientConnectionFactory(); + serviceConnectionFactory ??= new TestConnectionFactory(conn => Task.CompletedTask); + loggerFactory ??= NullLoggerFactory.Instance; + + var services = new ServiceCollection(); + var connectionHandler = new EndlessConnectionHandler(); + services.AddSingleton(connectionHandler); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + ConnectionDelegate handler = builder.Build(); + + return new TestServiceConnection( + serviceConnectionFactory, + clientConnectionFactory, + loggerFactory, + handler + ); + } + private sealed class EndlessConnectionHandler : ConnectionHandler { public CancellationTokenSource CancellationToken { get; } = new CancellationTokenSource(); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceContextFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceContextFacts.cs index acb749818..7f083edae 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceContextFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceContextFacts.cs @@ -3,11 +3,13 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Linq; using System.Net; using System.Security.Claims; using Microsoft.AspNetCore.Http; using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.Primitives; using Xunit; @@ -42,7 +44,7 @@ public void ServiceConnectionContextWithSystemClaimsIsUnauthenticated() new Claim("exp", "1234567890"), new Claim("iat", "1234567890"), new Claim("nbf", "1234567890"), - new Claim(Constants.ClaimType.UserId, "customUserId"), + new Claim(Constants.ClaimType.UserId, "customUserId"), }; var serviceConnectionContext = new ClientConnectionContext(new OpenConnectionMessage("1", claims)); Assert.NotNull(serviceConnectionContext.User.Identity); @@ -161,6 +163,20 @@ public void ServiceConnectionContextWithRequestPath() Assert.Equal(path, request.Path); } + [Fact] + public void ServiceConnectionShouldBeMigrated() + { + var open = new OpenConnectionMessage("foo", new Claim[0]); + var context = new ClientConnectionContext(open); + Assert.False(context.IsMigrated); + + open.Headers = new Dictionary{ + { Constants.AsrsMigrateIn, "another-server" } + }; + context = new ClientConnectionContext(open); + Assert.True(context.IsMigrated); + } + [Theory] [InlineData("1.1.1.1", true, "1.1.1.1")] [InlineData("1.1.1.1, 2.2.2.2", true, "1.1.1.1")]