Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redirect send2conn from sending to Hub to SvcConn #744

Merged
merged 1 commit into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions samples/ChatSample/ChatSample/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public Startup(IConfiguration configuration)
public void ConfigureServices(IServiceCollection services)
{
services.AddSignalR()
.AddAzureSignalR();
.AddAzureSignalR();
}

public void Configure(IApplicationBuilder app)
Expand All @@ -28,7 +28,7 @@ public void Configure(IApplicationBuilder app)
app.UseAzureSignalR(routes =>
{
routes.MapHub<Chat>("/chat");
routes.MapHub<BenchHub>("/bench");
routes.MapHub<BenchHub>("/signalrbench");
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNet.SignalR;
using Microsoft.AspNet.SignalR.Hosting;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ internal class MultiEndpointServiceConnectionContainer : IServiceConnectionConta

public Dictionary<ServiceEndpoint, IServiceConnectionContainer> Connections { get; }

public MultiEndpointServiceConnectionContainer(string hub, Func<HubServiceEndpoint, IServiceConnectionContainer> generator, IServiceEndpointManager endpointManager, IMessageRouter router, ILoggerFactory loggerFactory)
public MultiEndpointServiceConnectionContainer(string hub,
Func<HubServiceEndpoint, IServiceConnectionContainer> generator,
IServiceEndpointManager endpointManager,
IMessageRouter router,
ILoggerFactory loggerFactory)
{
if (generator == null)
{
Expand Down Expand Up @@ -144,7 +148,6 @@ public Task WriteAsync(ServiceMessage serviceMessage)
{
return _inner.WriteAsync(serviceMessage);
}

return WriteMultiEndpointMessageAsync(serviceMessage, connection => connection.WriteAsync(serviceMessage));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ protected async Task RemoveConnectionFromService(IServiceConnection c)
{
_ = WriteFinAsync(c);

var source = new CancellationTokenSource();
using var source = new CancellationTokenSource();
var task = await Task.WhenAny(c.ConnectionOfflineTask, Task.Delay(RemoveFromServiceTimeout, source.Token));
source.Cancel();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
Expand Down
44 changes: 16 additions & 28 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

Expand All @@ -18,16 +17,22 @@ internal class ServiceLifetimeManager<THub> : ServiceLifetimeManagerBase<THub> w
"'AddAzureSignalR(...)' was called without a matching call to 'IApplicationBuilder.UseAzureSignalR(...)'.";

private readonly ILogger<ServiceLifetimeManager<THub>> _logger;
private readonly IReadOnlyList<IHubProtocol> _allProtocols;

private readonly IServiceConnectionManager<THub> _serviceConnectionManager;
private readonly IClientConnectionManager _clientConnectionManager;

public ServiceLifetimeManager(IServiceConnectionManager<THub> serviceConnectionManager,
IClientConnectionManager clientConnectionManager, IHubProtocolResolver protocolResolver,
ILogger<ServiceLifetimeManager<THub>> logger, AzureSignalRMarkerService marker,
IOptions<HubOptions> globalHubOptions, IOptions<HubOptions<THub>> hubOptions)
: base(serviceConnectionManager, protocolResolver, globalHubOptions, hubOptions)
public ServiceLifetimeManager(
IServiceConnectionManager<THub> serviceConnectionManager,
IClientConnectionManager clientConnectionManager,
IHubProtocolResolver protocolResolver,
ILogger<ServiceLifetimeManager<THub>> logger,
AzureSignalRMarkerService marker,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions)
: base(
serviceConnectionManager,
protocolResolver,
globalHubOptions,
hubOptions)
{
// after core 3.0 UseAzureSignalR() is not required.
#if NETSTANDARD2_0
Expand All @@ -36,23 +41,10 @@ public ServiceLifetimeManager(IServiceConnectionManager<THub> serviceConnectionM
throw new InvalidOperationException(MarkerNotConfiguredError);
}
#endif

_serviceConnectionManager = serviceConnectionManager;
_clientConnectionManager = clientConnectionManager;
_allProtocols = protocolResolver.AllProtocols;
_logger = logger;
}

public override Task OnConnectedAsync(HubConnectionContext connection)
{
if (_clientConnectionManager.ClientConnections.TryGetValue(connection.ConnectionId, out var serviceConnectionContext))
{
serviceConnectionContext.HubConnectionContext = connection;
}

return Task.CompletedTask;
}

public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default)
{
if (IsInvalidArgument(connectionId))
Expand All @@ -67,15 +59,11 @@ public override Task SendConnectionAsync(string connectionId, string methodName,

if (_clientConnectionManager.ClientConnections.TryGetValue(connectionId, out var serviceConnectionContext))
{
var message = new InvocationMessage(methodName, args);

var message = new MultiConnectionDataMessage(new[] { connectionId }, SerializeAllProtocols(methodName, args));
// Write directly to this connection
return serviceConnectionContext.HubConnectionContext.WriteAsync(message).AsTask();
return serviceConnectionContext.ServiceConnection.WriteAsync(message);
}

return base.SendConnectionAsync(connectionId, methodName, args, cancellationToken);


}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ public void TickHeartbeat()

public Task ApplicationTask { get; set; }

// The associated HubConnectionContext
public HubConnectionContext HubConnectionContext { get; set; }
public ServiceConnectionBase ServiceConnection { get; set; }

public HttpContext HttpContext { get; set; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;


namespace Microsoft.Azure.SignalR
{
internal partial class ServiceConnection : ServiceConnectionBase
Expand Down Expand Up @@ -101,6 +101,7 @@ protected override ReadOnlyMemory<byte> GetPingMessage()
protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
{
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext);
connection.ServiceConnection = this;
AddClientConnection(connection, GetInstanceId(message.Headers));
Log.ConnectedStarting(Logger, connection.ConnectionId);

Expand Down
100 changes: 81 additions & 19 deletions test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.Azure.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
Expand All @@ -21,7 +23,9 @@ public class ServiceLifetimeManagerFacts
{
private static readonly List<string> TestUsers = new List<string> {"user1", "user2"};

private static readonly List<string> TestGroups = new List<string> {"group1", "group2"};
private static readonly List<string> TestGroups = new List<string> {"group1", "group2"};

private const string MockProtocol = "blazorpack";

private const string TestMethod = "TestMethod";

Expand Down Expand Up @@ -78,8 +82,14 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type)
public async void ServiceLifetimeManagerGroupTest(string functionName, Type type)
{
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions);
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(
serviceConnectionManager,
new ClientConnectionManager(),
HubProtocolResolver,
Logger,
Marker,
_globalHubOptions,
_localHubOptions);

await InvokeMethod(serviceLifetimeManager, functionName);

Expand Down Expand Up @@ -147,8 +157,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "json", "messagepack", "blazorpack" } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "json", "messagepack", "blazorpack" } });
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "json", "messagepack", MockProtocol } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "json", "messagepack", MockProtocol } });
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions);
Expand All @@ -169,18 +179,8 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct
[InlineData("SendUsersAsync", typeof(MultiUserDataMessage))]
public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functionName, Type type)
{
var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[]
{
new JsonHubProtocol(),
new MessagePackHubProtocol(),
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { "blazorpack" } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { "blazorpack" } });
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = new ServiceLifetimeManager<TestHub>(serviceConnectionManager,
new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions);
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var serviceLifetimeManager = MockLifetimeManager(serviceConnectionManager);

await InvokeMethod(serviceLifetimeManager, functionName);

Expand All @@ -189,6 +189,57 @@ public async void ServiceLifetimeManagerOnlyBlazorHubProtocolTest(string functio
Assert.Equal(1, (serviceConnectionManager.ServiceMessage as MulticastDataMessage).Payloads.Count);
}

[Fact]
public async void TestSendConnectionAsyncisOverwrittenWhenClientConnectionExisted()
{
var serviceConnectionManager = new TestServiceConnectionManager<TestHub>();
var clientConnectionManager = new ClientConnectionManager();

var context = new ClientConnectionContext(new OpenConnectionMessage("conn1", new Claim[] { }));
var connection = new TestServiceConnectionPrivate();
context.ServiceConnection = connection;
clientConnectionManager.AddClientConnection(context);

var manager = MockLifetimeManager(serviceConnectionManager, clientConnectionManager);

await manager.SendConnectionAsync("conn1", "foo", new object[] { 1, 2 });

Assert.NotNull(connection.last);
if (connection.last is MultiConnectionDataMessage m)
{
Assert.Equal("conn1", m.ConnectionList[0]);
Assert.Equal(1, m.Payloads.Count);
Assert.True(m.Payloads.ContainsKey(MockProtocol));
Copy link
Member

@vicancy vicancy Nov 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MockProtocol [](start = 51, length = 12)

payloads should not contain blazor, use another name as MockProtocol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should.
Our resolver supports 3 different protocols (including a mocked protocol "blazor"),
and this mocked protocol has been designated in the lifetime manager creation options.

return;
}
Assert.True(false);
}

private HubLifetimeManager<TestHub> MockLifetimeManager(IServiceConnectionManager<TestHub> serviceConnectionManager, IClientConnectionManager clientConnectionManager = null)
{
clientConnectionManager ??= new ClientConnectionManager();

var protocolResolver = new DefaultHubProtocolResolver(new IHubProtocol[]
{
new JsonHubProtocol(),
new MessagePackHubProtocol(),
new CustomHubProtocol(),
},
NullLogger<DefaultHubProtocolResolver>.Instance
);
IOptions<HubOptions> globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List<string>() { MockProtocol } });
IOptions<HubOptions<TestHub>> localHubOptions = Options.Create(new HubOptions<TestHub>() { SupportedProtocols = new List<string>() { MockProtocol } });
return new ServiceLifetimeManager<TestHub>(
serviceConnectionManager,
clientConnectionManager,
protocolResolver,
Logger,
Marker,
globalHubOptions,
localHubOptions
);
}

private static async Task InvokeMethod(HubLifetimeManager<TestHub> serviceLifetimeManager, string methodName)
{
switch (methodName)
Expand Down Expand Up @@ -278,9 +329,20 @@ private static void VerifyServiceMessage(string methodName, ServiceMessage servi
}
}

private class CustomHubProtocol : IHubProtocol
private sealed class TestServiceConnectionPrivate : TestServiceConnection
{
public ServiceMessage last = null;

public override Task WriteAsync(ServiceMessage serviceMessage)
{
last = serviceMessage;
return Task.CompletedTask;
}
}

private sealed class CustomHubProtocol : IHubProtocol
{
public string Name => "blazorpack";
public string Name => MockProtocol;

public TransferFormat TransferFormat => throw new NotImplementedException();

Expand Down