Skip to content

Commit

Permalink
Redirect send2conn to SvcConn from Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Nov 26, 2019
1 parent d858f9d commit 28e969d
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 58 deletions.
4 changes: 2 additions & 2 deletions samples/ChatSample/ChatSample/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public Startup(IConfiguration configuration)
public void ConfigureServices(IServiceCollection services)
{
services.AddSignalR()
.AddAzureSignalR();
.AddAzureSignalR();
}

public void Configure(IApplicationBuilder app)
Expand All @@ -28,7 +28,7 @@ public void Configure(IApplicationBuilder app)
app.UseAzureSignalR(routes =>
{
routes.MapHub<Chat>("/chat");
routes.MapHub<BenchHub>("/bench");
routes.MapHub<BenchHub>("/signalrbench");
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNet.SignalR;
using Microsoft.AspNet.SignalR.Hosting;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta

public Dictionary<ServiceEndpoint, IServiceConnectionContainer> Connections { get; }

public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory)
public MultiEndpointServiceConnectionContainer(string hub,
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpointManager,
IMessageRouter router,
ILoggerFactory loggerFactory)
{
if (generator == null)
{
Expand Down Expand Up @@ -144,7 +148,6 @@ public Task WriteAsync(ServiceMessage serviceMessage)
{
return _inner.WriteAsync(serviceMessage);
}

return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c)
{
_ = WriteFinAsync(c);

var source = new CancellationTokenSource();
using var source = new CancellationTokenSource();
var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(RemoveFromServiceTimeout, source.Token));
source.Cancel();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down
44 changes: 16 additions & 28 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

Expand All @@ -18,16 +17,22 @@ internal class ServiceLifetimeManager<THub> : ServiceLifetimeManagerBase<THub> w
"'AddAzureSignalR(...)' was called without a matching call to 'IApplicationBuilder.UseAzureSignalR(...)'.";

private readonly ILogger<ServiceLifetimeManager<THub>> _logger;
private readonly IReadOnlyList<IHubProtocol> _allProtocols;

private readonly IServiceConnectionManager<THub> _serviceConnectionManager;
private readonly IClientConnectionManager _clientConnectionManager;

public ServiceLifetimeManager(IServiceConnectionManager<THub> serviceConnectionManager,
IClientConnectionManager clientConnectionManager, IHubProtocolResolver protocolResolver,
ILogger<ServiceLifetimeManager<THub>> logger, AzureSignalRMarkerService marker,
IOptions<HubOptions> globalHubOptions, IOptions<HubOptions<THub>> hubOptions)
: base(serviceConnectionManager, protocolResolver, globalHubOptions, hubOptions)
public ServiceLifetimeManager(
IServiceConnectionManager<THub> serviceConnectionManager,
IClientConnectionManager clientConnectionManager,
IHubProtocolResolver protocolResolver,
ILogger<ServiceLifetimeManager<THub>> logger,
AzureSignalRMarkerService marker,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions)
: base(
serviceConnectionManager,
protocolResolver,
globalHubOptions,
hubOptions)
{
// after core 3.0 UseAzureSignalR() is not required.
#if NETSTANDARD2_0
Expand All @@ -36,23 +41,10 @@ public ServiceLifetimeManager(IServiceConnectionManager<THub> serviceConnectionM
throw new InvalidOperationException(MarkerNotConfiguredError);
}
#endif

_serviceConnectionManager = serviceConnectionManager;
_clientConnectionManager = clientConnectionManager;
_allProtocols = protocolResolver.AllProtocols;
_logger = logger;
}

public override Task OnConnectedAsync(HubConnectionContext connection)
{
if (_clientConnectionManager.ClientConnections.TryGetValue(connection.ConnectionId, out var serviceConnectionContext))
{
serviceConnectionContext.HubConnectionContext = connection;
}

return Task.CompletedTask;
}

public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default)
{
if (IsInvalidArgument(connectionId))
Expand All @@ -67,15 +59,11 @@ public override Task SendConnectionAsync(string connectionId, string methodName,

if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var serviceConnectionContext))
{
var message = new InvocationMessage(methodName, args);

var message = new MultiConnectionDataMessage(new[] { connectionId }, SerializeAllProtocols(methodName, args));
// Write directly to this connection
return serviceConnectionContext.HubConnectionContext.WriteAsync(message).AsTask();
return serviceConnectionContext.ServiceConnection.WriteAsync(message);
}

return base.SendConnectionAsync(connectionId, methodName, args, cancellationToken);


}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ public void TickHeartbeat()

public Task ApplicationTask { get; set; }

// The associated HubConnectionContext
public HubConnectionContext HubConnectionContext { get; set; }
public ServiceConnectionBase ServiceConnection { get; set; }

public HttpContext HttpContext { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;


namespace Microsoft.Azure.SignalR
{
internal partial class ServiceConnection : ServiceConnectionBase
Expand Down Expand Up @@ -101,6 +101,7 @@ protected override ReadOnlyMemory<byte> GetPingMessage()
protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
connection.ServiceConnection = this;
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);

Expand Down
100 changes: 81 additions & 19 deletions test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
Expand All @@ -21,7 +23,9 @@ public class ServiceLifetimeManagerFacts
{
private static readonly List<string> TestUsers = new List<string> {"user1", "user2"};

private static readonly List<string> TestGroups = new List<string> {"group1", "group2"};
private static readonly List<string> TestGroups = new List<string> {"group1", "group2"};

private const string MockProtocol = "blazorpack";

private const string TestMethod = "TestMethod";

Expand Down Expand Up @@ -78,8 +82,14 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type)
public async void ServiceLifetimeManagerGroupTest(string functionName, Type type)
{
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions);
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(
serviceConnectionManager,
new ClientConnectionManager(),
HubProtocolResolver,
Logger,
Marker,
_globalHubOptions,
_localHubOptions);

await InvokeMethod(serviceLifetimeManager, functionName);

Expand Down Expand Up @@ -147,8 +157,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "json", "messagepack", "blazorpack" } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "json", "messagepack", "blazorpack" } });
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "json", "messagepack", MockProtocol } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "json", "messagepack", MockProtocol } });
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions);
Expand All @@ -169,18 +179,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct
[InlineData("SendUsersAsync", typeof(MultiUserDataMessage))]
public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functionName, Type type)
{
var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[]
{
new JsonHubProtocol(),
new MessagePackHubProtocol(),
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "blazorpack" } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "blazorpack" } });
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions);
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = MockLifetimeManager(serviceConnectionManager);

await InvokeMethod(serviceLifetimeManager, functionName);

Expand All @@ -189,6 +189,57 @@ public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functio
Assert.Equal(1, (serviceConnectionManager.ServiceMessage as MulticastDataMessage).Payloads.Count);
}

[Fact]
public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExisted()
{
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var clientConnectionManager = new ClientConnectionManager();

var context = new ClientConnectionContext(new OpenConnectionMessage("conn1", new Claim[] { }));
var connection = new TestServiceConnectionPrivate();
context.ServiceConnection = connection;
clientConnectionManager.AddClientConnection(context);

var manager = MockLifetimeManager(serviceConnectionManager, clientConnectionManager);

await manager.SendConnectionAsync("conn1", "foo", new object[] { 1, 2 });

Assert.NotNull(connection.last);
if (connection.last is MultiConnectionDataMessage m)
{
Assert.Equal("conn1", m.ConnectionList[0]);
Assert.Equal(1, m.Payloads.Count);
Assert.True(m.Payloads.ContainsKey(MockProtocol));
return;
}
Assert.True(false);
}

private HubLifetimeManager<TestHub> MockLifetimeManager(IServiceConnectionManager<TestHub> serviceConnectionManager, IClientConnectionManager clientConnectionManager = null)
{
clientConnectionManager ??= new ClientConnectionManager();

var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[]
{
new JsonHubProtocol(),
new MessagePackHubProtocol(),
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance
);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { MockProtocol } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { MockProtocol } });
return new ServiceLifetimeManager<TestHub>(
serviceConnectionManager,
clientConnectionManager,
protocolResolver,
Logger,
Marker,
globalHubOptions,
localHubOptions
);
}

private static async Task InvokeMethod(HubLifetimeManager<TestHub> serviceLifetimeManager, string methodName)
{
switch (methodName)
Expand Down Expand Up @@ -278,9 +329,20 @@ private static void VerifyServiceMessage(string methodName, ServiceMessage servi
}
}

private class CustomHubProtocol : IHubProtocol
private sealed class TestServiceConnectionPrivate : TestServiceConnection
{
public ServiceMessage last = null;

public override Task WriteAsync(ServiceMessage serviceMessage)
{
last = serviceMessage;
return Task.CompletedTask;
}
}

private sealed class CustomHubProtocol : IHubProtocol
{
public string Name => "blazorpack";
public string Name => MockProtocol;

public TransferFormat TransferFormat => throw new NotImplementedException();

Expand Down

0 comments on commit 28e969d

Please sign in to comment.