Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server connection migration implementation #739

Merged
merged 1 commit into from
Nov 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ __pycache__/

.publish/

# docker
.docker/

# vim
*.swp

# docker
.docker/
2 changes: 2 additions & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ internal static class Constants

public const int DefaultShutdownTimeoutInSeconds = 30;

public const string AsrsMigrateIn = "Asrs-Migrate-In";
public const string AsrsMigrateOut = "Asrs-Migrate-Out";
public const string AsrsUserAgent = "Asrs-User-Agent";
public const string AsrsInstanceId = "Asrs-Instance-Id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Microsoft.Azure.SignalR
{
internal abstract class ServiceConnectionBase : IServiceConnection
{
private static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15);
protected static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15);
// Service ping rate is 5 sec to let server know service status. Set timeout for 30 sec for some space.
private static readonly TimeSpan DefaultServiceTimeout = TimeSpan.FromSeconds(30);
private static readonly long DefaultServiceTimeoutTicks = DefaultServiceTimeout.Seconds * Stopwatch.Frequency;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ protected async Task WriteFinAsync(IServiceConnection c)
await c.WriteAsync(_shutdownFinMessage);
}

protected async Task RemoveConnectionFromService(IServiceConnection c)
protected async Task RemoveConnectionAsync(IServiceConnection c)
{
_ = WriteFinAsync(c);

Expand All @@ -371,7 +371,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c)

public virtual Task OfflineAsync()
{
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionFromService(c)));
return Task.WhenAll(FixedServiceConnections.Select(c => RemoveConnectionAsync(c)));
}

private static class Log
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public override Task StopAsync()
public override Task OfflineAsync()
{
var task1 = base.OfflineAsync();
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionFromService(c)));
var task2 = Task.WhenAll(_onDemandServiceConnections.Select(c => RemoveConnectionAsync(c)));
return Task.WhenAll(task1, task2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
using Microsoft.AspNetCore.Http.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Features.Authentication;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Primitives;
Expand All @@ -37,7 +36,9 @@ internal class ClientConnectionContext : ConnectionContext,

private readonly TaskCompletionSource<object> _connectionEndTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

public Task CompleteTask => _connectionEndTcs.Task;
public Task CompleteTask => _connectionEndTcs.Task;

public bool IsMigrated { get; }

private readonly object _heartbeatLock = new object();
private List<(Action<object> handler, object state)> _heartbeatHandlers;
Expand All @@ -47,6 +48,11 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action<Http
ConnectionId = serviceMessage.ConnectionId;
User = serviceMessage.GetUserPrincipal();

if (serviceMessage.Headers.TryGetValue(Constants.AsrsMigrateIn, out _))
{
IsMigrated = true;
}

// Create the Duplix Pipeline for the virtual connection
transportPipeOptions = transportPipeOptions ?? DefaultPipeOptions;
appPipeOptions = appPipeOptions ?? DefaultPipeOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ private static class Log
private static readonly Action<ILogger, Exception> _applicationTaskTimedOut =
LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskTimedOut"), "Timed out waiting for the application task to complete.");

private static readonly Action<ILogger, string, Exception> _migrationStarting =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server.");

private static readonly Action<ILogger, string, Exception> _errorSkippingHandshakeResponse =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}");

public static void FailedToCleanupConnections(ILogger logger, Exception exception)
{
_failedToCleanupConnections(logger, exception);
Expand Down Expand Up @@ -82,6 +88,11 @@ public static void ConnectedStarting(ILogger logger, string connectionId)
_connectedStarting(logger, connectionId, null);
}

public static void MigrationStarting(ILogger logger, string connectionId)
{
_migrationStarting(logger, connectionId, null);
}

public static void ConnectedEnding(ILogger logger, string connectionId)
{
_connectedEnding(logger, connectionId, null);
Expand All @@ -101,6 +112,11 @@ public static void ApplicationTaskTimedOut(ILogger logger)
{
_applicationTaskTimedOut(logger, null);
}

public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex)
{
_errorSkippingHandshakeResponse(logger, ex.Message, ex);
}
}
}
}
89 changes: 83 additions & 6 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.Azure.SignalR
{
internal partial class ServiceConnection : ServiceConnectionBase
Expand All @@ -22,6 +23,8 @@ internal partial class ServiceConnection : ServiceConnectionBase
private static readonly Dictionary<string, string> CustomHeader = new Dictionary<string, string> { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } };
private static readonly TimeSpan CloseTimeout = TimeSpan.FromSeconds(5);

private readonly bool _enableConnectionMigration;

private const string ClientConnectionCountInHub = "#clientInHub";
private const string ClientConnectionCountInServiceConnection = "#client";

Expand Down Expand Up @@ -53,6 +56,8 @@ public ServiceConnection(IServiceProtocol serviceProtocol,
_connectionFactory = connectionFactory;
_connectionDelegate = connectionDelegate;
_clientConnectionFactory = clientConnectionFactory;

_enableConnectionMigration = false;
}

protected override Task<ConnectionContext> CreateConnection(string target = null)
Expand Down Expand Up @@ -102,8 +107,16 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
connection.ServiceConnection = this;
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);
AddClientConnection(connection, message);

if (connection.IsMigrated)
{
Log.MigrationStarting(Logger, connection.ConnectionId);
}
else
{
Log.ConnectedStarting(Logger, connection.ConnectionId);
}

// Execute the application code
connection.ApplicationTask = _connectionDelegate(connection);
Expand All @@ -120,6 +133,17 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage)
{
var connectionId = closeConnectionMessage.ConnectionId;
if (_enableConnectionMigration && _clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context))
{
if (!context.HttpContext.Request.Headers.ContainsKey(Constants.AsrsMigrateOut))
{
context.HttpContext.Request.Headers.Add(Constants.AsrsMigrateOut, "");
}
// We have to prevent SignalR `{type: 7}` (close message) from reaching our client while doing migration.
// Since all user-created messages will be sent to `ServiceConnection` directly.
// We can simply ignore all messages came from the application pipe.
context.Application.Input.CancelPendingRead();
}
return PerformDisconnectAsyncCore(connectionId, false);
}

Expand Down Expand Up @@ -158,20 +182,72 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect
}
}

private async Task<bool> SkipHandshakeResponse(ClientConnectionContext connection, CancellationToken token)
{
try
{
while (true)
{
var result = await connection.Application.Input.ReadAsync(token);
if (result.IsCanceled || token.IsCancellationRequested)
{
return false;
}

var buffer = result.Buffer;
if (buffer.IsEmpty)
{
continue;
}

if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
{
connection.Application.Input.AdvanceTo(buffer.Start);
return true;
}

if (result.IsCompleted)
{
return false;
}
}
}
catch (Exception ex)
{
Log.ErrorSkippingHandshakeResponse(Logger, ex);
}
return false;
}

private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connection)
{
try
{
{
if (connection.IsMigrated)
{
using var source = new CancellationTokenSource(DefaultHandshakeTimeout);

// A handshake response is not expected to be given
// if the connection was migrated from another server,
// since the connection hasn't been `dropped` from the client point of view.
if (!await SkipHandshakeResponse(connection, source.Token))
{
return;
}
}

while (true)
{
var result = await connection.Application.Input.ReadAsync();

if (result.IsCanceled)
{
break;
}

var buffer = result.Buffer;
if (!buffer.IsEmpty)

if (!buffer.IsEmpty)
{
try
{
Expand Down Expand Up @@ -209,8 +285,9 @@ private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connecti
}
}

private void AddClientConnection(ClientConnectionContext connection, string instanceId)
private void AddClientConnection(ClientConnectionContext connection, OpenConnectionMessage message)
{
var instanceId = GetInstanceId(message.Headers);
_clientConnectionManager.AddClientConnection(connection);
_connectionIds.TryAdd(connection.ConnectionId, instanceId);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Tests
{
class TestClientConnectionFactory : IClientConnectionFactory
{
public IList<ClientConnectionContext> Connections = new List<ClientConnectionContext>();

public ClientConnectionContext CreateConnection(OpenConnectionMessage message, Action<HttpContext> configureContext = null)
{
var context = new ClientConnectionContext(message, configureContext);
Connections.Add(context);
return context;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,13 +13,20 @@ internal class TestConnectionFactory : IConnectionFactory
{
private readonly Func<TestConnection, Task> _connectCallback;

public IList<TestConnection> Connections = new List<TestConnection>();

public List<DateTime> Times { get; } = new List<DateTime>();

public TestConnectionFactory()
{
_connectCallback = null;
}

public TestConnectionFactory(Func<TestConnection, Task> connectCallback)
{
_connectCallback = connectCallback;
}
public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, TransferFormat transferFormat, string connectionId, string target,
CancellationToken cancellationToken = default, IDictionary<string, string> headers = null)
{
Expand All @@ -31,13 +37,18 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, T
ConnectionId = connectionId,
Target = target
};
Connections.Add(connection);

// Start a task to process handshake request from the newly-created server connection.
_ = HandshakeAsync(connection);

// Do something for test purpose
await AfterConnectedAsync(connection);

await _connectCallback(connection);
if (_connectCallback != null)
{
await _connectCallback(connection);
}

return connection;
}
Expand Down
Loading