Skip to content

Commit

Permalink
Graceful shutdown implementation. (Azure#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan authored and JialinXin committed Dec 20, 2019
1 parent a1eb603 commit 80a6ed9
Show file tree
Hide file tree
Showing 28 changed files with 576 additions and 125 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ __pycache__/

.publish/

# vim
*.swp

# docker
.docker/

# vim
*.swp
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf);
}

var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, ccm, loggerFactory);
var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, loggerFactory);

if (hubs?.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ protected override Task OnConnectedAsync(OpenConnectionMessage openConnectionMes
}
else
{
// the manager still contains this connectionId, probably this connection is not yet cleaned up
// the manager still contains this connectionId, probably this connection is not yet cleaned up
Log.DuplicateConnectionId(Logger, connectionId, null);
return WriteAsync(
new CloseConnectionMessage(connectionId, $"Duplicate connection ID {connectionId}"));
Expand Down Expand Up @@ -288,10 +288,10 @@ private string GetInstanceId(IDictionary<string, StringValues> header)
if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId))
{
return instanceId;
}
return null;
}
return null;
}


private sealed class ClientContext
{
private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
Expand Down Expand Up @@ -319,7 +319,7 @@ public void CancelPendingRead()
public string InstanceId { get; }

public ChannelReader<ServiceMessage> Input { get; }

public ChannelWriter<ServiceMessage> Output { get; }

public IServiceTransport Transport { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ public Task StopAsync()
return Task.WhenAll(GetConnections().Select(s => s.StopAsync()));
}

public Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync()
{
return Task.WhenAll(GetConnections().Select(s => s.ShutdownAsync(timeout)));
}

return Task.WhenAll(GetConnections().Select(s => s.OfflineAsync()));
}

public IServiceConnectionContainer WithHub(string hubName)
{
if (_hubConnections == null ||!_hubConnections.TryGetValue(hubName, out var connection))
Expand Down Expand Up @@ -131,6 +131,6 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
yield return conn.Value;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
Expand All @@ -15,7 +14,7 @@ internal interface IServiceConnectionContainer

Task StopAsync();

Task ShutdownAsync(TimeSpan timeout);
Task OfflineAsync();

Task WriteAsync(ServiceMessage serviceMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,18 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta
private readonly IMessageRouter _router;
private readonly ILogger _logger;
private readonly IServiceConnectionContainer _inner;
private readonly IClientConnectionLifetimeManager _clientLifetime;

private IReadOnlyList<HubServiceEndpoint> _endpoints;

public Dictionary<ServiceEndpoint, IServiceConnectionContainer> Connections { get; }

public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory)
public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory)
{
if (generator == null)
{
throw new ArgumentNullException(nameof(generator));
}

_clientLifetime = lifetime;

_logger = loggerFactory?.CreateLogger<MultiEndpointServiceConnectionContainer>() ?? throw new ArgumentNullException(nameof(loggerFactory));

// provides a copy to the endpoint per container
Expand All @@ -57,14 +54,12 @@ public MultiEndpointServiceConnectionContainer(
IServiceEndpointManager endpointManager,
IMessageRouter router,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory
) : this(
hub,
endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, loggerFactory),
endpointManager,
router,
lifetime,
loggerFactory
)
{
Expand Down Expand Up @@ -131,18 +126,18 @@ public Task StopAsync()
}));
}

public async Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync()
{
// TODOS

// 1. write FIN to every server connection of every connection container.
if (_inner != null)
{
return _inner.OfflineAsync();
}
else
{
return Task.WhenAll(Connections.Select(c => c.Value.OfflineAsync()));
}
}

// 2. wait until all client connections have been closed (either by server/client side)

// 3. stop every container.
await StopAsync();
}

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (_inner != null)
Expand Down Expand Up @@ -251,8 +246,8 @@ private Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func<
}

return Task.WhenAll(routed);
}

}

private static class Log
{
private static readonly Action<ILogger, string, Exception> _startingConnection =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ public Task StopAsync()
{
Log.UnexpectedExceptionInStop(Logger, ConnectionId, ex);
}

return Task.CompletedTask;
}

Expand Down Expand Up @@ -394,8 +393,7 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
else
{
Log.HandshakeError(Logger, handshakeResponse.ErrorMessage, ConnectionId);
}

}
return false;
}
}
Expand All @@ -416,10 +414,11 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
private async Task ProcessIncomingAsync(ConnectionContext connection)
{
var keepAliveTimer = StartKeepAliveTimer();

try
{
while (true)
{
{
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;

Expand Down Expand Up @@ -466,27 +465,28 @@ private async Task ProcessIncomingAsync(ConnectionContext connection)
}
}
finally
{
keepAliveTimer.Stop();
}
{
keepAliveTimer.Stop();
_serviceConnectionOfflineTcs.TrySetResult(true);
}
}

private Task DispatchMessageAsync(ServiceMessage message)
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
}
return Task.CompletedTask;
}
Expand Down Expand Up @@ -626,7 +626,6 @@ private static class Log
private static readonly Action<ILogger, string, Exception> _receivedInstanceOfflinePing =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(31, "ReceivedInstanceOfflinePing"), "Received instance offline service ping: {InstanceId}");


public static void FailedToWrite(ILogger logger, string serviceConnectionId, Exception exception)
{
_failedToWrite(logger, exception.Message, serviceConnectionId, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta
private static TimeSpan ReconnectInterval =>
TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds));

private static TimeSpan RemoveFromServiceTimeout = TimeSpan.FromSeconds(3);

private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy();

private readonly object _lock = new object();
Expand All @@ -30,7 +32,12 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta

private volatile ServiceConnectionStatus _status;

private volatile bool _terminated = false;
private volatile bool _terminated = false;

private static readonly PingMessage _shutdownFinMessage = new PingMessage()
{
Messages = new string[2] { Constants.ServicePingMessageKey.ShutdownKey, Constants.ServicePingMessageValue.ShutdownFin }
};

protected ILogger Logger { get; }

Expand Down Expand Up @@ -117,12 +124,6 @@ public virtual Task StopAsync()
return Task.WhenAll(FixedServiceConnections.Select(c => c.StopAsync()));
}

public virtual Task ShutdownAsync(TimeSpan timeout)
{
_terminated = true;
return Task.CompletedTask;
}

/// <summary>
/// Start and manage the whole connection lifetime
/// </summary>
Expand Down Expand Up @@ -246,8 +247,8 @@ protected void ReplaceFixedConnections(int index, IServiceConnection serviceConn
}

public Task ConnectionInitializedTask => Task.WhenAll(from connection in FixedServiceConnections
select connection.ConnectionInitializedTask);

select connection.ConnectionInitializedTask);

public virtual Task WriteAsync(ServiceMessage serviceMessage)
{
return WriteToRandomAvailableConnection(serviceMessage);
Expand Down Expand Up @@ -347,8 +348,32 @@ private IEnumerable<IServiceConnection> CreateFixedServiceConnection(int count)
{
yield return CreateServiceConnectionCore(InitialConnectionType);
}
}

}

protected async Task WriteFinAsync(IServiceConnection c)
{
await c.WriteAsync(_shutdownFinMessage);
}

protected async Task RemoveConnectionFromService(IServiceConnection c)
{
_ = WriteFinAsync(c);

var source = new CancellationTokenSource();
var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(RemoveFromServiceTimeout, source.Token));
source.Cancel();

if (task != c.ConnectionOfflineTask)
{
// log
}
}

public virtual Task OfflineAsync()
{
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c)));
}

private static class Log
{
private static readonly Action<ILogger, string, string, Exception> _endpointOnline =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa
private readonly IServiceEndpointManager _serviceEndpointManager;
private readonly IMessageRouter _router;
private readonly IServerNameProvider _nameProvider;
private readonly IClientConnectionLifetimeManager _lifetime;
private readonly IServiceConnectionFactory _serviceConnectionFactory;

public ServiceConnectionContainerFactory(
Expand All @@ -22,21 +21,19 @@ public ServiceConnectionContainerFactory(
IMessageRouter router,
IServiceEndpointOptions options,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory)
{
_serviceConnectionFactory = serviceConnectionFactory;
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_options = options;
_nameProvider = nameProvider;
_lifetime = lifetime;
_loggerFactory = loggerFactory;
}

public IServiceConnectionContainer Create(string hub)
{
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _lifetime, _loggerFactory);
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _loggerFactory);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -42,13 +43,11 @@ public override Task StopAsync()
);
}

public override Task ShutdownAsync(TimeSpan timeout)
public override Task OfflineAsync()
{
var task = base.StopAsync();
return Task.WhenAll(
task,
Task.CompletedTask // TODO
);
var task1 = base.OfflineAsync();
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c)));
return Task.WhenAll(task1, task2);
}

protected override ServiceConnectionStatus GetStatus()
Expand Down
Loading

0 comments on commit 80a6ed9

Please sign in to comment.