Skip to content

Commit

Permalink
Graceful shutdown implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 13, 2019
1 parent 3d88464 commit baa06f0
Show file tree
Hide file tree
Showing 27 changed files with 555 additions and 119 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf);
}

var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, ccm, loggerFactory);
var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, loggerFactory);

if (hubs?.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,10 @@ private string GetInstanceId(IDictionary<string, StringValues> header)
if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId))
{
return instanceId;
}
return null;
}
return null;
}


private sealed class ClientContext
{
private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ public Task StopAsync()
return Task.WhenAll(GetConnections().Select(s => s.StopAsync()));
}

public Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync(CancellationToken token)
{
return Task.WhenAll(GetConnections().Select(s => s.ShutdownAsync(timeout)));
}

return Task.WhenAll(GetConnections().Select(s => s.OfflineAsync(token)));
}

public IServiceConnectionContainer WithHub(string hubName)
{
if (_hubConnections == null ||!_hubConnections.TryGetValue(hubName, out var connection))
Expand Down Expand Up @@ -131,6 +131,6 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
yield return conn.Value;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
Expand All @@ -15,7 +14,7 @@ internal interface IServiceConnectionContainer

Task StopAsync();

Task ShutdownAsync(TimeSpan timeout);
Task OfflineAsync(CancellationToken token);

Task WriteAsync(ServiceMessage serviceMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,18 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta
private readonly IMessageRouter _router;
private readonly ILogger _logger;
private readonly IServiceConnectionContainer _inner;
private readonly IClientConnectionLifetimeManager _clientLifetime;

private IReadOnlyList<HubServiceEndpoint> _endpoints;

public Dictionary<ServiceEndpoint, IServiceConnectionContainer> Connections { get; }

public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory)
public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory)
{
if (generator == null)
{
throw new ArgumentNullException(nameof(generator));
}

_clientLifetime = lifetime;

_logger = loggerFactory?.CreateLogger<MultiEndpointServiceConnectionContainer>() ?? throw new ArgumentNullException(nameof(loggerFactory));

// provides a copy to the endpoint per container
Expand All @@ -57,14 +54,12 @@ public MultiEndpointServiceConnectionContainer(
IServiceEndpointManager endpointManager,
IMessageRouter router,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory
) : this(
hub,
endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, loggerFactory),
endpointManager,
router,
lifetime,
loggerFactory
)
{
Expand Down Expand Up @@ -131,18 +126,18 @@ public Task StopAsync()
}));
}

public async Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync(CancellationToken token = default)
{
// TODOS

// 1. write FIN to every server connection of every connection container.
if (_inner != null)
{
return _inner.OfflineAsync(token);
}
else
{
return Task.WhenAll(Connections.Select(c => c.Value.OfflineAsync(token)));
}
}

// 2. wait until all client connections have been closed (either by server/client side)

// 3. stop every container.
await StopAsync();
}

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (_inner != null)
Expand Down Expand Up @@ -251,8 +246,8 @@ private Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func<
}

return Task.WhenAll(routed);
}

}

private static class Log
{
private static readonly Action<ILogger, string, Exception> _startingConnection =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ public Task StopAsync()
{
Log.UnexpectedExceptionInStop(Logger, ConnectionId, ex);
}

return Task.CompletedTask;
}

Expand Down Expand Up @@ -394,8 +393,7 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
else
{
Log.HandshakeError(Logger, handshakeResponse.ErrorMessage, ConnectionId);
}

}
return false;
}
}
Expand All @@ -416,10 +414,11 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
private async Task ProcessIncomingAsync(ConnectionContext connection)
{
var keepAliveTimer = StartKeepAliveTimer();

try
{
while (true)
{
{
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;

Expand Down Expand Up @@ -466,27 +465,28 @@ private async Task ProcessIncomingAsync(ConnectionContext connection)
}
}
finally
{
keepAliveTimer.Stop();
}
{
keepAliveTimer.Stop();
_serviceConnectionOfflineTcs.TrySetResult(true);
}
}

private Task DispatchMessageAsync(ServiceMessage message)
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
}
return Task.CompletedTask;
}
Expand Down Expand Up @@ -626,7 +626,6 @@ private static class Log
private static readonly Action<ILogger, string, Exception> _receivedInstanceOfflinePing =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(31, "ReceivedInstanceOfflinePing"), "Received instance offline service ping: {InstanceId}");


public static void FailedToWrite(ILogger logger, string serviceConnectionId, Exception exception)
{
_failedToWrite(logger, exception.Message, serviceConnectionId, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,6 @@ public virtual Task StopAsync()
return Task.WhenAll(FixedServiceConnections.Select(c => c.StopAsync()));
}

public virtual Task ShutdownAsync(TimeSpan timeout)
{
_terminated = true;
return Task.CompletedTask;
}

/// <summary>
/// Start and manage the whole connection lifetime
/// </summary>
Expand Down Expand Up @@ -246,8 +240,8 @@ protected void ReplaceFixedConnections(int index, IServiceConnection serviceConn
}

public Task ConnectionInitializedTask => Task.WhenAll(from connection in FixedServiceConnections
select connection.ConnectionInitializedTask);

select connection.ConnectionInitializedTask);

public virtual Task WriteAsync(ServiceMessage serviceMessage)
{
return WriteToRandomAvailableConnection(serviceMessage);
Expand Down Expand Up @@ -347,8 +341,28 @@ private IEnumerable<IServiceConnection> CreateFixedServiceConnection(int count)
{
yield return CreateServiceConnectionCore(InitialConnectionType);
}
}

}

protected async Task WriteFinAsync(IServiceConnection c)
{
var ping = new PingMessage()
{
Messages = new string[2] { Constants.ServicePingMessageKey.ShutdownKey, Constants.ServicePingMessageValue.ShutdownFin }
};
await c.WriteAsync(ping);
}

protected async Task RemoveConnectionFromService(IServiceConnection c, CancellationToken token)
{
_ = WriteFinAsync(c);
await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(TimeSpan.FromSeconds(3), token));
}

public virtual Task OfflineAsync(CancellationToken token)
{
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c, token)));
}

private static class Log
{
private static readonly Action<ILogger, string, string, Exception> _endpointOnline =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa
private readonly IServiceEndpointManager _serviceEndpointManager;
private readonly IMessageRouter _router;
private readonly IServerNameProvider _nameProvider;
private readonly IClientConnectionLifetimeManager _lifetime;
private readonly IServiceConnectionFactory _serviceConnectionFactory;

public ServiceConnectionContainerFactory(
Expand All @@ -22,21 +21,19 @@ public ServiceConnectionContainerFactory(
IMessageRouter router,
IServiceEndpointOptions options,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory)
{
_serviceConnectionFactory = serviceConnectionFactory;
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_options = options;
_nameProvider = nameProvider;
_lifetime = lifetime;
_loggerFactory = loggerFactory;
}

public IServiceConnectionContainer Create(string hub)
{
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _lifetime, _loggerFactory);
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _loggerFactory);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -42,13 +43,11 @@ public override Task StopAsync()
);
}

public override Task ShutdownAsync(TimeSpan timeout)
public override Task OfflineAsync(CancellationToken token)
{
var task = base.StopAsync();
return Task.WhenAll(
task,
Task.CompletedTask // TODO
);
var task1 = base.OfflineAsync(token);
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c, token)));
return Task.WhenAll(task1, task2);
}

protected override ServiceConnectionStatus GetStatus()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ public override Task HandlePingAsync(PingMessage pingMessage)
return Task.CompletedTask;
}

public override Task ShutdownAsync(TimeSpan timeout) => StopAsync();

public override Task WriteAsync(ServiceMessage serviceMessage)
{
if (!_active && !(serviceMessage is PingMessage))
Expand All @@ -61,6 +59,11 @@ public override Task WriteAsync(ServiceMessage serviceMessage)
return base.WriteAsync(serviceMessage);
}

public override Task OfflineAsync(CancellationToken token)
{
return Task.CompletedTask;
}

internal bool GetServiceStatus(bool active, int checkWindow, TimeSpan checkTimeSpan)
{
lock (_lock)
Expand Down
22 changes: 15 additions & 7 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,18 @@ public void Start(ConnectionDelegate connectionDelegate, Action<HttpContext> con
_ = _serviceConnectionManager.StartAsync();
}

public Task Shutdown(TimeSpan timeout)
public async Task ShutdownAsync(TimeSpan timeout)
{
return _serviceConnectionManager.ShutdownAsync(timeout);
await Task.WhenAny(
Task.Delay(timeout), OfflineAndWaitForCompletedAsync()
);
await _serviceConnectionManager.StopAsync();
}

private async Task OfflineAndWaitForCompletedAsync()
{
await _serviceConnectionManager.OfflineAsync();
await _clientConnectionManager.WhenAllCompleted();
}

private IServiceConnectionContainer GetMultiEndpointServiceConnectionContainer(string hub, ConnectionDelegate connectionDelegate, Action<HttpContext> contextConfig = null)
Expand All @@ -77,12 +86,11 @@ private IServiceConnectionContainer GetMultiEndpointServiceConnectionContainer(s
serviceConnectionFactory.ConfigureContext = contextConfig;

var factory = new ServiceConnectionContainerFactory(
serviceConnectionFactory,
serviceConnectionFactory,
_serviceEndpointManager,
_router,
_options,
_nameProvider,
_clientConnectionManager,
_router,
_options,
_nameProvider,
_loggerFactory
);
return factory.Create(hub);
Expand Down
Loading

0 comments on commit baa06f0

Please sign in to comment.