Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful shutdown support #689

Merged
merged 1 commit into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ __pycache__/

.publish/

# vim
*.swp

# docker
.docker/

# vim
*.swp
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.AspNet/DispatcherHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ internal static ServiceHubDispatcher PrepareAndGetDispatcher(IAppBuilder builder
configuration.Resolver.Register(typeof(IServiceConnectionFactory), () => scf);
}

var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, ccm, loggerFactory);
var sccf = new ServiceConnectionContainerFactory(scf, endpoint, router, options, serverNameProvider, loggerFactory);

if (hubs?.Count > 0)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ protected override Task OnConnectedAsync(OpenConnectionMessage openConnectionMes
}
else
{
// the manager still contains this connectionId, probably this connection is not yet cleaned up
// the manager still contains this connectionId, probably this connection is not yet cleaned up
Log.DuplicateConnectionId(Logger, connectionId, null);
return WriteAsync(
new CloseConnectionMessage(connectionId, $"Duplicate connection ID {connectionId}"));
Expand Down Expand Up @@ -288,10 +288,10 @@ private string GetInstanceId(IDictionary<string, StringValues> header)
if (header.TryGetValue(Constants.AsrsInstanceId, out var instanceId))
{
return instanceId;
}
return null;
}
return null;
}


private sealed class ClientContext
{
private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
Expand Down Expand Up @@ -319,7 +319,7 @@ public void CancelPendingRead()
public string InstanceId { get; }

public ChannelReader<ServiceMessage> Input { get; }

public ChannelWriter<ServiceMessage> Output { get; }

public IServiceTransport Transport { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ public Task StopAsync()
return Task.WhenAll(GetConnections().Select(s => s.StopAsync()));
}

public Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync()
{
return Task.WhenAll(GetConnections().Select(s => s.ShutdownAsync(timeout)));
}

return Task.WhenAll(GetConnections().Select(s => s.OfflineAsync()));
}

public IServiceConnectionContainer WithHub(string hubName)
{
if (_hubConnections == null ||!_hubConnections.TryGetValue(hubName, out var connection))
Expand Down Expand Up @@ -131,6 +131,6 @@ private IEnumerable<IServiceConnectionContainer> GetConnections()
yield return conn.Value;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
Expand All @@ -15,7 +14,7 @@ internal interface IServiceConnectionContainer

Task StopAsync();

Task ShutdownAsync(TimeSpan timeout);
Task OfflineAsync();

Task WriteAsync(ServiceMessage serviceMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,18 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta
private readonly IMessageRouter _router;
private readonly ILogger _logger;
private readonly IServiceConnectionContainer _inner;
private readonly IClientConnectionLifetimeManager _clientLifetime;

private IReadOnlyList<HubServiceEndpoint> _endpoints;

public Dictionary<ServiceEndpoint, IServiceConnectionContainer> Connections { get; }

public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, IClientConnectionLifetimeManager lifetime, ILoggerFactory loggerFactory)
public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory)
{
if (generator == null)
{
throw new ArgumentNullException(nameof(generator));
}

_clientLifetime = lifetime;

_logger = loggerFactory?.CreateLogger<MultiEndpointServiceConnectionContainer>() ?? throw new ArgumentNullException(nameof(loggerFactory));

// provides a copy to the endpoint per container
Expand All @@ -57,14 +54,12 @@ public MultiEndpointServiceConnectionContainer(
IServiceEndpointManager endpointManager,
IMessageRouter router,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory
) : this(
hub,
endpoint => CreateContainer(serviceConnectionFactory, endpoint, count, loggerFactory),
endpointManager,
router,
lifetime,
loggerFactory
)
{
Expand Down Expand Up @@ -131,18 +126,18 @@ public Task StopAsync()
}));
}

public async Task ShutdownAsync(TimeSpan timeout)
public Task OfflineAsync()
{
// TODOS

// 1. write FIN to every server connection of every connection container.
if (_inner != null)
{
return _inner.OfflineAsync();
}
else
{
return Task.WhenAll(Connections.Select(c => c.Value.OfflineAsync()));
}
}

// 2. wait until all client connections have been closed (either by server/client side)

// 3. stop every container.
await StopAsync();
}

public Task WriteAsync(ServiceMessage serviceMessage)
{
if (_inner != null)
Expand Down Expand Up @@ -251,8 +246,8 @@ private Task WriteMultiEndpointMessageAsync(ServiceMessage serviceMessage, Func<
}

return Task.WhenAll(routed);
}

}

private static class Log
{
private static readonly Action<ILogger, string, Exception> _startingConnection =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ public Task StopAsync()
{
Log.UnexpectedExceptionInStop(Logger, ConnectionId, ex);
}

return Task.CompletedTask;
}

Expand Down Expand Up @@ -394,8 +393,7 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
else
{
Log.HandshakeError(Logger, handshakeResponse.ErrorMessage, ConnectionId);
}

}
return false;
}
}
Expand All @@ -416,10 +414,11 @@ private async Task<bool> ReceiveHandshakeResponseAsync(PipeReader input, Cancell
private async Task ProcessIncomingAsync(ConnectionContext connection)
{
var keepAliveTimer = StartKeepAliveTimer();

try
{
while (true)
{
{
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;

Expand Down Expand Up @@ -466,27 +465,28 @@ private async Task ProcessIncomingAsync(ConnectionContext connection)
}
}
finally
{
keepAliveTimer.Stop();
}
{
keepAliveTimer.Stop();
_serviceConnectionOfflineTcs.TrySetResult(true);
}
}

private Task DispatchMessageAsync(ServiceMessage message)
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
{
switch (message)
{
case OpenConnectionMessage openConnectionMessage:
return OnConnectedAsync(openConnectionMessage);
case CloseConnectionMessage closeConnectionMessage:
return OnDisconnectedAsync(closeConnectionMessage);
case ConnectionDataMessage connectionDataMessage:
return OnMessageAsync(connectionDataMessage);
case ServiceErrorMessage serviceErrorMessage:
return OnServiceErrorAsync(serviceErrorMessage);
case PingMessage pingMessage:
return OnPingMessageAsync(pingMessage);
case AckMessage ackMessage:
return OnAckMessageAsync(ackMessage);
}
return Task.CompletedTask;
}
Expand Down Expand Up @@ -626,7 +626,6 @@ private static class Log
private static readonly Action<ILogger, string, Exception> _receivedInstanceOfflinePing =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(31, "ReceivedInstanceOfflinePing"), "Received instance offline service ping: {InstanceId}");


public static void FailedToWrite(ILogger logger, string serviceConnectionId, Exception exception)
{
_failedToWrite(logger, exception.Message, serviceConnectionId, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta
private static TimeSpan ReconnectInterval =>
TimeSpan.FromMilliseconds(StaticRandom.Next(MaxReconnectBackOffInternalInMilliseconds));

private static TimeSpan RemoveFromServiceTimeout = TimeSpan.FromSeconds(3);

private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy();

private readonly object _lock = new object();
Expand All @@ -30,7 +32,12 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta

private volatile ServiceConnectionStatus _status;

private volatile bool _terminated = false;
private volatile bool _terminated = false;

private static readonly PingMessage _shutdownFinMessage = new PingMessage()
{
Messages = new string[2] { Constants.ServicePingMessageKey.ShutdownKey, Constants.ServicePingMessageValue.ShutdownFin }
};

protected ILogger Logger { get; }

Expand Down Expand Up @@ -117,12 +124,6 @@ public virtual Task StopAsync()
return Task.WhenAll(FixedServiceConnections.Select(c => c.StopAsync()));
}

public virtual Task ShutdownAsync(TimeSpan timeout)
{
_terminated = true;
return Task.CompletedTask;
}

/// <summary>
/// Start and manage the whole connection lifetime
/// </summary>
Expand Down Expand Up @@ -246,8 +247,8 @@ protected void ReplaceFixedConnections(int index, IServiceConnection serviceConn
}

public Task ConnectionInitializedTask => Task.WhenAll(from connection in FixedServiceConnections
select connection.ConnectionInitializedTask);

select connection.ConnectionInitializedTask);

public virtual Task WriteAsync(ServiceMessage serviceMessage)
{
return WriteToRandomAvailableConnection(serviceMessage);
Expand Down Expand Up @@ -347,8 +348,32 @@ private IEnumerable<IServiceConnection> CreateFixedServiceConnection(int count)
{
yield return CreateServiceConnectionCore(InitialConnectionType);
}
}

}

protected async Task WriteFinAsync(IServiceConnection c)
{
await c.WriteAsync(_shutdownFinMessage);
}

protected async Task RemoveConnectionFromService(IServiceConnection c)
{
_ = WriteFinAsync(c);

var source = new CancellationTokenSource();
var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(RemoveFromServiceTimeout, source.Token));
source.Cancel();

if (task != c.ConnectionOfflineTask)
{
// log
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't this get fixed in the PR?

}
}

public virtual Task OfflineAsync()
{
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c)));
}

private static class Log
{
private static readonly Action<ILogger, string, string, Exception> _endpointOnline =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ internal class ServiceConnectionContainerFactory : IServiceConnectionContainerFa
private readonly IServiceEndpointManager _serviceEndpointManager;
private readonly IMessageRouter _router;
private readonly IServerNameProvider _nameProvider;
private readonly IClientConnectionLifetimeManager _lifetime;
private readonly IServiceConnectionFactory _serviceConnectionFactory;

public ServiceConnectionContainerFactory(
Expand All @@ -22,21 +21,19 @@ public ServiceConnectionContainerFactory(
IMessageRouter router,
IServiceEndpointOptions options,
IServerNameProvider nameProvider,
IClientConnectionLifetimeManager lifetime,
ILoggerFactory loggerFactory)
{
_serviceConnectionFactory = serviceConnectionFactory;
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_router = router ?? throw new ArgumentNullException(nameof(router));
_options = options;
_nameProvider = nameProvider;
_lifetime = lifetime;
_loggerFactory = loggerFactory;
}

public IServiceConnectionContainer Create(string hub)
{
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _lifetime, _loggerFactory);
return new MultiEndpointServiceConnectionContainer(_serviceConnectionFactory, hub, _options.ConnectionCount, _serviceEndpointManager, _router, _nameProvider, _loggerFactory);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -42,13 +43,11 @@ public override Task StopAsync()
);
}

public override Task ShutdownAsync(TimeSpan timeout)
public override Task OfflineAsync()
{
var task = base.StopAsync();
return Task.WhenAll(
task,
Task.CompletedTask // TODO
);
var task1 = base.OfflineAsync();
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c)));
return Task.WhenAll(task1, task2);
}

protected override ServiceConnectionStatus GetStatus()
Expand Down
Loading