diff --git a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs index 9387ae182..8c12f2a59 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs @@ -104,7 +104,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf); } - var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, ccm, loggerFactory); + var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, loggerFactory); if (hubs?.Count > 0) { diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs index d0ccf303a..b8451b91c 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs @@ -80,7 +80,7 @@ protected override Task OnConnectedAsync(OpenConnectionMessage openConnectionMes } else { - // the manager still contains this connectionId, probably this connection is not yet cleaned up + // the manager still contains this connectionId, probably this connection is not yet cleaned up Log.DuplicateConnectionId(Logger, connectionId, null); return WriteAsync( new CloseConnectionMessage(connectionId, $"Duplicate connection ID {connectionId}")); @@ -288,10 +288,10 @@ private string GetInstanceId(IDictionary header) if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId)) { return instanceId; - } - return null; + } + return null; } - + private sealed class ClientContext { private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); @@ -319,7 +319,7 @@ public void CancelPendingRead() public string InstanceId { get; } public ChannelReader Input { get; } - + public ChannelWriter Output { get; } public IServiceTransport Transport { get; set; } diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs index 2e06745bc..efbf35be7 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnectionManager.cs @@ -82,11 +82,11 @@ public Task StopAsync() return Task.WhenAll(GetConnections().Select(s => s.StopAsync())); } - public Task ShutdownAsync(TimeSpan timeout) + public Task OfflineAsync() { - return Task.WhenAll(GetConnections().Select(s => s.ShutdownAsync(timeout))); - } - + return Task.WhenAll(GetConnections().Select(s => s.OfflineAsync())); + } + public IServiceConnectionContainer WithHub(string hubName) { if (_hubConnections == null ||!_hubConnections.TryGetValue(hubName, out var connection)) @@ -131,6 +131,6 @@ private IEnumerable GetConnections() yield return conn.Value; } } - } + } } } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs index 8fa0e3781..7c277271c 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; @@ -15,7 +14,7 @@ internal interface IServiceConnectionContainer Task StopAsync(); - Task ShutdownAsync(TimeSpan timeout); + Task OfflineAsync(); Task WriteAsync(ServiceMessage serviceMessage); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index 08b0cb2e2..8aaf53140 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -18,21 +18,18 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta private readonly IMessageRouter _router; private readonly ILogger _logger; private readonly IServiceConnectionContainer _inner; - private readonly IClientConnectionLifetimeManager _clientLifetime; private IReadOnlyList _endpoints; public Dictionary Connections { get; } - public MultiEndpointServiceConnectionContainer(string hub, Func generator, IServiceEndpointManager endpointManager, IMessageRouter router, IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory) + public MultiEndpointServiceConnectionContainer(string hub, Func generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory) { if (generator == null) { throw new ArgumentNullException(nameof(generator)); } - _clientLifetime = lifetime; - _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); // provides a copy to the endpoint per container @@ -57,14 +54,12 @@ public MultiEndpointServiceConnectionContainer( IServiceEndpointManager endpointManager, IMessageRouter router, IServerNameProvider nameProvider, - IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory ) : this( hub, endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, loggerFactory), endpointManager, router, - lifetime, loggerFactory ) { @@ -131,18 +126,18 @@ public Task StopAsync() })); } - public async Task ShutdownAsync(TimeSpan timeout) + public Task OfflineAsync() { - // TODOS - - // 1. write FIN to every server connection of every connection container. + if (_inner != null) + { + return _inner.OfflineAsync(); + } + else + { + return Task.WhenAll(Connections.Select(c => c.Value.OfflineAsync())); + } + } - // 2. wait until all client connections have been closed (either by server/client side) - - // 3. stop every container. - await StopAsync(); - } - public Task WriteAsync(ServiceMessage serviceMessage) { if (_inner != null) @@ -251,8 +246,8 @@ private Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func< } return Task.WhenAll(routed); - } - + } + private static class Log { private static readonly Action _startingConnection = diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index 59f7ff4b7..37099638a 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -183,7 +183,6 @@ public Task StopAsync() { Log.UnexpectedExceptionInStop(Logger, ConnectionId, ex); } - return Task.CompletedTask; } @@ -394,8 +393,7 @@ private async Task ReceiveHandshakeResponseAsync(PipeReader input, Cancell else { Log.HandshakeError(Logger, handshakeResponse.ErrorMessage, ConnectionId); - } - + } return false; } } @@ -416,10 +414,11 @@ private async Task ReceiveHandshakeResponseAsync(PipeReader input, Cancell private async Task ProcessIncomingAsync(ConnectionContext connection) { var keepAliveTimer = StartKeepAliveTimer(); + try { while (true) - { + { var result = await connection.Transport.Input.ReadAsync(); var buffer = result.Buffer; @@ -466,27 +465,28 @@ private async Task ProcessIncomingAsync(ConnectionContext connection) } } finally - { - keepAliveTimer.Stop(); - } + { + keepAliveTimer.Stop(); + _serviceConnectionOfflineTcs.TrySetResult(true); + } } private Task DispatchMessageAsync(ServiceMessage message) - { - switch (message) - { - case OpenConnectionMessage openConnectionMessage: - return OnConnectedAsync(openConnectionMessage); - case CloseConnectionMessage closeConnectionMessage: - return OnDisconnectedAsync(closeConnectionMessage); - case ConnectionDataMessage connectionDataMessage: - return OnMessageAsync(connectionDataMessage); - case ServiceErrorMessage serviceErrorMessage: - return OnServiceErrorAsync(serviceErrorMessage); - case PingMessage pingMessage: - return OnPingMessageAsync(pingMessage); - case AckMessage ackMessage: - return OnAckMessageAsync(ackMessage); + { + switch (message) + { + case OpenConnectionMessage openConnectionMessage: + return OnConnectedAsync(openConnectionMessage); + case CloseConnectionMessage closeConnectionMessage: + return OnDisconnectedAsync(closeConnectionMessage); + case ConnectionDataMessage connectionDataMessage: + return OnMessageAsync(connectionDataMessage); + case ServiceErrorMessage serviceErrorMessage: + return OnServiceErrorAsync(serviceErrorMessage); + case PingMessage pingMessage: + return OnPingMessageAsync(pingMessage); + case AckMessage ackMessage: + return OnAckMessageAsync(ackMessage); } return Task.CompletedTask; } @@ -626,7 +626,6 @@ private static class Log private static readonly Action _receivedInstanceOfflinePing = LoggerMessage.Define(LogLevel.Information, new EventId(31, "ReceivedInstanceOfflinePing"), "Received instance offline service ping: {InstanceId}"); - public static void FailedToWrite(ILogger logger, string serviceConnectionId, Exception exception) { _failedToWrite(logger, exception.Message, serviceConnectionId, null); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 7b5e847a4..12281b889 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -117,12 +117,6 @@ public virtual Task StopAsync() return Task.WhenAll(FixedServiceConnections.Select(c => c.StopAsync())); } - public virtual Task ShutdownAsync(TimeSpan timeout) - { - _terminated = true; - return Task.CompletedTask; - } - /// /// Start and manage the whole connection lifetime /// @@ -246,8 +240,8 @@ protected void ReplaceFixedConnections(int index, IServiceConnection serviceConn } public Task ConnectionInitializedTask => Task.WhenAll(from connection in FixedServiceConnections - select connection.ConnectionInitializedTask); - + select connection.ConnectionInitializedTask); + public virtual Task WriteAsync(ServiceMessage serviceMessage) { return WriteToRandomAvailableConnection(serviceMessage); @@ -347,8 +341,36 @@ private IEnumerable CreateFixedServiceConnection(int count) { yield return CreateServiceConnectionCore(InitialConnectionType); } - } - + } + + protected async Task WriteFinAsync(IServiceConnection c) + { + var ping = new PingMessage() + { + Messages = new string[2] { Constants.ServicePingMessageKey.ShutdownKey, Constants.ServicePingMessageValue.ShutdownFin } + }; + await c.WriteAsync(ping); + } + + protected async Task RemoveConnectionFromService(IServiceConnection c) + { + _ = WriteFinAsync(c); + + var source = new CancellationTokenSource(); + var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(TimeSpan.FromSeconds(3), source.Token)); + source.Cancel(); + + if (task != c.ConnectionOfflineTask) + { + // log + } + } + + public virtual Task OfflineAsync() + { + return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c))); + } + private static class Log { private static readonly Action _endpointOnline = diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs index ccc48a8e7..294b95536 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerFactory.cs @@ -13,7 +13,6 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa private readonly IServiceEndpointManager _serviceEndpointManager; private readonly IMessageRouter _router; private readonly IServerNameProvider _nameProvider; - private readonly IClientConnectionLifetimeManager _lifetime; private readonly IServiceConnectionFactory _serviceConnectionFactory; public ServiceConnectionContainerFactory( @@ -22,7 +21,6 @@ public ServiceConnectionContainerFactory( IMessageRouter router, IServiceEndpointOptions options, IServerNameProvider nameProvider, - IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory) { _serviceConnectionFactory = serviceConnectionFactory; @@ -30,13 +28,12 @@ public ServiceConnectionContainerFactory( _router = router ?? throw new ArgumentNullException(nameof(router)); _options = options; _nameProvider = nameProvider; - _lifetime = lifetime; _loggerFactory = loggerFactory; } public IServiceConnectionContainer Create(string hub) { - return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _lifetime, _loggerFactory); + return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _loggerFactory); } } } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs index 74e3981ab..ccc5a802e 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; @@ -42,13 +43,11 @@ public override Task StopAsync() ); } - public override Task ShutdownAsync(TimeSpan timeout) + public override Task OfflineAsync() { - var task = base.StopAsync(); - return Task.WhenAll( - task, - Task.CompletedTask // TODO - ); + var task1 = base.OfflineAsync(); + var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c))); + return Task.WhenAll(task1, task2); } protected override ServiceConnectionStatus GetStatus() diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs index 7443170ac..6246540e6 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/WeakServiceConnectionContainer.cs @@ -47,8 +47,6 @@ public override Task HandlePingAsync(PingMessage pingMessage) return Task.CompletedTask; } - public override Task ShutdownAsync(TimeSpan timeout) => StopAsync(); - public override Task WriteAsync(ServiceMessage serviceMessage) { if (!_active && !(serviceMessage is PingMessage)) @@ -61,6 +59,11 @@ public override Task WriteAsync(ServiceMessage serviceMessage) return base.WriteAsync(serviceMessage); } + public override Task OfflineAsync() + { + return Task.CompletedTask; + } + internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan) { lock (_lock) diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs index 5c7b5d1c5..79775ddab 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; @@ -65,9 +66,28 @@ public void Start(ConnectionDelegate connectionDelegate, Action con _ = _serviceConnectionManager.StartAsync(); } - public Task Shutdown(TimeSpan timeout) + public async Task ShutdownAsync(TimeSpan timeout) { - return _serviceConnectionManager.ShutdownAsync(timeout); + using CancellationTokenSource source = new CancellationTokenSource(); + + var expected = OfflineAndWaitForCompletedAsync(); + var actual = await Task.WhenAny( + Task.Delay(timeout, source.Token), OfflineAndWaitForCompletedAsync() + ); + + if (actual != expected) + { + // TODO log timeout. + } + + source.Cancel(); + await _serviceConnectionManager.StopAsync(); + } + + private async Task OfflineAndWaitForCompletedAsync() + { + await _serviceConnectionManager.OfflineAsync(); + await _clientConnectionManager.WhenAllCompleted(); } private IServiceConnectionContainer GetMultiEndpointServiceConnectionContainer(string hub, ConnectionDelegate connectionDelegate, Action contextConfig = null) @@ -77,12 +97,11 @@ private IServiceConnectionContainer GetMultiEndpointServiceConnectionContainer(s serviceConnectionFactory.ConfigureContext = contextConfig; var factory = new ServiceConnectionContainerFactory( - serviceConnectionFactory, + serviceConnectionFactory, _serviceEndpointManager, - _router, - _options, - _nameProvider, - _clientConnectionManager, + _router, + _options, + _nameProvider, _loggerFactory ); return factory.Create(hub); diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/IServiceConnectionManager.cs b/src/Microsoft.Azure.SignalR/ServerConnections/IServiceConnectionManager.cs index 59efb7202..c59ddf78d 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/IServiceConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/IServiceConnectionManager.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; @@ -17,7 +16,7 @@ internal interface IServiceConnectionManager where THub : Hub Task StopAsync(); - Task ShutdownAsync(TimeSpan timeout); + Task OfflineAsync(); Task WriteAsync(ServiceMessage seviceMessage); diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs index 7bdf09786..79e1f5e51 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnectionManager.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; -using System.Collections.Generic; -using System.Linq; +using System; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; @@ -31,9 +29,9 @@ public Task StopAsync() return _serviceConnection.StopAsync(); } - public Task ShutdownAsync(TimeSpan timeout) + public async Task OfflineAsync() { - return _serviceConnection.ShutdownAsync(timeout); + await _serviceConnection.OfflineAsync(); } public Task WriteAsync(ServiceMessage serviceMessage) diff --git a/src/Microsoft.Azure.SignalR/ServiceOptionsSetup.cs b/src/Microsoft.Azure.SignalR/ServiceOptionsSetup.cs index 79d81e0fe..8f99ad9a5 100644 --- a/src/Microsoft.Azure.SignalR/ServiceOptionsSetup.cs +++ b/src/Microsoft.Azure.SignalR/ServiceOptionsSetup.cs @@ -16,6 +16,9 @@ internal class ServiceOptionsSetup : IConfigureOptions private readonly string _connectionString; private readonly ServiceEndpoint[] _endpoints; + private readonly bool _gracefulShutdownEnabled = false; + private readonly TimeSpan _shutdownTimeout = TimeSpan.FromSeconds(Constants.DefaultShutdownTimeoutInSeconds); + public ServiceOptionsSetup(IConfiguration configuration) { _appName = configuration[Constants.ApplicationNameDefaultKeyPrefix]; @@ -44,6 +47,8 @@ public void Configure(ServiceOptions options) options.Endpoints = _endpoints; options.ApplicationName = _appName; options.ServerStickyMode = _serverStickyMode; + options.EnableGracefulShutdown = _gracefulShutdownEnabled; + options.ServerShutdownTimeout = _shutdownTimeout; } private static (string, ServiceEndpoint[]) GetEndpoint(IConfiguration configuration, string key) diff --git a/src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs b/src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs index 76e9e430b..05b71fb58 100644 --- a/src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs +++ b/src/Microsoft.Azure.SignalR/ServiceRouteBuilder.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; namespace Microsoft.Azure.SignalR { @@ -58,6 +59,18 @@ private void Start() where THub : Hub var dispatcher = _serviceProvider.GetRequiredService>(); dispatcher.Start(app); + +#if NETCOREAPP + var lifetime = _serviceProvider.GetService(); +#elif NETSTANDARD + var lifetime = _serviceProvider.GetService(); +#else + var lifetime = null; +#endif + if (lifetime != null) + { + lifetime.ApplicationStopping.Register(() => dispatcher.ShutdownAsync(TimeSpan.FromSeconds(30)).Wait()); + } } } } diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs index 2d3736ab9..98a5639d7 100644 --- a/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -26,7 +26,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub, IEndpointRouter router, ILoggerFactory loggerFactory - ) : base(hub, generator, endpoint, router, null, loggerFactory) + ) : base(hub, generator, endpoint, router, loggerFactory) { } } diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestClientConnectionManager.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestClientConnectionManager.cs index 8d612a527..df14fe023 100644 --- a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestClientConnectionManager.cs +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestClientConnectionManager.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; using System.Collections.Concurrent; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.Azure.SignalR.Protocol; diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestBaseServiceConnectionContainer.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestBaseServiceConnectionContainer.cs index d0996ac83..0e27f24cc 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestBaseServiceConnectionContainer.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestBaseServiceConnectionContainer.cs @@ -19,10 +19,5 @@ public override Task HandlePingAsync(PingMessage pingMessage) { return Task.CompletedTask; } - - protected override Task OnConnectionComplete(IServiceConnection connection) - { - return Task.CompletedTask; - } } } diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs index 8b7249b66..84ed8e341 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; @@ -10,13 +11,31 @@ namespace Microsoft.Azure.SignalR.Tests.Common { - internal sealed class TestServiceConnection : ServiceConnectionBase + internal class TestServiceConnection : ServiceConnectionBase { private readonly bool _throws; + private ServiceConnectionStatus _expectedStatus; + private ConnectionContext _connection; - public TestServiceConnection(ServiceConnectionStatus status = ServiceConnectionStatus.Connected, bool throws = false) : base(null, null, null, null, ServerConnectionType.Default, NullLogger.Instance) + public IDuplexPipe Application { get; private set; } + + private TaskCompletionSource _created = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public Task ConnectionCreated + { + get => _created.Task; + } + + public TestServiceConnection(ServiceConnectionStatus status = ServiceConnectionStatus.Connected, bool throws = false) : base( + new ServiceProtocol(), + Guid.NewGuid().ToString(), + new HubServiceEndpoint(), + null, // TODO replace it with a NullMessageHandler + ServerConnectionType.Default, + NullLogger.Instance + ) { _expectedStatus = status; _throws = throws; @@ -37,6 +56,10 @@ protected override Task CreateConnection(string target = null { var pipeOptions = new PipeOptions(); var duplex = DuplexPipe.CreateConnectionPair(pipeOptions, pipeOptions); + + Application = duplex.Application; + _created.SetResult(1); + return Task.FromResult(new DefaultConnectionContext() { Application = duplex.Application, @@ -71,6 +94,11 @@ protected override Task OnMessageAsync(ConnectionDataMessage connectionDataMessa return Task.CompletedTask; } + protected Task WriteAsyncBase(ServiceMessage serviceMessage) + { + return base.WriteAsync(serviceMessage); + } + public override Task WriteAsync(ServiceMessage serviceMessage) { if (_throws) @@ -84,6 +112,6 @@ public override Task WriteAsync(ServiceMessage serviceMessage) public void Stop() { _connection?.Transport.Input.CancelPendingRead(); - } + } } } diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs index 06db5b8b6..b62544144 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnectionContainer.cs @@ -62,7 +62,7 @@ public Task StopAsync() return Task.CompletedTask; } - public Task ShutdownAsync(TimeSpan timeout) + public Task OfflineAsync() { return Task.CompletedTask; } diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs index f69597f36..d51d59609 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestServiceConnectionManager.cs @@ -57,9 +57,6 @@ public Task StopAsync() return Task.CompletedTask; } - public Task ShutdownAsync(TimeSpan timeout) - { - return Task.CompletedTask; - } + public Task OfflineAsync() => Task.CompletedTask; } } diff --git a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs index 0864644eb..afed23980 100644 --- a/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/MultiEndpointServiceConnectionContainerTests.cs @@ -28,7 +28,7 @@ public TestMultiEndpointServiceConnectionContainer(string hub, IEndpointRouter router, ILoggerFactory loggerFactory - ) : base(hub, generator, endpoint, router, null, loggerFactory) + ) : base(hub, generator, endpoint, router, loggerFactory) { } } @@ -705,6 +705,60 @@ public async Task TestTwoEndpointsWithCancellationToken() await Assert.ThrowsAnyAsync(async () => await task).OrTimeout(); } + private async Task TestEndpointOfflineInner(IServiceEndpointManager manager, IEndpointRouter router) + { + var containers = new List(); + + var container = new TestMultiEndpointServiceConnectionContainer("hub", e => + { + var c = new TestServiceConnectionContainer(new List + { + new TestSimpleServiceConnection(), + new TestSimpleServiceConnection() + }); + c.MockOffline = true; + containers.Add(c); + return c; + }, manager, router, NullLoggerFactory.Instance); + + foreach (var c in containers) + { + Assert.False(c.IsOffline); + } + + var expected = container.OfflineAsync(); + var actual = await Task.WhenAny( + expected, + Task.Delay(TimeSpan.FromSeconds(1)) + ); + Assert.Equal(expected, actual); + + foreach (var c in containers) + { + Assert.True(c.IsOffline); + } + + } + + [Fact] + public async Task TestSingleEndpointOffline() + { + var manager = new TestServiceEndpointManager( + new ServiceEndpoint(ConnectionString1) + ); + await TestEndpointOfflineInner(manager, null); + } + + [Fact] + public async Task TestMultiEndpointOffline() + { + var manager = new TestServiceEndpointManager( + new ServiceEndpoint(ConnectionString1), + new ServiceEndpoint(ConnectionString2) + ); + await TestEndpointOfflineInner(manager, new TestEndpointRouter()); + } + private class NotExistEndpointRouter : EndpointRouterDecorator { public override IEnumerable GetEndpointsForConnection(string connectionId, IEnumerable endpoints) @@ -718,11 +772,6 @@ public override IEnumerable GetEndpointsForGroup(string groupNa } } - private IServiceConnection CreateServiceConnection(ServerConnectionType type, IConnectionFactory factory) - { - return new TestSimpleServiceConnection(); - } - private class TestServiceEndpointManager : ServiceEndpointManagerBase { private readonly ServiceEndpoint[] _endpoints; diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs index 7733c8f7c..48bde45e7 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; using Xunit; @@ -29,29 +30,51 @@ public async Task TestIfConnectionWillRestartAfterShutdown() Assert.NotEqual(ServiceConnectionStatus.Connected, container.Connections[1].Status); } - private sealed class SimpleTestServiceConnectionFactory : IServiceConnectionFactory + [Fact] + public async Task TestOffline() { - public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServerConnectionType type) + List connections = new List + { + new SimpleTestServiceConnection(), + new SimpleTestServiceConnection() + }; + using TestServiceConnectionContainer container = new TestServiceConnectionContainer(connections, factory: new SimpleTestServiceConnectionFactory()); + + foreach (SimpleTestServiceConnection c in connections) + { + Assert.False(c.ConnectionOfflineTask.IsCompleted); + } + + await container.OfflineAsync(); + + foreach (SimpleTestServiceConnection c in connections) { - return new SimpleTestServiceConnection(); + Assert.True(c.ConnectionOfflineTask.IsCompleted); } } - private sealed class SimpleTestServiceConnection : IServiceConnection + private sealed class SimpleTestServiceConnectionFactory : IServiceConnectionFactory { - public ServiceConnectionStatus Status { get; set; } + public IServiceConnection Create(HubServiceEndpoint endpoint, IServiceMessageHandler serviceMessageHandler, ServerConnectionType type) => new SimpleTestServiceConnection(); + } + private sealed class SimpleTestServiceConnection : IServiceConnection + { public Task ConnectionInitializedTask => Task.Delay(TimeSpan.FromSeconds(1)); - public Task ConnectionOfflineTask => Task.CompletedTask; + public ServiceConnectionStatus Status { get; set; } = ServiceConnectionStatus.Disconnected; - public event Action ConnectionStatusChanged; + private readonly TaskCompletionSource _offline = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + + public Task ConnectionOfflineTask => _offline.Task; public SimpleTestServiceConnection(ServiceConnectionStatus status = ServiceConnectionStatus.Disconnected) { Status = status; } + public event Action ConnectionStatusChanged; + public Task StartAsync(string target = null) { Status = ServiceConnectionStatus.Connected; @@ -60,12 +83,16 @@ public Task StartAsync(string target = null) public Task StopAsync() { - throw new NotImplementedException(); + return Task.CompletedTask; } public Task WriteAsync(ServiceMessage serviceMessage) { - throw new NotImplementedException(); + if (serviceMessage is PingMessage ping && ping.TryGetValue(Constants.ServicePingMessageKey.ShutdownKey, out var val) && val == Constants.ServicePingMessageValue.ShutdownFin) + { + _offline.SetResult(true); + } + return Task.CompletedTask; } } } diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTests.cs new file mode 100644 index 000000000..b6beb6393 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTests.cs @@ -0,0 +1,130 @@ +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; +using System; +using System.Collections.Generic; +using System.Security.Claims; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Azure.SignalR.Tests +{ + + public class ServiceConnectionContainerTests + { + + private async Task MockServiceAsync(TestServiceConnectionForCloseAsync conn) + { + IServiceProtocol proto = new ServiceProtocol(); + + await conn.ConnectionCreated; + + // open 2 new connections (to create 2 new outgoing tasks + proto.WriteMessage(new OpenConnectionMessage(Guid.NewGuid().ToString(), new Claim[0]), conn.Application.Output); + proto.WriteMessage(new OpenConnectionMessage(Guid.NewGuid().ToString(), new Claim[0]), conn.Application.Output); + await conn.Application.Output.FlushAsync(); + + while (true) + { + var result = await conn.Application.Input.ReadAsync(); + var buffer = result.Buffer; + + try + { + // write back a FinAck after receiving a Fin + if (proto.TryParseMessage(ref buffer, out ServiceMessage message)) + { + if (message is PingMessage ping && ping.TryGetValue(Constants.ServicePingMessageKey.ShutdownKey, out string val)) + { + if (val == Constants.ServicePingMessageValue.ShutdownFin) + { + PingMessage pong = new PingMessage + { + Messages = new string[2] { Constants.ServicePingMessageKey.ShutdownKey, Constants.ServicePingMessageValue.ShutdownFinAck } + }; + proto.WriteMessage(pong, conn.Application.Output); + await conn.Application.Output.FlushAsync(); + break; + } + } + } + } + finally + { + conn.Application.Input.AdvanceTo(buffer.Start, buffer.End); + } + } + } + + private PingMessage BuildPingMessage(string key, string val) + { + return new PingMessage + { + Messages = new string[2] { key, val } + }; + } + + private async Task MockServiceAsyncWithException(TestServiceConnectionForCloseAsync conn) + { + IServiceProtocol proto = new ServiceProtocol(); + + await conn.ConnectionCreated; + + // open 2 new connections (to create 2 new outgoing tasks + proto.WriteMessage(new OpenConnectionMessage(Guid.NewGuid().ToString(), new Claim[0]), conn.Application.Output); + proto.WriteMessage(new OpenConnectionMessage(Guid.NewGuid().ToString(), new Claim[0]), conn.Application.Output); + await conn.Application.Output.FlushAsync(); + + await Task.Delay(TimeSpan.FromSeconds(1)); + proto.WriteMessage(BuildPingMessage("_exception", "1"), conn.Application.Output); + await conn.Application.Output.FlushAsync(); + } + + private async Task AssertTask(Task task, TimeSpan timeout) + { + // prevent our test cases from running permanently + Task r = await Task.WhenAny(task, Task.Delay(timeout)); + Assert.Equal(r, task); + } + + [Fact] + public async void TestCloseAsync() + { + var conn = new TestServiceConnectionForCloseAsync(); + var hub = new HubServiceEndpoint(); + using var container = new TestBaseServiceConnectionContainer(new List { conn }, hub); + + _ = conn.StartAsync(); + _ = MockServiceAsync(conn); + + // close connection after 1 seconds. + await Task.Delay(TimeSpan.FromSeconds(1)); + // await AssertTask(container.CloseClientConnectionForTest(conn), TimeSpan.FromSeconds(5)); + } + + [Fact] + public async void TestCloseAsyncWithoutStartAsync() + { + var conn = new TestServiceConnectionForCloseAsync(); + var hub = new HubServiceEndpoint(); + using var container = new TestBaseServiceConnectionContainer(new List { conn }, hub); + + // await AssertTask(container.CloseClientConnectionForTest(conn), TimeSpan.FromSeconds(5)); + } + + [Fact] + public async void TestCloseAsyncWithExceptionAndNoFinAck() + { + var conn = new TestServiceConnectionForCloseAsync(); + var hub = new HubServiceEndpoint(); + using var container = new TestBaseServiceConnectionContainer(new List { conn }, hub); + + _ = conn.StartAsync(); + _ = MockServiceAsyncWithException(conn); + + // close connection after 2 seconds to make sure we have received an exception. + await Task.Delay(TimeSpan.FromSeconds(2)); + // TODO double check if we received an exception. + // await AssertTask(container.CloseClientConnectionForTest(conn), TimeSpan.FromSeconds(5)); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs new file mode 100644 index 000000000..e9f9f8794 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceHubDispatcherTests.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.Azure.SignalR.Tests +{ + public class ServiceHubDispatcherTests + { + [Fact] + public async void TestShutdown() + { + var clientManager = new TestClientConnectionManager(); + var serviceManager = new TestServiceConnectionManager(); + var dispatcher = new ServiceHubDispatcher( + null, + serviceManager, + clientManager, + null, + new TestOptions(), + NullLoggerFactory.Instance, + new TestRouter(), + null, + null + ); + + await dispatcher.ShutdownAsync(TimeSpan.FromSeconds(1)); + + Assert.True(clientManager.completeTime.Subtract(serviceManager.offlineTime) > TimeSpan.FromMilliseconds(100)); + Assert.True(clientManager.completeTime.Subtract(serviceManager.stopTime) < -TimeSpan.FromMilliseconds(100)); + Assert.True(serviceManager.offlineTime != serviceManager.stopTime); + } + + private sealed class TestRouter : IEndpointRouter + { + public IEnumerable GetEndpointsForBroadcast(IEnumerable endpoints) + { + throw new NotImplementedException(); + } + + public IEnumerable GetEndpointsForConnection(string connectionId, IEnumerable endpoints) + { + throw new NotImplementedException(); + } + + public IEnumerable GetEndpointsForGroup(string groupName, IEnumerable endpoints) + { + throw new NotImplementedException(); + } + + public IEnumerable GetEndpointsForUser(string userId, IEnumerable endpoints) + { + throw new NotImplementedException(); + } + + public ServiceEndpoint GetNegotiateEndpoint(HttpContext context, IEnumerable endpoints) + { + throw new NotImplementedException(); + } + } + + private sealed class TestClientConnectionManager : IClientConnectionManager + { + public IReadOnlyDictionary ClientConnections => throw new NotImplementedException(); + + public DateTime completeTime = new DateTime(); + + public void AddClientConnection(ServiceConnectionContext clientConnection) + { + throw new NotImplementedException(); + } + + public ServiceConnectionContext RemoveClientConnection(string connectionId) + { + throw new NotImplementedException(); + } + + public async Task WhenAllCompleted() + { + await Task.Delay(100); + completeTime = DateTime.Now; + } + } + + private sealed class TestOptions : IOptions + { + public ServiceOptions Value => new ServiceOptions(); + } + + private sealed class TestServiceConnectionManager : IServiceConnectionManager where THub : Hub + { + public DateTime offlineTime = new DateTime(); + public DateTime stopTime = new DateTime(); + + public async Task OfflineAsync() + { + await Task.Delay(100); + offlineTime = DateTime.Now; + } + + public void SetServiceConnection(IServiceConnectionContainer serviceConnection) + { + throw new NotImplementedException(); + } + + public Task StartAsync() + { + throw new NotImplementedException(); + } + + public async Task StopAsync() + { + await Task.Delay(100); + stopTime = DateTime.Now; + } + + public Task WriteAckableMessageAsync(ServiceMessage seviceMessage, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task WriteAsync(ServiceMessage seviceMessage) + { + throw new NotImplementedException(); + } + } + + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionContainer.cs b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionContainer.cs index cb48bdf27..c8e4749d2 100644 --- a/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionContainer.cs +++ b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionContainer.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging.Abstractions; @@ -11,6 +12,10 @@ namespace Microsoft.Azure.SignalR.Tests { internal sealed class TestServiceConnectionContainer : ServiceConnectionContainerBase { + public bool IsOffline { get; set; } = false; + + public bool MockOffline { get; set; } = false; + public TestServiceConnectionContainer(List serviceConnections, HubServiceEndpoint endpoint = null, AckHandler ackHandler = null, IServiceConnectionFactory factory = null) : base(factory, 0, endpoint, serviceConnections, ackHandler: ackHandler, logger: NullLogger.Instance) { @@ -24,6 +29,18 @@ public void ShutdownForTest() prop.SetValue(this, true); } + public override async Task OfflineAsync() + { + if (MockOffline) + { + await Task.Delay(100); + IsOffline = true; + } else + { + await base.OfflineAsync(); + } + } + public override Task HandlePingAsync(PingMessage pingMessage) { return Task.CompletedTask; diff --git a/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs new file mode 100644 index 000000000..294977936 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/TestServiceConnectionForCloseAsync.cs @@ -0,0 +1,27 @@ +using System; +using System.Threading.Tasks; +using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; + +namespace Microsoft.Azure.SignalR.Tests +{ + class TestServiceConnectionForCloseAsync : TestServiceConnection + { + public TestServiceConnectionForCloseAsync() : base(ServiceConnectionStatus.Connected, false) + { + } + + /** + * Register an outgoing Task. + */ + protected override Task OnConnectedAsync(OpenConnectionMessage openConnectionMessage) + { + return Task.CompletedTask; + } + + public override Task WriteAsync(ServiceMessage serviceMessage) + { + return WriteAsyncBase(serviceMessage); + } + } +}