diff --git a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs index 51c2314fde3669..d088a30d6f7445 100644 --- a/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs +++ b/src/libraries/System.Net.NameResolution/src/System/Net/NameResolutionPal.Windows.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using System.Diagnostics; +using Microsoft.Win32.SafeHandles; namespace System.Net { @@ -138,17 +139,14 @@ public static unsafe string GetHostName() { Interop.Winsock.EnsureInitialized(); - GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext(); - - GetAddrInfoExState state; + GetAddrInfoExState? state = null; try { - state = new GetAddrInfoExState(context, hostName, justAddresses); - context->QueryStateHandle = state.CreateHandle(); + state = new GetAddrInfoExState(hostName, justAddresses); } catch { - GetAddrInfoExContext.FreeContext(context); + state?.Dispose(); throw; } @@ -158,6 +156,8 @@ public static unsafe string GetHostName() hints.ai_flags = AddressInfoHints.AI_CANONNAME; } + GetAddrInfoExContext* context = state.Context; + SocketError errorCode = (SocketError)Interop.Winsock.GetAddrInfoExW( hostName, null, Interop.Winsock.NS_ALL, IntPtr.Zero, &hints, &context->Result, IntPtr.Zero, &context->Overlapped, &GetAddressInfoExCallback, &context->CancelHandle); @@ -172,7 +172,7 @@ public static unsafe string GetHostName() // and final result would be posted via overlapped IO. // synchronous failure here may signal issue when GetAddrInfoExW does not work from // impersonated context. Windows 8 and Server 2012 fail for same reason with different errorCode. - GetAddrInfoExContext.FreeContext(context); + state.Dispose(); return null; } else @@ -194,10 +194,10 @@ private static unsafe void GetAddressInfoExCallback(int error, int bytes, Native private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExContext* context) { + GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle); + try { - GetAddrInfoExState state = GetAddrInfoExState.FromHandleAndFree(context->QueryStateHandle); - CancellationToken cancellationToken = state.UnregisterAndGetCancellationToken(); if (errorCode == SocketError.Success) @@ -222,7 +222,7 @@ private static unsafe void ProcessResult(SocketError errorCode, GetAddrInfoExCon } finally { - GetAddrInfoExContext.FreeContext(context); + state.Dispose(); } } @@ -360,18 +360,21 @@ private static unsafe IPAddress CreateIPv6Address(ReadOnlySpan socketAddre return new IPAddress(address, scope); } - private sealed unsafe class GetAddrInfoExState : IThreadPoolWorkItem + // GetAddrInfoExState is a SafeHandle that manages the lifetime of GetAddrInfoExContext* + // to make sure GetAddrInfoExCancel always takes a valid memory address regardless of the race + // between cancellation and completion callbacks. + private sealed unsafe class GetAddrInfoExState : SafeHandleZeroOrMinusOneIsInvalid, IThreadPoolWorkItem { - private GetAddrInfoExContext* _cancellationContext; private CancellationTokenRegistration _cancellationRegistration; private AsyncTaskMethodBuilder IPHostEntryBuilder; private AsyncTaskMethodBuilder IPAddressArrayBuilder; private object? _result; + private volatile bool _completed; - public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool justAddresses) + public GetAddrInfoExState(string hostName, bool justAddresses) + : base(true) { - _cancellationContext = context; HostName = hostName; JustAddresses = justAddresses; if (justAddresses) @@ -384,6 +387,10 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j IPHostEntryBuilder = AsyncTaskMethodBuilder.Create(); _ = IPHostEntryBuilder.Task; // force initialization } + + GetAddrInfoExContext* context = GetAddrInfoExContext.AllocateContext(); + context->QueryStateHandle = CreateHandle(); + SetHandle((IntPtr)context); } public string HostName { get; } @@ -392,52 +399,62 @@ public GetAddrInfoExState(GetAddrInfoExContext *context, string hostName, bool j public Task Task => JustAddresses ? (Task)IPAddressArrayBuilder.Task : IPHostEntryBuilder.Task; + internal GetAddrInfoExContext* Context => (GetAddrInfoExContext*)handle; + public void RegisterForCancellation(CancellationToken cancellationToken) { if (!cancellationToken.CanBeCanceled) return; - lock (this) + if (_completed) { - if (_cancellationContext == null) + // The operation completed before registration could be done. + return; + } + + _cancellationRegistration = cancellationToken.UnsafeRegister(static o => + { + var @this = (GetAddrInfoExState)o!; + if (@this._completed) { - // The operation completed before registration could be done. + // Escape early and avoid ObjectDisposedException in DangerousAddRef return; } - _cancellationRegistration = cancellationToken.UnsafeRegister(o => + bool needRelease = false; + try { - var @this = (GetAddrInfoExState)o!; - int cancelResult = 0; + @this.DangerousAddRef(ref needRelease); - lock (@this) - { - GetAddrInfoExContext* context = @this._cancellationContext; - - if (context != null) - { - // An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR. - // If this thread has lost the race between cancellation and completion, this will be a NOP - // with GetAddrInfoExCancel returning WSA_INVALID_HANDLE. - cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle); - } - } + // If DangerousAddRef didn't throw ODE, the handle should contain a valid pointer. + GetAddrInfoExContext* context = @this.Context; - if (cancelResult != 0 && cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled()) + // An outstanding operation will be completed with WSA_E_CANCELLED, and GetAddrInfoExCancel will return NO_ERROR. + // If this thread has lost the race between cancellation and completion, this will be a NOP + // with GetAddrInfoExCancel returning WSA_INVALID_HANDLE. + int cancelResult = Interop.Winsock.GetAddrInfoExCancel(&context->CancelHandle); + if (cancelResult != Interop.Winsock.WSA_INVALID_HANDLE && NetEventSource.Log.IsEnabled()) { NetEventSource.Info(@this, $"GetAddrInfoExCancel returned error {cancelResult}"); } - }, this); - } + } + finally + { + if (needRelease) + { + @this.DangerousRelease(); + } + } + + }, this); } public CancellationToken UnregisterAndGetCancellationToken() { - lock (this) - { - _cancellationContext = null; - _cancellationRegistration.Unregister(); - } + _completed = true; + // We should not wait for pending cancellation callbacks with CTR.Dispose(), + // since we are in a completion routine and GetAddrInfoExCancel may get blocked until it's finished. + _cancellationRegistration.Unregister(); return _cancellationRegistration.Token; } @@ -479,8 +496,6 @@ void IThreadPoolWorkItem.Execute() } } - public IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal)); - public static GetAddrInfoExState FromHandleAndFree(IntPtr handle) { GCHandle gcHandle = GCHandle.FromIntPtr(handle); @@ -488,6 +503,15 @@ public static GetAddrInfoExState FromHandleAndFree(IntPtr handle) gcHandle.Free(); return state; } + + protected override bool ReleaseHandle() + { + GetAddrInfoExContext.FreeContext(Context); + + return true; + } + + private IntPtr CreateHandle() => GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Normal)); } [StructLayout(LayoutKind.Sequential)] @@ -498,12 +522,7 @@ private unsafe struct GetAddrInfoExContext public IntPtr CancelHandle; public IntPtr QueryStateHandle; - public static GetAddrInfoExContext* AllocateContext() - { - var context = (GetAddrInfoExContext*)Marshal.AllocHGlobal(sizeof(GetAddrInfoExContext)); - *context = default; - return context; - } + public static GetAddrInfoExContext* AllocateContext() => (GetAddrInfoExContext*)NativeMemory.AllocZeroed((nuint)sizeof(GetAddrInfoExContext)); public static void FreeContext(GetAddrInfoExContext* context) { @@ -511,8 +530,7 @@ public static void FreeContext(GetAddrInfoExContext* context) { Interop.Winsock.FreeAddrInfoExW(context->Result); } - - Marshal.FreeHGlobal((IntPtr)context); + NativeMemory.Free(context); } } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs index 450357c57a130b..10b588ebda6c8f 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostAddressesTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; @@ -170,10 +171,13 @@ public async Task DnsGetHostAddresses_PreCancelledToken_Throws() OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Dns.GetHostAddressesAsync(TestSettings.LocalHost, cts.Token)); Assert.Equal(cts.Token, oce.CancellationToken); } + } - [OuterLoop] + // Cancellation tests are sequential to reduce the chance of timing issues. + [Collection(nameof(DisableParallelization))] + public class GetHostAddressesTest_Cancellation + { [Fact] - [ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below. [ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix. public async Task DnsGetHostAddresses_PostCancelledToken_Throws() { @@ -188,5 +192,35 @@ public async Task DnsGetHostAddresses_PostCancelledToken_Throws() OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => task); Assert.Equal(cts.Token, oce.CancellationToken); } + + // This is a regression test for https://github.com/dotnet/runtime/issues/63552 + [Fact] + [ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix. + public async Task DnsGetHostAddresses_ResolveParallelCancelOnFailure_AllCallsReturn() + { + string invalidAddress = TestSettings.UncachedHost; + await ResolveManyAsync(invalidAddress); + await ResolveManyAsync(invalidAddress, TestSettings.LocalHost) + .WaitAsync(TestSettings.PassingTestTimeout); + + static async Task ResolveManyAsync(params string[] addresses) + { + using CancellationTokenSource cts = new(); + Task[] resolveTasks = addresses.Select(a => ResolveOneAsync(a, cts)).ToArray(); + await Task.WhenAll(resolveTasks); + } + + static async Task ResolveOneAsync(string address, CancellationTokenSource cancellationTokenSource) + { + try + { + await Dns.GetHostAddressesAsync(address, cancellationTokenSource.Token); + } + catch (Exception) + { + cancellationTokenSource.Cancel(); + } + } + } } } diff --git a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs index 5a8b6d5d92f8f5..69cf72e00af08f 100644 --- a/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs +++ b/src/libraries/System.Net.NameResolution/tests/FunctionalTests/GetHostEntryTest.cs @@ -309,13 +309,21 @@ public async Task DnsGetHostEntry_PreCancelledToken_Throws() OperationCanceledException oce = await Assert.ThrowsAnyAsync(() => Dns.GetHostEntryAsync(TestSettings.LocalHost, cts.Token)); Assert.Equal(cts.Token, oce.CancellationToken); } + } + // Cancellation tests are sequential to reduce the chance of timing issues. + [Collection(nameof(DisableParallelization))] + public class GetHostEntryTest_Cancellation + { [OuterLoop] - [ActiveIssue("https://github.com/dotnet/runtime/issues/43816")] // Race condition outlined below. [ActiveIssue("https://github.com/dotnet/runtime/issues/33378", TestPlatforms.AnyUnix)] // Cancellation of an outstanding getaddrinfo is not supported on *nix. [Fact] public async Task DnsGetHostEntry_PostCancelledToken_Throws() { + // Windows 7 name resolution is synchronous and does not respect cancellation. + if (PlatformDetection.IsWindows7) + return; + using var cts = new CancellationTokenSource(); Task task = Dns.GetHostEntryAsync(TestSettings.UncachedHost, cts.Token);