Skip to content

Commit

Permalink
Fix DNS cancellation deadlock (#63904)
Browse files Browse the repository at this point in the history
Avoid taking a lock, and address the use-after-free race condition by guarding GetAddrInfoExContext with a SafeHandle.
  • Loading branch information
antonfirsov authored Mar 21, 2022
1 parent 135e566 commit 548b70d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading;
using System.Threading.Tasks;
using System.Diagnostics;
using Microsoft.Win32.SafeHandles;

namespace System.Net
{
Expand Down Expand Up @@ -138,17 +139,14 @@ public static unsafe string GetHostName()
{
Interop.Winsock.EnsureInitialized();

GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();

GetAddrInfoExState state;
GetAddrInfoExState? state = null;
try
{
state = new GetAddrInfoExState(context, hostName, justAddresses);
context->QueryStateHandle = state.CreateHandle();
state = new GetAddrInfoExState(hostName, justAddresses);
}
catch
{
GetAddrInfoExContext.FreeContext(context);
state?.Dispose();
throw;
}

Expand All @@ -158,6 +156,8 @@ public static unsafe string GetHostName()
hints.ai_flags = AddressInfoHints.AI_CANONNAME;
}

GetAddrInfoExContext* context = state.Context;

SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW(
hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, &GetAddressInfoExCallback, &context->CancelHandle);

Expand All @@ -172,7 +172,7 @@ public static unsafe string GetHostName()
// and final result would be posted via overlapped IO.
// synchronous failure here may signal issue when GetAddrInfoExW does not work from
// impersonated context. Windows 8 and Server 2012 fail for same reason with different errorCode.
GetAddrInfoExContext.FreeContext(context);
state.Dispose();
return null;
}
else
Expand All @@ -194,10 +194,10 @@ private static unsafe void GetAddressInfoExCallback(int error, int bytes, Native

private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExContext* context)
{
GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle);

try
{
GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle);

CancellationToken cancellationToken = state.UnregisterAndGetCancellationToken();

if (errorCode == SocketError.Success)
Expand All @@ -222,7 +222,7 @@ private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExCon
}
finally
{
GetAddrInfoExContext.FreeContext(context);
state.Dispose();
}
}

Expand Down Expand Up @@ -360,18 +360,21 @@ private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan<byte> socketAddre
return new IPAddress(address, scope);
}

private sealed unsafe class GetAddrInfoExState : IThreadPoolWorkItem
// GetAddrInfoExState is a SafeHandle that manages the lifetime of GetAddrInfoExContext*
// to make sure GetAddrInfoExCancel always takes a valid memory address regardless of the race
// between cancellation and completion callbacks.
private sealed unsafe class GetAddrInfoExState : SafeHandleZeroOrMinusOneIsInvalid, IThreadPoolWorkItem
{
private GetAddrInfoExContext* _cancellationContext;
private CancellationTokenRegistration _cancellationRegistration;

private AsyncTaskMethodBuilder<IPHostEntry> IPHostEntryBuilder;
private AsyncTaskMethodBuilder<IPAddress[]> IPAddressArrayBuilder;
private object? _result;
private volatile bool _completed;

public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool justAddresses)
public GetAddrInfoExState(string hostName, bool justAddresses)
: base(true)
{
_cancellationContext = context;
HostName = hostName;
JustAddresses = justAddresses;
if (justAddresses)
Expand All @@ -384,6 +387,10 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j
IPHostEntryBuilder = AsyncTaskMethodBuilder<IPHostEntry>.Create();
_ = IPHostEntryBuilder.Task; // force initialization
}

GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext();
context->QueryStateHandle = CreateHandle();
SetHandle((IntPtr)context);
}

public string HostName { get; }
Expand All @@ -392,52 +399,62 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j

public Task Task => JustAddresses ? (Task)IPAddressArrayBuilder.Task : IPHostEntryBuilder.Task;

internal GetAddrInfoExContext* Context => (GetAddrInfoExContext*)handle;

public void RegisterForCancellation(CancellationToken cancellationToken)
{
if (!cancellationToken.CanBeCanceled) return;

lock (this)
if (_completed)
{
if (_cancellationContext == null)
// The operation completed before registration could be done.
return;
}

_cancellationRegistration = cancellationToken.UnsafeRegister(static o =>
{
var @this = (GetAddrInfoExState)o!;
if (@this._completed)
{
// The operation completed before registration could be done.
// Escape early and avoid ObjectDisposedException in DangerousAddRef
return;
}

_cancellationRegistration = cancellationToken.UnsafeRegister(o =>
bool needRelease = false;
try
{
var @this = (GetAddrInfoExState)o!;
int cancelResult = 0;
@this.DangerousAddRef(ref needRelease);

lock (@this)
{
GetAddrInfoExContext* context = @this._cancellationContext;

if (context != null)
{
// An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR.
// If this thread has lost the race between cancellation and completion, this will be a NOP
// with GetAddrInfoExCancel returning WSA_INVALID_HANDLE.
cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle);
}
}
// If DangerousAddRef didn't throw ODE, the handle should contain a valid pointer.
GetAddrInfoExContext* context = @this.Context;

if (cancelResult != 0 && cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled())
// An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR.
// If this thread has lost the race between cancellation and completion, this will be a NOP
// with GetAddrInfoExCancel returning WSA_INVALID_HANDLE.
int cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle);
if (cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(@this, $"GetAddrInfoExCancel returned error {cancelResult}");
}
}, this);
}
}
finally
{
if (needRelease)
{
@this.DangerousRelease();
}
}

}, this);
}

public CancellationToken UnregisterAndGetCancellationToken()
{
lock (this)
{
_cancellationContext = null;
_cancellationRegistration.Unregister();
}
_completed = true;

// We should not wait for pending cancellation callbacks with CTR.Dispose(),
// since we are in a completion routine and GetAddrInfoExCancel may get blocked until it's finished.
_cancellationRegistration.Unregister();
return _cancellationRegistration.Token;
}

Expand Down Expand Up @@ -479,15 +496,22 @@ void IThreadPoolWorkItem.Execute()
}
}

public IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal));

public static GetAddrInfoExState FromHandleAndFree(IntPtr handle)
{
GCHandle gcHandle = GCHandle.FromIntPtr(handle);
var state = (GetAddrInfoExState)gcHandle.Target!;
gcHandle.Free();
return state;
}

protected override bool ReleaseHandle()
{
GetAddrInfoExContext.FreeContext(Context);

return true;
}

private IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal));
}

[StructLayout(LayoutKind.Sequential)]
Expand All @@ -498,21 +522,15 @@ private unsafe struct GetAddrInfoExContext
public IntPtr CancelHandle;
public IntPtr QueryStateHandle;

public static GetAddrInfoExContext* AllocateContext()
{
var context = (GetAddrInfoExContext*)Marshal.AllocHGlobal(sizeof(GetAddrInfoExContext));
*context = default;
return context;
}
public static GetAddrInfoExContext* AllocateContext() => (GetAddrInfoExContext*)NativeMemory.AllocZeroed((nuint)sizeof(GetAddrInfoExContext));

public static void FreeContext(GetAddrInfoExContext* context)
{
if (context->Result != null)
{
Interop.Winsock.FreeAddrInfoExW(context->Result);
}

Marshal.FreeHGlobal((IntPtr)context);
NativeMemory.Free(context);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -170,10 +171,13 @@ public async Task DnsGetHostAddresses_PreCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => Dns.GetHostAddressesAsync(TestSettings.LocalHost, cts.Token));
Assert.Equal(cts.Token, oce.CancellationToken);
}
}

[OuterLoop]
// Cancellation tests are sequential to reduce the chance of timing issues.
[Collection(nameof(DisableParallelization))]
public class GetHostAddressesTest_Cancellation
{
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below.
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
public async Task DnsGetHostAddresses_PostCancelledToken_Throws()
{
Expand All @@ -188,5 +192,35 @@ public async Task DnsGetHostAddresses_PostCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => task);
Assert.Equal(cts.Token, oce.CancellationToken);
}

// This is a regression test for https://github.com/dotnet/runtime/issues/63552
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
public async Task DnsGetHostAddresses_ResolveParallelCancelOnFailure_AllCallsReturn()
{
string invalidAddress = TestSettings.UncachedHost;
await ResolveManyAsync(invalidAddress);
await ResolveManyAsync(invalidAddress, TestSettings.LocalHost)
.WaitAsync(TestSettings.PassingTestTimeout);

static async Task ResolveManyAsync(params string[] addresses)
{
using CancellationTokenSource cts = new();
Task[] resolveTasks = addresses.Select(a => ResolveOneAsync(a, cts)).ToArray();
await Task.WhenAll(resolveTasks);
}

static async Task ResolveOneAsync(string address, CancellationTokenSource cancellationTokenSource)
{
try
{
await Dns.GetHostAddressesAsync(address, cancellationTokenSource.Token);
}
catch (Exception)
{
cancellationTokenSource.Cancel();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,21 @@ public async Task DnsGetHostEntry_PreCancelledToken_Throws()
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(() => Dns.GetHostEntryAsync(TestSettings.LocalHost, cts.Token));
Assert.Equal(cts.Token, oce.CancellationToken);
}
}

// Cancellation tests are sequential to reduce the chance of timing issues.
[Collection(nameof(DisableParallelization))]
public class GetHostEntryTest_Cancellation
{
[OuterLoop]
[ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below.
[ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix.
[Fact]
public async Task DnsGetHostEntry_PostCancelledToken_Throws()
{
// Windows 7 name resolution is synchronous and does not respect cancellation.
if (PlatformDetection.IsWindows7)
return;

using var cts = new CancellationTokenSource();

Task task = Dns.GetHostEntryAsync(TestSettings.UncachedHost, cts.Token);
Expand Down

0 comments on commit 548b70d

Please sign in to comment.