diff --git a/samples/ChatSample/ChatSample/Startup.cs b/samples/ChatSample/ChatSample/Startup.cs index 3836d1042..3cf7a5f43 100644 --- a/samples/ChatSample/ChatSample/Startup.cs +++ b/samples/ChatSample/ChatSample/Startup.cs @@ -19,7 +19,7 @@ public Startup(IConfiguration configuration) public void ConfigureServices(IServiceCollection services) { services.AddSignalR() - .AddAzureSignalR(); + .AddAzureSignalR(); } public void Configure(IApplicationBuilder app) @@ -28,7 +28,7 @@ public void Configure(IApplicationBuilder app) app.UseAzureSignalR(routes => { routes.MapHub("/chat"); - routes.MapHub("/bench"); + routes.MapHub("/signalrbench"); }); } } diff --git a/src/Microsoft.Azure.SignalR.AspNet/ClientConnections/ClientConnectionManager.cs b/src/Microsoft.Azure.SignalR.AspNet/ClientConnections/ClientConnectionManager.cs index 23b1ef016..4e56f003b 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ClientConnections/ClientConnectionManager.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ClientConnections/ClientConnectionManager.cs @@ -5,7 +5,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; -using System.Linq; using System.Threading.Tasks; using Microsoft.AspNet.SignalR; using Microsoft.AspNet.SignalR.Hosting; diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs index 7c277271c..0f16adb97 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IServiceConnectionContainer.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index 8aaf53140..227ec8e24 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -23,7 +23,11 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta public Dictionary Connections { get; } - public MultiEndpointServiceConnectionContainer(string hub, Func generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory) + public MultiEndpointServiceConnectionContainer(string hub, + Func generator, + IServiceEndpointManager endpointManager, + IMessageRouter router, + ILoggerFactory loggerFactory) { if (generator == null) { @@ -144,7 +148,6 @@ public Task WriteAsync(ServiceMessage serviceMessage) { return _inner.WriteAsync(serviceMessage); } - return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage)); } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 5218cb067..084ca5a38 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -359,7 +359,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c) { _ = WriteFinAsync(c); - var source = new CancellationTokenSource(); + using var source = new CancellationTokenSource(); var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(RemoveFromServiceTimeout, source.Token)); source.Cancel(); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs index ccc5a802e..dd882c7f2 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/StrongServiceConnectionContainer.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 4726fa555..d0b138d17 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -2,11 +2,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; -using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -18,16 +17,22 @@ internal class ServiceLifetimeManager : ServiceLifetimeManagerBase w "'AddAzureSignalR(...)' was called without a matching call to 'IApplicationBuilder.UseAzureSignalR(...)'."; private readonly ILogger> _logger; - private readonly IReadOnlyList _allProtocols; - private readonly IServiceConnectionManager _serviceConnectionManager; private readonly IClientConnectionManager _clientConnectionManager; - public ServiceLifetimeManager(IServiceConnectionManager serviceConnectionManager, - IClientConnectionManager clientConnectionManager, IHubProtocolResolver protocolResolver, - ILogger> logger, AzureSignalRMarkerService marker, - IOptions globalHubOptions, IOptions> hubOptions) - : base(serviceConnectionManager, protocolResolver, globalHubOptions, hubOptions) + public ServiceLifetimeManager( + IServiceConnectionManager serviceConnectionManager, + IClientConnectionManager clientConnectionManager, + IHubProtocolResolver protocolResolver, + ILogger> logger, + AzureSignalRMarkerService marker, + IOptions globalHubOptions, + IOptions> hubOptions) + : base( + serviceConnectionManager, + protocolResolver, + globalHubOptions, + hubOptions) { // after core 3.0 UseAzureSignalR() is not required. #if NETSTANDARD2_0 @@ -36,23 +41,10 @@ public ServiceLifetimeManager(IServiceConnectionManager serviceConnectionM throw new InvalidOperationException(MarkerNotConfiguredError); } #endif - - _serviceConnectionManager = serviceConnectionManager; _clientConnectionManager = clientConnectionManager; - _allProtocols = protocolResolver.AllProtocols; _logger = logger; } - public override Task OnConnectedAsync(HubConnectionContext connection) - { - if (_clientConnectionManager.ClientConnections.TryGetValue(connection.ConnectionId, out var serviceConnectionContext)) - { - serviceConnectionContext.HubConnectionContext = connection; - } - - return Task.CompletedTask; - } - public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) { if (IsInvalidArgument(connectionId)) @@ -67,15 +59,11 @@ public override Task SendConnectionAsync(string connectionId, string methodName, if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var serviceConnectionContext)) { - var message = new InvocationMessage(methodName, args); - + var message = new MultiConnectionDataMessage(new[] { connectionId }, SerializeAllProtocols(methodName, args)); // Write directly to this connection - return serviceConnectionContext.HubConnectionContext.WriteAsync(message).AsTask(); + return serviceConnectionContext.ServiceConnection.WriteAsync(message); } - return base.SendConnectionAsync(connectionId, methodName, args, cancellationToken); - - } } } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index c16fa48c6..8bef08107 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -111,8 +111,7 @@ public void TickHeartbeat() public Task ApplicationTask { get; set; } - // The associated HubConnectionContext - public HubConnectionContext HubConnectionContext { get; set; } + public ServiceConnectionBase ServiceConnection { get; set; } public HttpContext HttpContext { get; set; } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index a6a68fd1c..832ce682d 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -12,7 +12,7 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; - + namespace Microsoft.Azure.SignalR { internal partial class ServiceConnection : ServiceConnectionBase @@ -101,6 +101,7 @@ protected override ReadOnlyMemory GetPingMessage() protected override Task OnClientConnectedAsync(OpenConnectionMessage message) { var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext); + connection.ServiceConnection = this; AddClientConnection(connection, GetInstanceId(message.Headers)); Log.ConnectedStarting(Logger, connection.ConnectionId); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index c922b5d3f..a60724327 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -4,12 +4,14 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; @@ -21,7 +23,9 @@ public class ServiceLifetimeManagerFacts { private static readonly List TestUsers = new List {"user1", "user2"}; - private static readonly List TestGroups = new List {"group1", "group2"}; + private static readonly List TestGroups = new List {"group1", "group2"}; + + private const string MockProtocol = "blazorpack"; private const string TestMethod = "TestMethod"; @@ -78,8 +82,14 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type) public async void ServiceLifetimeManagerGroupTest(string functionName, Type type) { var serviceConnectionManager = new TestServiceConnectionManager(); - var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, - new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions); + var serviceLifetimeManager = new ServiceLifetimeManager( + serviceConnectionManager, + new ClientConnectionManager(), + HubProtocolResolver, + Logger, + Marker, + _globalHubOptions, + _localHubOptions); await InvokeMethod(serviceLifetimeManager, functionName); @@ -147,8 +157,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct new CustomHubProtocol(), }, NullLogger.Instance); - IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", "blazorpack" } }); - IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", "blazorpack" } }); + IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol } }); + IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol } }); var serviceConnectionManager = new TestServiceConnectionManager(); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions); @@ -169,18 +179,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct [InlineData("SendUsersAsync", typeof(MultiUserDataMessage))] public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functionName, Type type) { - var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[] - { - new JsonHubProtocol(), - new MessagePackHubProtocol(), - new CustomHubProtocol(), - }, - NullLogger.Instance); - IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "blazorpack" } }); - IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "blazorpack" } }); - var serviceConnectionManager = new TestServiceConnectionManager(); - var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, - new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions); + var serviceConnectionManager = new TestServiceConnectionManager(); + var serviceLifetimeManager = MockLifetimeManager(serviceConnectionManager); await InvokeMethod(serviceLifetimeManager, functionName); @@ -189,6 +189,57 @@ public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functio Assert.Equal(1, (serviceConnectionManager.ServiceMessage as MulticastDataMessage).Payloads.Count); } + [Fact] + public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExisted() + { + var serviceConnectionManager = new TestServiceConnectionManager(); + var clientConnectionManager = new ClientConnectionManager(); + + var context = new ClientConnectionContext(new OpenConnectionMessage("conn1", new Claim[] { })); + var connection = new TestServiceConnectionPrivate(); + context.ServiceConnection = connection; + clientConnectionManager.AddClientConnection(context); + + var manager = MockLifetimeManager(serviceConnectionManager, clientConnectionManager); + + await manager.SendConnectionAsync("conn1", "foo", new object[] { 1, 2 }); + + Assert.NotNull(connection.last); + if (connection.last is MultiConnectionDataMessage m) + { + Assert.Equal("conn1", m.ConnectionList[0]); + Assert.Equal(1, m.Payloads.Count); + Assert.True(m.Payloads.ContainsKey(MockProtocol)); + return; + } + Assert.True(false); + } + + private HubLifetimeManager MockLifetimeManager(IServiceConnectionManager serviceConnectionManager, IClientConnectionManager clientConnectionManager = null) + { + clientConnectionManager ??= new ClientConnectionManager(); + + var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(), + new MessagePackHubProtocol(), + new CustomHubProtocol(), + }, + NullLogger.Instance + ); + IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); + IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); + return new ServiceLifetimeManager( + serviceConnectionManager, + clientConnectionManager, + protocolResolver, + Logger, + Marker, + globalHubOptions, + localHubOptions + ); + } + private static async Task InvokeMethod(HubLifetimeManager serviceLifetimeManager, string methodName) { switch (methodName) @@ -278,9 +329,20 @@ private static void VerifyServiceMessage(string methodName, ServiceMessage servi } } - private class CustomHubProtocol : IHubProtocol + private sealed class TestServiceConnectionPrivate : TestServiceConnection + { + public ServiceMessage last = null; + + public override Task WriteAsync(ServiceMessage serviceMessage) + { + last = serviceMessage; + return Task.CompletedTask; + } + } + + private sealed class CustomHubProtocol : IHubProtocol { - public string Name => "blazorpack"; + public string Name => MockProtocol; public TransferFormat TransferFormat => throw new NotImplementedException();