diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index e1b1362ba..fcf009216 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -118,7 +118,7 @@ public Task StartAsync() /// Start and manage the whole connection lifetime /// /// - protected async Task StartCoreAsync(IServiceConnection connection, string target = null) + protected virtual async Task StartCoreAsync(IServiceConnection connection, string target = null) { try { diff --git a/src/Microsoft.Azure.SignalR.Management/Properties/AssemblyInfo.cs b/src/Microsoft.Azure.SignalR.Management/Properties/AssemblyInfo.cs index 2b21cb22f..d9d6b50bc 100644 --- a/src/Microsoft.Azure.SignalR.Management/Properties/AssemblyInfo.cs +++ b/src/Microsoft.Azure.SignalR.Management/Properties/AssemblyInfo.cs @@ -1,3 +1,4 @@ using System.Runtime.CompilerServices; -[assembly : InternalsVisibleTo("Microsoft.Azure.SignalR.Management.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file +[assembly : InternalsVisibleTo("Microsoft.Azure.SignalR.Management.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly : InternalsVisibleTo("Microsoft.Azure.SignalR.E2ETests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Microsoft.Azure.SignalR.Management/ServiceHubContext.cs b/src/Microsoft.Azure.SignalR.Management/ServiceHubContext.cs index 419a5b487..5fcc6ecbe 100644 --- a/src/Microsoft.Azure.SignalR.Management/ServiceHubContext.cs +++ b/src/Microsoft.Azure.SignalR.Management/ServiceHubContext.cs @@ -3,6 +3,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Common.ServiceConnections; using Microsoft.Extensions.DependencyInjection; namespace Microsoft.Azure.SignalR.Management @@ -31,7 +32,14 @@ public async Task DisposeAsync() _serviceProvider?.Dispose(); } - private Task StopConnectionAsync() + // for test only + public ServiceConnectionStatus GetConnectionStatus() + { + var container = _serviceProvider.GetService(); + return ((ManagementServiceConnectionContainer)container).GetServiceConnectionStatus(); + } + + public Task StopConnectionAsync() { var serviceConnectionManager = _serviceProvider.GetService>(); if (serviceConnectionManager == null) diff --git a/src/Microsoft.Azure.SignalR.Management/ServiceManager.cs b/src/Microsoft.Azure.SignalR.Management/ServiceManager.cs index 7d144f5bb..7005b9a43 100644 --- a/src/Microsoft.Azure.SignalR.Management/ServiceManager.cs +++ b/src/Microsoft.Azure.SignalR.Management/ServiceManager.cs @@ -49,7 +49,7 @@ public async Task CreateHubContextAsync(string hubName, ILog var clientConnectionFactory = new ClientConnectionFactory(); ConnectionDelegate connectionDelegate = connectionContext => Task.CompletedTask; var serviceConnectionFactory = new ServiceConnectionFactory(serviceProtocol, clientConnectionManager, connectionFactory, loggerFactory, connectionDelegate, clientConnectionFactory); - var weakConnectionContainer = new WeakServiceConnectionContainer(serviceConnectionFactory, _serviceManagerOptions.ConnectionCount, new HubServiceEndpoint(hubName, _endpointProvider, _endpoint)); + var managementConnectionContainer = new ManagementServiceConnectionContainer(serviceConnectionFactory, _serviceManagerOptions.ConnectionCount, new HubServiceEndpoint(hubName, _endpointProvider, _endpoint)); var serviceCollection = new ServiceCollection(); serviceCollection.AddSignalRCore(); @@ -64,7 +64,7 @@ public async Task CreateHubContextAsync(string hubName, ILog .AddSingleton(typeof(IConnectionFactory), sp => connectionFactory) .AddSingleton(typeof(HubLifetimeManager<>), typeof(WebSocketsHubLifetimeManager<>)) .AddSingleton(typeof(IServiceConnectionManager<>), typeof(ServiceConnectionManager<>)) - .AddSingleton(typeof(IServiceConnectionContainer), sp => weakConnectionContainer); + .AddSingleton(typeof(IServiceConnectionContainer), sp => managementConnectionContainer); var success = false; ServiceProvider serviceProvider = null; @@ -73,11 +73,11 @@ public async Task CreateHubContextAsync(string hubName, ILog serviceProvider = serviceCollection.BuildServiceProvider(); var serviceConnectionManager = serviceProvider.GetRequiredService>(); - serviceConnectionManager.SetServiceConnection(weakConnectionContainer); + serviceConnectionManager.SetServiceConnection(managementConnectionContainer); _ = serviceConnectionManager.StartAsync(); // wait until service connection established - await weakConnectionContainer.ConnectionInitializedTask.OrTimeout(cancellationToken); + await managementConnectionContainer.ConnectionInitializedTask.OrTimeout(cancellationToken); var webSocketsHubLifetimeManager = (WebSocketsHubLifetimeManager)serviceProvider.GetRequiredService>(); diff --git a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs index 59cbc6182..fe606698f 100644 --- a/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs +++ b/test/Microsoft.Azure.SignalR.E2ETests/Management/ServiceHubContextE2EFacts.cs @@ -171,9 +171,9 @@ internal async Task SendToConnectionTest(ServiceTransportType serviceTransportTy try { await RunTestCore(clientEndpoint, clientAccessTokens, - async () => + async () => { - var connectionId = await task.OrTimeout(); + var connectionId = await SignalR.Tests.Common.TaskExtensions.OrTimeout(task); await serviceHubContext.Clients.Client(connectionId).SendAsync(MethodName, Message); }, 1, receivedMessageDict); @@ -201,7 +201,7 @@ internal async Task ConnectionJoinLeaveGroupTest(ServiceTransportType serviceTra await RunTestCore(clientEndpoint, clientAccessTokens, async () => { - var connectionId = await task.OrTimeout(); + var connectionId = await SignalR.Tests.Common.TaskExtensions.OrTimeout(task); await serviceHubContext.Groups.AddToGroupAsync(connectionId, _groupNames[0]); await serviceHubContext.Clients.Group(_groupNames[0]).SendAsync(MethodName, Message); // We can't guarantee the order between the send group and the following leave group diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TaskExtensions.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TaskExtensions.cs index 589d2e97a..20e60606e 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TaskExtensions.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TaskExtensions.cs @@ -1,10 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; -namespace System.Threading.Tasks +namespace Microsoft.Azure.SignalR.Tests.Common { public static class TaskExtensions { diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index d88e6d111..61e0cfc0c 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -8,6 +8,7 @@ using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests.Common; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Xunit;