Skip to content

Commit

Permalink
server connection migration implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 19, 2019
1 parent 36228ea commit c2c6f08
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 11 deletions.
2 changes: 1 addition & 1 deletion samples/ChatSample/ChatSample/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public void ConfigureServices(IServiceCollection services)
}

public void Configure(IApplicationBuilder app)
{
{
app.UseFileServer();
app.UseAzureSignalR(routes =>
{
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ internal static class Constants

public const int DefaultShutdownTimeoutInSeconds = 30;

public const string AsrsMigratedFrom = "Asrs-Migrated-From";
public const string AsrsUserAgent = "Asrs-User-Agent";
public const string AsrsInstanceId = "Asrs-Instance-Id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ private static class Log
private static readonly Action<ILogger, string, Exception> _connectedStarting =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(11, "ConnectedStarting"), "Connection {TransportConnectionId} started.");

private static readonly Action<ILogger, string, Exception> _migrationStarting =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(22, "MigrationStarting"), "Connection {TransportConnectionId} migrated from another server.");

private static readonly Action<ILogger, string, Exception> _connectedEnding =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(12, "ConnectedEnding"), "Connection {TransportConnectionId} ended.");

Expand Down Expand Up @@ -82,6 +85,11 @@ public static void ConnectedStarting(ILogger logger, string connectionId)
_connectedStarting(logger, connectionId, null);
}

public static void MigrationStarting(ILogger logger, string connectionId)
{
_migrationStarting(logger, connectionId, null);
}

public static void ConnectedEnding(ILogger logger, string connectionId)
{
_connectedEnding(logger, connectionId, null);
Expand Down
28 changes: 25 additions & 3 deletions src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using SignalRProtocol = Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;
Expand Down Expand Up @@ -113,6 +114,17 @@ private async Task ProcessOutgoingMessagesAsync(ServiceConnectionContext connect
var buffer = result.Buffer;
if (!buffer.IsEmpty)
{
// We assume the first response message would be a HandshakeResponse.
if (connection.IsMigrated)
{
if (SignalRProtocol.HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
{
connection.IsMigrated = false;
connection.Application.Input.AdvanceTo(buffer.End);
}
continue;
}

try
{
// Forward the message to the service
Expand Down Expand Up @@ -149,17 +161,27 @@ private async Task ProcessOutgoingMessagesAsync(ServiceConnectionContext connect
}
}

private void AddClientConnection(ServiceConnectionContext connection, string instanceId)
private void AddClientConnection(ServiceConnectionContext connection, OpenConnectionMessage message)
{
var instanceId = GetInstanceId(message.Headers);

_clientConnectionManager.AddClientConnection(connection);
_connectionIds.TryAdd(connection.ConnectionId, instanceId);
}

protected override Task OnConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);
AddClientConnection(connection, message);

if (connection.IsMigrated)
{
Log.MigrationStarting(Logger, connection.ConnectionId);
}
else
{
Log.ConnectedStarting(Logger, connection.ConnectionId);
}

// Execute the application code
connection.ApplicationTask = _connectionDelegate(connection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ internal class ServiceConnectionContext : ConnectionContext,
private readonly TaskCompletionSource<object> _connectionEndTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

public Task CompleteTask => _connectionEndTcs.Task;

public bool IsMigrated { get; set; }

private readonly object _heartbeatLock = new object();
private List<(Action<object> handler, object state)> _heartbeatHandlers;
Expand All @@ -47,6 +49,11 @@ public ServiceConnectionContext(OpenConnectionMessage serviceMessage, Action<Htt
ConnectionId = serviceMessage.ConnectionId;
User = serviceMessage.GetUserPrincipal();

if (serviceMessage.Headers.TryGetValue(Constants.AsrsMigratedFrom, out _))
{
IsMigrated = true;
}

// Create the Duplix Pipeline for the virtual connection
transportPipeOptions = transportPipeOptions ?? DefaultPipeOptions;
appPipeOptions = appPipeOptions ?? DefaultPipeOptions;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;

namespace Microsoft.Azure.SignalR.Tests
{
class TestClientConnectionFactory : IClientConnectionFactory
{
public IList<ServiceConnectionContext> Connections = new List<ServiceConnectionContext>();

public ServiceConnectionContext CreateConnection(OpenConnectionMessage message, Action<HttpContext> configureContext = null)
{
var context = new ServiceConnectionContext(message, configureContext);
Connections.Add(context);
return context;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -14,13 +13,20 @@ internal class TestConnectionFactory : IConnectionFactory
{
private readonly Func<TestConnection, Task> _connectCallback;

public IList<TestConnection> Connections = new List<TestConnection>();

public List<DateTime> Times { get; } = new List<DateTime>();

public TestConnectionFactory()
{
_connectCallback = null;
}

public TestConnectionFactory(Func<TestConnection, Task> connectCallback)
{
_connectCallback = connectCallback;
}

public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, TransferFormat transferFormat, string connectionId, string target,
CancellationToken cancellationToken = default, IDictionary<string, string> headers = null)
{
Expand All @@ -31,13 +37,18 @@ public async Task<ConnectionContext> ConnectAsync(HubServiceEndpoint endpoint, T
ConnectionId = connectionId,
Target = target
};
Connections.Add(connection);

// Start a task to process handshake request from the newly-created server connection.
_ = HandshakeAsync(connection);

// Do something for test purpose
await AfterConnectedAsync(connection);

await _connectCallback(connection);
if (null != _connectCallback)
{
await _connectCallback(connection);
}

return connection;
}
Expand Down
89 changes: 88 additions & 1 deletion test/Microsoft.Azure.SignalR.Tests/ServiceConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,32 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Azure.SignalR.Tests
{
public class ServiceConnectionTests : VerifiableLoggedTest
{
private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(
pauseWriterThreshold: 0,
resumeWriterThreshold: 0,
readerScheduler: PipeScheduler.ThreadPool,
useSynchronizationContext: false);

public ServiceConnectionTests(ITestOutputHelper output) : base(output)
{
}
Expand Down Expand Up @@ -196,6 +205,84 @@ await transportConnection.Application.Output.WriteAsync(
}
}

private class TestServiceConnection : ServiceConnection
{
public TestServiceConnection(IConnectionFactory serviceConnectionFactory,
IClientConnectionFactory clientConnectionFactory,
ILoggerFactory loggerFactory,
ConnectionDelegate handler) : base(
new ServiceProtocol(),
new TestClientConnectionManager(),
serviceConnectionFactory,
loggerFactory,
handler,
clientConnectionFactory,
Guid.NewGuid().ToString("N"),
null,
null
)
{
}

public Task OnConnectedAsyncForTest(OpenConnectionMessage message)
{
return base.OnConnectedAsync(message);
}
}

private TestServiceConnection MockServiceConnection(IConnectionFactory serviceConnectionFactory = null,
IClientConnectionFactory clientConnectionFactory = null,
ILoggerFactory loggerFactory = null)
{
clientConnectionFactory ??= new ClientConnectionFactory();
serviceConnectionFactory ??= new TestConnectionFactory(conn => Task.CompletedTask);
loggerFactory ??= NullLoggerFactory.Instance;

var services = new ServiceCollection();
var connectionHandler = new EndlessConnectionHandler();
services.AddSingleton(connectionHandler);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<EndlessConnectionHandler>();
ConnectionDelegate handler = builder.Build();

return new TestServiceConnection(
serviceConnectionFactory,
clientConnectionFactory,
loggerFactory,
handler
);
}

[Fact]
public async void ServiceConnectionShouldIgnoreFirstHandshakeResponse()
{
var factory = new TestClientConnectionFactory();
var connection = MockServiceConnection(null, factory);

// create a connection with migration header.
await connection.OnConnectedAsyncForTest(new OpenConnectionMessage("foo", new Claim[0])
{
Headers = new Dictionary<string, StringValues>{
{ Constants.AsrsMigratedFrom, "another-server" }
}
});

Assert.Equal(1, factory.Connections.Count);
var context = factory.Connections[0];
Assert.True(context.IsMigrated);

var message = new AspNetCore.SignalR.Protocol.HandshakeResponseMessage("");
HandshakeProtocol.WriteResponseMessage(message, context.Transport.Output);
await context.Transport.Output.FlushAsync();

var task = context.Transport.Input.ReadAsync();
await Task.Delay(100);

// nothing should be written into the transport
Assert.False(task.IsCompleted);
Assert.False(context.IsMigrated);
}

private sealed class TestConnectionHandler : ConnectionHandler
{
private TaskCompletionSource<object> _startedTcs = new TaskCompletionSource<object>();
Expand Down
22 changes: 19 additions & 3 deletions test/Microsoft.Azure.SignalR.Tests/ServiceContextFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Security.Claims;
using Microsoft.AspNetCore.Http;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Primitives;
using Xunit;

Expand Down Expand Up @@ -42,7 +44,7 @@ public void ServiceConnectionContextWithSystemClaimsIsUnauthenticated()
new Claim("exp", "1234567890"),
new Claim("iat", "1234567890"),
new Claim("nbf", "1234567890"),
new Claim(Constants.ClaimType.UserId, "customUserId"),
new Claim(Constants.ClaimType.UserId, "customUserId"),
};
var serviceConnectionContext = new ServiceConnectionContext(new OpenConnectionMessage("1", claims));
Assert.NotNull(serviceConnectionContext.User.Identity);
Expand Down Expand Up @@ -106,9 +108,9 @@ public void ServiceConnectionContextWithNonEmptyHeaders()
const string key1 = "header-key-1";
const string key2 = "header-key-2";
const string value1 = "header-value-1";
var value2 = new[] {"header-value-2a", "header-value-2b"};
var value2 = new[] { "header-value-2a", "header-value-2b" };
var serviceConnectionContext = new ServiceConnectionContext(new OpenConnectionMessage("1", new Claim[0],
new Dictionary<string, StringValues> (StringComparer.OrdinalIgnoreCase)
new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase)
{
{key1, value1},
{key2, value2}
Expand Down Expand Up @@ -161,6 +163,20 @@ public void ServiceConnectionContextWithRequestPath()
Assert.Equal(path, request.Path);
}

[Fact]
public void ServiceConnectionShouldBeMigrated()
{
var open = new OpenConnectionMessage("foo", new Claim[0]);
var context = new ServiceConnectionContext(open);
Assert.False(context.IsMigrated);

open.Headers = new Dictionary<string, StringValues>{
{ Constants.AsrsMigratedFrom, "another-server" }
};
context = new ServiceConnectionContext(open);
Assert.True(context.IsMigrated);
}

[Theory]
[InlineData("1.1.1.1", true, "1.1.1.1")]
[InlineData("1.1.1.1, 2.2.2.2", true, "1.1.1.1")]
Expand Down

0 comments on commit c2c6f08

Please sign in to comment.