diff --git a/release_notes.md b/release_notes.md index 4674a21bc6..f500fd171a 100644 --- a/release_notes.md +++ b/release_notes.md @@ -13,4 +13,5 @@ - Fixed incorrect function count in the log message.(#10220) - Migrate Diagnostic Events to Azure.Data.Tables (#10218) - Sanitize worker arguments before logging (#10260) +- Fix race condition on startup with extension RPC endpoints not being available. (#10282) - Adding a timeout when retrieving function metadata from metadata providers (#10219) diff --git a/src/WebJobs.Script.Grpc/Server/ExtensionsCompositeEndpointDataSource.cs b/src/WebJobs.Script.Grpc/Server/ExtensionsCompositeEndpointDataSource.cs index cb384a3b34..ac64a6a22c 100644 --- a/src/WebJobs.Script.Grpc/Server/ExtensionsCompositeEndpointDataSource.cs +++ b/src/WebJobs.Script.Grpc/Server/ExtensionsCompositeEndpointDataSource.cs @@ -6,10 +6,12 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.Azure.WebJobs.Rpc.Core.Internal; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; namespace Microsoft.Azure.WebJobs.Script.Grpc @@ -26,6 +28,7 @@ internal sealed class ExtensionsCompositeEndpointDataSource : EndpointDataSource private readonly object _lock = new(); private readonly List _dataSources = new(); private readonly IScriptHostManager _scriptHostManager; + private readonly TaskCompletionSource _initialized = new(); private IServiceProvider _extensionServices; private List _endpoints; @@ -191,6 +194,7 @@ private void OnHostChanged(object sender, ActiveHostChangedEventArgs args) .GetService>() ?? Enumerable.Empty(); _dataSources.AddRange(sources); + _initialized.TrySetResult(); // signal we have first initialized. } else { @@ -301,5 +305,49 @@ private void ThrowIfDisposed() throw new ObjectDisposedException(nameof(ExtensionsCompositeEndpointDataSource)); } } + + /// + /// Middleware to ensure is initialized before routing for the first time. + /// Must be registered as a singleton service. + /// + /// The to ensure is initialized. + /// The logger. + public sealed class EnsureInitializedMiddleware(ExtensionsCompositeEndpointDataSource dataSource, ILogger logger) : IMiddleware + { + private TaskCompletionSource _initialized = new(); + private bool _firstRun = true; + + // used for testing to verify initialization success. + internal Task Initialized => _initialized.Task; + + // settable only for testing purposes. + internal TimeSpan Timeout { get; init; } = TimeSpan.FromSeconds(2); + + public Task InvokeAsync(HttpContext context, RequestDelegate next) + { + return _firstRun ? InvokeCoreAsync(context, next) : next(context); + } + + private async Task InvokeCoreAsync(HttpContext context, RequestDelegate next) + { + try + { + await dataSource._initialized.Task.WaitAsync(Timeout); + } + catch (TimeoutException ex) + { + // In case of deadlock we don't want to block all gRPC requests. + // Log an error and continue. + logger.LogError(ex, "Error initializing extension endpoints."); + _initialized.TrySetException(ex); + } + + // Even in case of timeout we don't want to continually test for initialization on subsequent requests. + // That would be a serious performance degredation. + _firstRun = false; + _initialized.TrySetResult(); + await next(context); + } + } } } diff --git a/src/WebJobs.Script.Grpc/Server/Startup.cs b/src/WebJobs.Script.Grpc/Server/Startup.cs index d30b539df3..695615338c 100644 --- a/src/WebJobs.Script.Grpc/Server/Startup.cs +++ b/src/WebJobs.Script.Grpc/Server/Startup.cs @@ -16,6 +16,7 @@ internal class Startup public void ConfigureServices(IServiceCollection services) { services.AddSingleton(); + services.AddSingleton(); services.AddGrpc(options => { options.MaxReceiveMessageSize = MaxMessageLengthBytes; @@ -30,12 +31,16 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) app.UseDeveloperExceptionPage(); } + // This must occur before 'UseRouting'. This ensures extension endpoints are registered before the + // endpoints are collected by the routing middleware. + app.UseMiddleware(); app.UseRouting(); app.UseEndpoints(endpoints => { endpoints.MapGrpcService(); - endpoints.DataSources.Add(endpoints.ServiceProvider.GetRequiredService()); + endpoints.DataSources.Add( + endpoints.ServiceProvider.GetRequiredService()); }); } } diff --git a/test/WebJobs.Script.Tests/Workers/Rpc/ExtensionsCompositeEndpointDataSourceTests.cs b/test/WebJobs.Script.Tests/Workers/Rpc/ExtensionsCompositeEndpointDataSourceTests.cs index ed8d32b355..84f8bb8cc7 100644 --- a/test/WebJobs.Script.Tests/Workers/Rpc/ExtensionsCompositeEndpointDataSourceTests.cs +++ b/test/WebJobs.Script.Tests/Workers/Rpc/ExtensionsCompositeEndpointDataSourceTests.cs @@ -3,12 +3,16 @@ using System; using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Azure.WebJobs.Rpc.Core.Internal; using Microsoft.Azure.WebJobs.Script.Grpc; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.FileProviders; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Primitives; using Moq; using Xunit; @@ -17,6 +21,9 @@ namespace Microsoft.Azure.WebJobs.Script.Tests.Workers.Rpc { public class ExtensionsCompositeEndpointDataSourceTests { + private static readonly ILogger _logger + = NullLogger.Instance; + [Fact] public void NoActiveHost_NoEndpoints() { @@ -41,6 +48,7 @@ public void ActiveHostChanged_NullHost_NoEndpoints() public void ActiveHostChanged_NoExtensions_NoEndpoints() { Mock manager = new(); + ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object); IChangeToken token = dataSource.GetChangeToken(); @@ -67,6 +75,45 @@ public void ActiveHostChanged_NewExtensions_NewEndpoints() endpoint => Assert.Equal("Test2", endpoint.DisplayName)); } + [Fact] + public async Task ActiveHostChanged_MiddlewareWaits_Success() + { + Mock manager = new(); + + ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object); + ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware = + new(dataSource, _logger) { Timeout = Timeout.InfiniteTimeSpan }; + TestDelegate next = new(); + + Task waiter = middleware.InvokeAsync(null, next.InvokeAsync); + Assert.False(waiter.IsCompleted); // should be blocked until we raise the event. + + manager.Raise(x => x.ActiveHostChanged += null, new ActiveHostChangedEventArgs(null, GetHost())); + await waiter.WaitAsync(TimeSpan.FromSeconds(5)); + await middleware.Initialized; + await next.Invoked; + } + + [Fact] + public async Task NoActiveHostChanged_MiddlewareWaits_Timeout() + { + Mock manager = new(); + + ExtensionsCompositeEndpointDataSource dataSource = new(manager.Object); + ExtensionsCompositeEndpointDataSource.EnsureInitializedMiddleware middleware = + new(dataSource, _logger) { Timeout = TimeSpan.Zero }; + TestDelegate next = new(); + + await middleware.InvokeAsync(null, next.InvokeAsync).WaitAsync(TimeSpan.FromSeconds(5)); // should not throw + await Assert.ThrowsAsync(() => middleware.Initialized); + await next.Invoked; + + // invoke again to verify it processes the next request. + next = new(); + await middleware.InvokeAsync(null, next.InvokeAsync); + await next.Invoked; + } + [Fact] public void Dispose_GetThrows() { @@ -105,5 +152,18 @@ public TestEndpoints(params Endpoint[] endpoints) public override IChangeToken GetChangeToken() => NullChangeToken.Singleton; } + + private class TestDelegate + { + private readonly TaskCompletionSource _invoked = new(); + + public Task Invoked => _invoked.Task; + + public Task InvokeAsync(HttpContext context) + { + _invoked.TrySetResult(); + return Task.CompletedTask; + } + } } }