Skip to content

Commit

Permalink
server connection migration implementation (#739)
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan authored Nov 29, 2019
1 parent 9652efb commit 91ea570
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 20 deletions.
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ __pycache__/

.publish/

# docker
.docker/

# vim
*.swp

# docker
.docker/
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ internal static class Constants

public const int DefaultShutdownTimeoutInSeconds = 30;

public const string AsrsMigrateIn = "Asrs-Migrate-In";
public const string AsrsMigrateOut = "Asrs-Migrate-Out";
public const string AsrsUserAgent = "Asrs-User-Agent";
public const string AsrsInstanceId = "Asrs-Instance-Id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Azure.SignalR
{
internal abstract class ServiceConnectionBase : IServiceConnection
{
private static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15);
protected static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15);
// Service ping rate is 5 sec to let server know service status. Set timeout for 30 sec for some space.
private static readonly TimeSpan DefaultServiceTimeout = TimeSpan.FromSeconds(30);
private static readonly long DefaultServiceTimeoutTicks = DefaultServiceTimeout.Seconds * Stopwatch.Frequency;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ protected async Task WriteFinAsync(IServiceConnection c)
await c.WriteAsync(_shutdownFinMessage);
}

protected async Task RemoveConnectionFromService(IServiceConnection c)
protected async Task RemoveConnectionAsync(IServiceConnection c)
{
_ = WriteFinAsync(c);

Expand All @@ -371,7 +371,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c)

public virtual Task OfflineAsync()
{
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c)));
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionAsync(c)));
}

private static class Log
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public override Task StopAsync()
public override Task OfflineAsync()
{
var task1 = base.OfflineAsync();
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c)));
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionAsync(c)));
return Task.WhenAll(task1, task2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
using Microsoft.AspNetCore.Http.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Features.Authentication;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;
Expand All @@ -37,7 +36,9 @@ internal class ClientConnectionContext : ConnectionContext,

private readonly TaskCompletionSource<object> _connectionEndTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

public Task CompleteTask => _connectionEndTcs.Task;
public Task CompleteTask => _connectionEndTcs.Task;

public bool IsMigrated { get; }

private readonly object _heartbeatLock = new object();
private List<(Action<object> handler, object state)> _heartbeatHandlers;
Expand All @@ -47,6 +48,11 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action<Http
ConnectionId = serviceMessage.ConnectionId;
User = serviceMessage.GetUserPrincipal();

if (serviceMessage.Headers.TryGetValue(Constants.AsrsMigrateIn, out _))
{
IsMigrated = true;
}

// Create the Duplix Pipeline for the virtual connection
transportPipeOptions = transportPipeOptions ?? DefaultPipeOptions;
appPipeOptions = appPipeOptions ?? DefaultPipeOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ private static class Log
private static readonly Action<ILogger, Exception> _applicationTaskTimedOut =
LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskTimedOut"), "Timed out waiting for the application task to complete.");

private static readonly Action<ILogger, string, Exception> _migrationStarting =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server.");

private static readonly Action<ILogger, string, Exception> _errorSkippingHandshakeResponse =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}");

public static void FailedToCleanupConnections(ILogger logger, Exception exception)
{
_failedToCleanupConnections(logger, exception);
Expand Down Expand Up @@ -82,6 +88,11 @@ public static void ConnectedStarting(ILogger logger, string connectionId)
_connectedStarting(logger, connectionId, null);
}

public static void MigrationStarting(ILogger logger, string connectionId)
{
_migrationStarting(logger, connectionId, null);
}

public static void ConnectedEnding(ILogger logger, string connectionId)
{
_connectedEnding(logger, connectionId, null);
Expand All @@ -101,6 +112,11 @@ public static void ApplicationTaskTimedOut(ILogger logger)
{
_applicationTaskTimedOut(logger, null);
}

public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex)
{
_errorSkippingHandshakeResponse(logger, ex.Message, ex);
}
}
}
}
89 changes: 83 additions & 6 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.Azure.SignalR
{
internal partial class ServiceConnection : ServiceConnectionBase
Expand All @@ -22,6 +23,8 @@ internal partial class ServiceConnection : ServiceConnectionBase
private static readonly Dictionary<string, string> CustomHeader = new Dictionary<string, string> { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } };
private static readonly TimeSpan CloseTimeout = TimeSpan.FromSeconds(5);

private readonly bool _enableConnectionMigration;

private const string ClientConnectionCountInHub = "#clientInHub";
private const string ClientConnectionCountInServiceConnection = "#client";

Expand Down Expand Up @@ -53,6 +56,8 @@ public ServiceConnection(IServiceProtocol serviceProtocol,
_connectionFactory = connectionFactory;
_connectionDelegate = connectionDelegate;
_clientConnectionFactory = clientConnectionFactory;

_enableConnectionMigration = false;
}

protected override Task<ConnectionContext> CreateConnection(string target = null)
Expand Down Expand Up @@ -102,8 +107,16 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
connection.ServiceConnection = this;
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);
AddClientConnection(connection, message);

if (connection.IsMigrated)
{
Log.MigrationStarting(Logger, connection.ConnectionId);
}
else
{
Log.ConnectedStarting(Logger, connection.ConnectionId);
}

// Execute the application code
connection.ApplicationTask = _connectionDelegate(connection);
Expand All @@ -120,6 +133,17 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage)
{
var connectionId = closeConnectionMessage.ConnectionId;
if (_enableConnectionMigration && _clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context))
{
if (!context.HttpContext.Request.Headers.ContainsKey(Constants.AsrsMigrateOut))
{
context.HttpContext.Request.Headers.Add(Constants.AsrsMigrateOut, "");
}
// We have to prevent SignalR `{type: 7}` (close message) from reaching our client while doing migration.
// Since all user-created messages will be sent to `ServiceConnection` directly.
// We can simply ignore all messages came from the application pipe.
context.Application.Input.CancelPendingRead();
}
return PerformDisconnectAsyncCore(connectionId, false);
}

Expand Down Expand Up @@ -158,20 +182,72 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect
}
}

private async Task<bool> SkipHandshakeResponse(ClientConnectionContext connection, CancellationToken token)
{
try
{
while (true)
{
var result = await connection.Application.Input.ReadAsync(token);
if (result.IsCanceled || token.IsCancellationRequested)
{
return false;
}

var buffer = result.Buffer;
if (buffer.IsEmpty)
{
continue;
}

if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
{
connection.Application.Input.AdvanceTo(buffer.Start);
return true;
}

if (result.IsCompleted)
{
return false;
}
}
}
catch (Exception ex)
{
Log.ErrorSkippingHandshakeResponse(Logger, ex);
}
return false;
}

private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connection)
{
try
{
{
if (connection.IsMigrated)
{
using var source = new CancellationTokenSource(DefaultHandshakeTimeout);

// A handshake response is not expected to be given
// if the connection was migrated from another server,
// since the connection hasn't been `dropped` from the client point of view.
if (!await SkipHandshakeResponse(connection, source.Token))
{
return;
}
}

while (true)
{
var result = await connection.Application.Input.ReadAsync();

if (result.IsCanceled)
{
break;
}

var buffer = result.Buffer;
if (!buffer.IsEmpty)

if (!buffer.IsEmpty)
{
try
{
Expand Down Expand Up @@ -209,8 +285,9 @@ private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connecti
}
}

private void AddClientConnection(ClientConnectionContext connection, string instanceId)
private void AddClientConnection(ClientConnectionContext connection, OpenConnectionMessage message)
{
var instanceId = GetInstanceId(message.Headers);
_clientConnectionManager.AddClientConnection(connection);
_connectionIds.TryAdd(connection.ConnectionId, instanceId);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Tests
{
class TestClientConnectionFactory : IClientConnectionFactory
{
public IList<ClientConnectionContext> Connections = new List<ClientConnectionContext>();

public ClientConnectionContext CreateConnection(OpenConnectionMessage message, Action<HttpContext> configureContext = null)
{
var context = new ClientConnectionContext(message, configureContext);
Connections.Add(context);
return context;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,13 +13,20 @@ internal class TestConnectionFactory : IConnectionFactory
{
private readonly Func<TestConnection, Task> _connectCallback;

public IList<TestConnection> Connections = new List<TestConnection>();

public List<DateTime> Times { get; } = new List<DateTime>();

public TestConnectionFactory()
{
_connectCallback = null;
}

public TestConnectionFactory(Func<TestConnection, Task> connectCallback)
{
_connectCallback = connectCallback;
}

public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, TransferFormat transferFormat, string connectionId, string target,
CancellationToken cancellationToken = default, IDictionary<string, string> headers = null)
{
Expand All @@ -31,13 +37,18 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, T
ConnectionId = connectionId,
Target = target
};
Connections.Add(connection);

// Start a task to process handshake request from the newly-created server connection.
_ = HandshakeAsync(connection);

// Do something for test purpose
await AfterConnectedAsync(connection);

await _connectCallback(connection);
if (_connectCallback != null)
{
await _connectCallback(connection);
}

return connection;
}
Expand Down
Loading

0 comments on commit 91ea570

Please sign in to comment.