Skip to content

Commit

Permalink
server connection migration implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 28, 2019
1 parent 506b832 commit 33aefc5
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ internal static class Constants

public const int DefaultShutdownTimeoutInSeconds = 30;

public const string AsrsMigratedFrom = "Asrs-Migrated-From";
public const string AsrsUserAgent = "Asrs-User-Agent";
public const string AsrsInstanceId = "Asrs-Instance-Id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ internal class ClientConnectionContext : ConnectionContext,

private readonly TaskCompletionSource<object> _connectionEndTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

public Task CompleteTask => _connectionEndTcs.Task;
public Task CompleteTask => _connectionEndTcs.Task;

public readonly bool IsMigrated;

private readonly object _heartbeatLock = new object();
private List<(Action<object> handler, object state)> _heartbeatHandlers;
Expand All @@ -47,6 +49,11 @@ public ClientConnectionContext(OpenConnectionMessage serviceMessage, Action<Http
ConnectionId = serviceMessage.ConnectionId;
User = serviceMessage.GetUserPrincipal();

if (serviceMessage.Headers.TryGetValue(Constants.AsrsMigratedFrom, out _))
{
IsMigrated = true;
}

// Create the Duplix Pipeline for the virtual connection
transportPipeOptions = transportPipeOptions ?? DefaultPipeOptions;
appPipeOptions = appPipeOptions ?? DefaultPipeOptions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ private static class Log
private static readonly Action<ILogger, Exception> _applicationTaskTimedOut =
LoggerMessage.Define(LogLevel.Error, new EventId(21, "ApplicationTaskTimedOut"), "Timed out waiting for the application task to complete.");

private static readonly Action<ILogger, string, Exception> _migrationStarting =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server.");

private static readonly Action<ILogger, string, Exception> _errorSkippingHandshakeResponse =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(23, "ErrorSkippingHandshakeResponse"), "Error while skipping handshake response during migration, the connection will be dropped on the client-side. Error detail: {message}");

public static void FailedToCleanupConnections(ILogger logger, Exception exception)
{
_failedToCleanupConnections(logger, exception);
Expand Down Expand Up @@ -82,6 +88,11 @@ public static void ConnectedStarting(ILogger logger, string connectionId)
_connectedStarting(logger, connectionId, null);
}

public static void MigrationStarting(ILogger logger, string connectionId)
{
_migrationStarting(logger, connectionId, null);
}

public static void ConnectedEnding(ILogger logger, string connectionId)
{
_connectedEnding(logger, connectionId, null);
Expand All @@ -101,6 +112,11 @@ public static void ApplicationTaskTimedOut(ILogger logger)
{
_applicationTaskTimedOut(logger, null);
}

public static void ErrorSkippingHandshakeResponse(ILogger logger, Exception ex)
{
_errorSkippingHandshakeResponse(logger, ex.Message, ex);
}
}
}
}
74 changes: 69 additions & 5 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;

using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol;

namespace Microsoft.Azure.SignalR
{
internal partial class ServiceConnection : ServiceConnectionBase
Expand All @@ -22,6 +23,8 @@ internal partial class ServiceConnection : ServiceConnectionBase
private static readonly Dictionary<string, string> CustomHeader = new Dictionary<string, string> { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } };
private static readonly TimeSpan CloseTimeout = TimeSpan.FromSeconds(5);

private readonly bool _enableConnectionMigration;

private const string ClientConnectionCountInHub = "#clientInHub";
private const string ClientConnectionCountInServiceConnection = "#client";

Expand Down Expand Up @@ -53,6 +56,8 @@ public ServiceConnection(IServiceProtocol serviceProtocol,
_connectionFactory = connectionFactory;
_connectionDelegate = connectionDelegate;
_clientConnectionFactory = clientConnectionFactory;

_enableConnectionMigration = false;
}

protected override Task<ConnectionContext> CreateConnection(string target = null)
Expand Down Expand Up @@ -102,8 +107,16 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
connection.ServiceConnection = this;
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);
AddClientConnection(connection, message);

if (connection.IsMigrated)
{
Log.MigrationStarting(Logger, connection.ConnectionId);
}
else
{
Log.ConnectedStarting(Logger, connection.ConnectionId);
}

// Execute the application code
connection.ApplicationTask = _connectionDelegate(connection);
Expand All @@ -120,6 +133,10 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeConnectionMessage)
{
var connectionId = closeConnectionMessage.ConnectionId;
if (_enableConnectionMigration && _clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var context))
{
context.Application.Input.CancelPendingRead();
}
return PerformDisconnectAsyncCore(connectionId, false);
}

Expand Down Expand Up @@ -158,20 +175,66 @@ protected override async Task OnClientMessageAsync(ConnectionDataMessage connect
}
}

private async Task SkipHandshakeResponse(ClientConnectionContext connection)
{
try
{
while (true)
{
var result = await connection.Application.Input.ReadAsync();
if (result.IsCanceled)
{
return;
}

var buffer = result.Buffer;
if (buffer.IsEmpty)
{
continue;
}

if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
{
connection.Application.Input.AdvanceTo(buffer.Start);
return;
}

if (result.IsCompleted)
{
return;
}
}
}
catch (Exception ex)
{
Log.ErrorSkippingHandshakeResponse(Logger, ex);
}
}

private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connection)
{
if (connection.IsMigrated)
{
// A handshake response is not expected to be given
// if the connection was migrated from another server,
// since the connection hasn't been `dropped` from the client point of view.
await SkipHandshakeResponse(connection);
}

try
{
while (true)
{
var result = await connection.Application.Input.ReadAsync();

if (result.IsCanceled)
{
break;
}

var buffer = result.Buffer;
if (!buffer.IsEmpty)

if (!buffer.IsEmpty)
{
try
{
Expand Down Expand Up @@ -209,8 +272,9 @@ private async Task ProcessOutgoingMessagesAsync(ClientConnectionContext connecti
}
}

private void AddClientConnection(ClientConnectionContext connection, string instanceId)
private void AddClientConnection(ClientConnectionContext connection, OpenConnectionMessage message)
{
var instanceId = GetInstanceId(message.Headers);
_clientConnectionManager.AddClientConnection(connection);
_connectionIds.TryAdd(connection.ConnectionId, instanceId);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Tests
{
class TestClientConnectionFactory : IClientConnectionFactory
{
public IList<ClientConnectionContext> Connections = new List<ClientConnectionContext>();

public ClientConnectionContext CreateConnection(OpenConnectionMessage message, Action<HttpContext> configureContext = null)
{
var context = new ClientConnectionContext(message, configureContext);
Connections.Add(context);
return context;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,13 +13,20 @@ internal class TestConnectionFactory : IConnectionFactory
{
private readonly Func<TestConnection, Task> _connectCallback;

public IList<TestConnection> Connections = new List<TestConnection>();

public List<DateTime> Times { get; } = new List<DateTime>();

public TestConnectionFactory()
{
_connectCallback = null;
}

public TestConnectionFactory(Func<TestConnection, Task> connectCallback)
{
_connectCallback = connectCallback;
}

public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, TransferFormat transferFormat, string connectionId, string target,
CancellationToken cancellationToken = default, IDictionary<string, string> headers = null)
{
Expand All @@ -31,13 +37,18 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, T
ConnectionId = connectionId,
Target = target
};
Connections.Add(connection);

// Start a task to process handshake request from the newly-created server connection.
_ = HandshakeAsync(connection);

// Do something for test purpose
await AfterConnectedAsync(connection);

await _connectCallback(connection);
if (null != _connectCallback)
{
await _connectCallback(connection);
}

return connection;
}
Expand Down
90 changes: 89 additions & 1 deletion test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,32 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Azure.SignalR.Tests
{
public class ServiceConnectionTests : VerifiableLoggedTest
{
private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(
pauseWriterThreshold: 0,
resumeWriterThreshold: 0,
readerScheduler: PipeScheduler.ThreadPool,
useSynchronizationContext: false);

public ServiceConnectionTests(ITestOutputHelper output) : base(output)
{
}
Expand Down Expand Up @@ -196,6 +205,85 @@ await transportConnection.Application.Output.WriteAsync(
}
}

private class TestServiceConnection : ServiceConnection
{
public TestServiceConnection(IConnectionFactory serviceConnectionFactory,
IClientConnectionFactory clientConnectionFactory,
ILoggerFactory loggerFactory,
ConnectionDelegate handler) : base(
new ServiceProtocol(),
new TestClientConnectionManager(),
serviceConnectionFactory,
loggerFactory,
handler,
clientConnectionFactory,
Guid.NewGuid().ToString("N"),
null,
null
)
{
}

public Task OnClientConnectedAsyncForTest(OpenConnectionMessage message)
{
return base.OnClientConnectedAsync(message);
}
}

private TestServiceConnection MockServiceConnection(IConnectionFactory serviceConnectionFactory = null,
IClientConnectionFactory clientConnectionFactory = null,
ILoggerFactory loggerFactory = null)
{
clientConnectionFactory ??= new ClientConnectionFactory();
serviceConnectionFactory ??= new TestConnectionFactory(conn => Task.CompletedTask);
loggerFactory ??= NullLoggerFactory.Instance;

var services = new ServiceCollection();
var connectionHandler = new EndlessConnectionHandler();
services.AddSingleton(connectionHandler);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<EndlessConnectionHandler>();
ConnectionDelegate handler = builder.Build();

return new TestServiceConnection(
serviceConnectionFactory,
clientConnectionFactory,
loggerFactory,
handler
);
}

[Fact]
public async void ServiceConnectionShouldIgnoreFirstHandshakeResponse()
{
var factory = new TestClientConnectionFactory();
var connection = MockServiceConnection(null, factory);

// create a connection with migration header.
await connection.OnClientConnectedAsyncForTest(new OpenConnectionMessage("foo", new Claim[0])
{
Headers = new Dictionary<string, StringValues>{
{ Constants.AsrsMigratedFrom, "another-server" }
}
});

Assert.Equal(1, factory.Connections.Count);
var context = factory.Connections[0];
Assert.True(context.IsMigrated);

var message = new AspNetCore.SignalR.Protocol.HandshakeResponseMessage("");
HandshakeProtocol.WriteResponseMessage(message, context.Transport.Output);
await context.Transport.Output.FlushAsync();

var task = context.Transport.Input.ReadAsync();
await Task.Delay(100);

// nothing should be written into the transport
Assert.False(task.IsCompleted);
// but the `migrated` status should remain False (readonly)
Assert.True(context.IsMigrated);
}

private sealed class TestConnectionHandler : ConnectionHandler
{
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
Expand Down
Loading

0 comments on commit 33aefc5

Please sign in to comment.