Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak when using call context propagation with cancellation token #2421

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
using Grpc.Core.Interceptors;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Grpc.AspNetCore.ClientFactory;

Expand Down Expand Up @@ -53,14 +52,15 @@ public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreami
}
else
{
var state = CreateContextState(call, cts);
return new AsyncClientStreamingCall<TRequest, TResponse>(
requestStream: call.RequestStream,
responseAsync: call.ResponseAsync,
responseAsync: OnResponseAsync(call.ResponseAsync, state),
responseHeadersAsync: ClientStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: ClientStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: ClientStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -73,14 +73,15 @@ public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreami
}
else
{
var state = CreateContextState(call, cts);
return new AsyncDuplexStreamingCall<TRequest, TResponse>(
requestStream: call.RequestStream,
responseStream: call.ResponseStream,
responseStream: new ResponseStreamWrapper<TResponse>(call.ResponseStream, state),
responseHeadersAsync: DuplexStreamingCallbacks<TRequest, TResponse>.GetResponseHeadersAsync,
getStatusFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetStatus,
getTrailersFunc: DuplexStreamingCallbacks<TRequest, TResponse>.GetTrailers,
disposeAction: DuplexStreamingCallbacks<TRequest, TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -93,13 +94,14 @@ public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRe
}
else
{
var state = CreateContextState(call, cts);
return new AsyncServerStreamingCall<TResponse>(
responseStream: call.ResponseStream,
responseStream: new ResponseStreamWrapper<TResponse>(call.ResponseStream, state),
responseHeadersAsync: ServerStreamingCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: ServerStreamingCallbacks<TResponse>.GetStatus,
getTrailersFunc: ServerStreamingCallbacks<TResponse>.GetTrailers,
disposeAction: ServerStreamingCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -112,13 +114,14 @@ public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TR
}
else
{
var state = CreateContextState(call, cts);
return new AsyncUnaryCall<TResponse>(
responseAsync: call.ResponseAsync,
responseAsync: OnResponseAsync(call.ResponseAsync, state),
responseHeadersAsync: UnaryCallbacks<TResponse>.GetResponseHeadersAsync,
getStatusFunc: UnaryCallbacks<TResponse>.GetStatus,
getTrailersFunc: UnaryCallbacks<TResponse>.GetTrailers,
disposeAction: UnaryCallbacks<TResponse>.Dispose,
CreateContextState(call, cts));
state);
}
}

Expand All @@ -129,6 +132,19 @@ public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest reques
return response;
}

// Automatically dispose state after awaiting the response.
private static async Task<TResponse> OnResponseAsync<TResponse>(Task<TResponse> task, IDisposable state)
{
try
{
return await task.ConfigureAwait(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this be a rare occasion when ContinueWith makes more sense?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would it be better here? From what I know, async/await is preferred.

}
finally
{
state.Dispose();
}
}

private ClientInterceptorContext<TRequest, TResponse> ConfigureContext<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, out CancellationTokenSource? linkedCts)
where TRequest : class
where TResponse : class
Expand Down Expand Up @@ -197,7 +213,7 @@ private bool TryGetServerCallContext([NotNullWhen(true)] out ServerCallContext?
private ContextState<TCall> CreateContextState<TCall>(TCall call, CancellationTokenSource cancellationTokenSource) where TCall : IDisposable =>
new ContextState<TCall>(call, cancellationTokenSource);

private class ContextState<TCall> : IDisposable where TCall : IDisposable
private sealed class ContextState<TCall> : IDisposable where TCall : IDisposable
{
public ContextState(TCall call, CancellationTokenSource cancellationTokenSource)
{
Expand All @@ -215,6 +231,33 @@ public void Dispose()
}
}

// Automatically dispose state after reading to the end of the stream.
private sealed class ResponseStreamWrapper<TResponse> : IAsyncStreamReader<TResponse>
{
private readonly IAsyncStreamReader<TResponse> _inner;
private readonly IDisposable _state;
private bool _disposed;

public ResponseStreamWrapper(IAsyncStreamReader<TResponse> inner, IDisposable state)
{
_inner = inner;
_state = state;
}

public TResponse Current => _inner.Current;

public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
var result = await _inner.MoveNext(cancellationToken);
if (!result && !_disposed)
{
_state.Dispose();
_disposed = true;
}
return result;
}
}

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _propagateServerCallContextFailure =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -20,6 +20,7 @@
using Greet;
using Grpc.AspNetCore.Server.ClientFactory.Tests.TestObjects;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Grpc.Net.ClientFactory;
using Grpc.Net.ClientFactory.Internal;
using Grpc.Tests.Shared;
Expand Down Expand Up @@ -91,6 +92,155 @@ public async Task CreateClient_ServerCallContextHasValues_PropogatedDeadlineAndC
Assert.AreEqual(cancellationToken, options.CancellationToken);
}

[Test]
public async Task CreateClient_Unary_ServerCallContextHasValues_StateDisposed()
{
// Arrange
var baseAddress = new Uri("http://localhost");
var deadline = DateTime.UtcNow.AddDays(1);
var cancellationToken = new CancellationTokenSource().Token;

var interceptor = new OnDisposedInterceptor();

var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation()
.AddInterceptor(() => interceptor)
.ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = CreateGrpcClientFactory(serviceProvider);
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Checking that token register calls don't build up on CTS and create a memory leak.
var cts = new CancellationTokenSource();

// Act
// Send calls in a different method so there is no chance that a stack reference
// to a gRPC call is still alive after calls are complete.
var response = await client.SayHelloAsync(new HelloRequest(), cancellationToken: cts.Token);

// Assert
Assert.IsTrue(interceptor.ContextDisposed);
}

[Test]
public async Task CreateClient_ServerStreaming_ServerCallContextHasValues_StateDisposed()
{
// Arrange
var baseAddress = new Uri("http://localhost");
var deadline = DateTime.UtcNow.AddDays(1);
var cancellationToken = new CancellationTokenSource().Token;

var interceptor = new OnDisposedInterceptor();

var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton(CreateHttpContextAccessorWithServerCallContext(deadline: deadline, cancellationToken: cancellationToken));
services
.AddGrpcClient<Greeter.GreeterClient>(o =>
{
o.Address = baseAddress;
})
.EnableCallContextPropagation()
.AddInterceptor(() => interceptor)
.ConfigurePrimaryHttpMessageHandler(() => ClientTestHelpers.CreateTestMessageHandler(new HelloReply()));

var serviceProvider = services.BuildServiceProvider(validateScopes: true);

var clientFactory = CreateGrpcClientFactory(serviceProvider);
var client = clientFactory.CreateClient<Greeter.GreeterClient>(nameof(Greeter.GreeterClient));

// Checking that token register calls don't build up on CTS and create a memory leak.
var cts = new CancellationTokenSource();

// Act
// Send calls in a different method so there is no chance that a stack reference
// to a gRPC call is still alive after calls are complete.
var call = client.SayHellos(new HelloRequest(), cancellationToken: cts.Token);

Assert.IsTrue(await call.ResponseStream.MoveNext());
Assert.IsFalse(await call.ResponseStream.MoveNext());

// Assert
Assert.IsTrue(interceptor.ContextDisposed);
}

private sealed class OnDisposedInterceptor : Interceptor
{
public bool ContextDisposed { get; private set; }

public override TResponse BlockingUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, BlockingUnaryCallContinuation<TRequest, TResponse> continuation)
{
return continuation(request, context);
}

public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(request, context);
return new AsyncUnaryCall<TResponse>(call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>(TRequest request, ClientInterceptorContext<TRequest, TResponse> context, AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(request, context);
return new AsyncServerStreamingCall<TResponse>(call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncClientStreamingCall<TRequest, TResponse> AsyncClientStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncClientStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(context);
return new AsyncClientStreamingCall<TRequest, TResponse>(call.RequestStream,
call.ResponseAsync,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}

public override AsyncDuplexStreamingCall<TRequest, TResponse> AsyncDuplexStreamingCall<TRequest, TResponse>(ClientInterceptorContext<TRequest, TResponse> context, AsyncDuplexStreamingCallContinuation<TRequest, TResponse> continuation)
{
var call = continuation(context);
return new AsyncDuplexStreamingCall<TRequest, TResponse>(call.RequestStream,
call.ResponseStream,
call.ResponseHeadersAsync,
call.GetStatus,
call.GetTrailers,
() =>
{
call.Dispose();
ContextDisposed = true;
});
}
}

[TestCase(Canceller.Context)]
[TestCase(Canceller.User)]
public async Task CreateClient_ServerCallContextAndUserCancellationToken_PropogatedDeadlineAndCancellation(Canceller canceller)
Expand Down
Loading