Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broken Select with error list on macOS #104915

Merged
merged 12 commits into from
Jul 28, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_Select")]
internal static unsafe partial Error Select(Span<int> readFDs, int readFDsLength, Span<int> writeFDs, int writeFDsLength, Span<int> checkError, int checkErrorLength, int timeout, int maxFd, out int triggered);
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@
Link="Common\Interop\Unix\System.Native\Interop.ReceiveMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Send.cs"
Link="Common\Interop\Unix\System.Native\Interop.Send.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Select.cs"
Link="Common\Interop\Unix\System.Native\Interop.Select.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SendMessage.cs"
Link="Common\Interop\Unix\System.Native\Interop.SendMessage.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SetSockOpt.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal static partial class SocketPal
public static readonly int MaximumAddressSize = Interop.Sys.GetMaximumAddressSize();
private static readonly bool SupportsDualModeIPv4PacketInfo = GetPlatformSupportsDualModeIPv4PacketInfo();

private static readonly bool SelectOverPollIsBroken = OperatingSystem.IsMacOS() || OperatingSystem.IsIOS() || OperatingSystem.IsTvOS() || OperatingSystem.IsMacCatalyst();

// IovStackThreshold matches Linux's UIO_FASTIOV, which is the number of 'struct iovec'
// that get stackalloced in the Linux kernel.
private const int IovStackThreshold = 8;
Expand Down Expand Up @@ -1782,6 +1784,11 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
// by the system. Since poll then expects an array of entries, we try to allocate the array on the stack,
// only falling back to allocating it on the heap if it's deemed too big.

if (SelectOverPollIsBroken)
{
return SelectViaSelect(checkRead, checkWrite, checkError, microseconds);
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved

const int StackThreshold = 80; // arbitrary limit to avoid too much space on stack
if (count < StackThreshold)
{
Expand All @@ -1806,6 +1813,103 @@ public static unsafe SocketError Select(IList? checkRead, IList? checkWrite, ILi
}
}

private static SocketError SelectViaSelect(IList? checkRead, IList? checkWrite, IList? checkError, int microseconds)
{
const int MaxStackAllocCount = 20; // this is just arbitrary limit 3x 20 -> 60 e.g. close to 64 we have in some other places
Span<int> readFDs = checkRead?.Count > MaxStackAllocCount ? new int[checkRead.Count] : stackalloc int[checkRead?.Count ?? 0];
Span<int> writeFDs = checkWrite?.Count > MaxStackAllocCount ? new int[checkWrite.Count] : stackalloc int[checkWrite?.Count ?? 0];
Span<int> errorFDs = checkError?.Count > MaxStackAllocCount ? new int[checkError.Count] : stackalloc int[checkError?.Count ?? 0];

int refsAdded = 0;
int maxFd = 0;
try
{
AddDesriptors(readFDs, checkRead, ref refsAdded, ref maxFd);
AddDesriptors(writeFDs, checkWrite, ref refsAdded, ref maxFd);
AddDesriptors(errorFDs, checkError, ref refsAdded, ref maxFd);

int triggered = 0;
Interop.Error err = Interop.Sys.Select(readFDs, readFDs.Length, writeFDs, writeFDs.Length, errorFDs, errorFDs.Length, microseconds, maxFd, out triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
}

Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

if (triggered == 0)
{
checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
}
else
{
FilterSelectList(checkRead, readFDs);
FilterSelectList(checkWrite, writeFDs);
FilterSelectList(checkError, errorFDs);
}
}
finally
{
// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}

return (SocketError)0;
}

private static void AddDesriptors(Span<int> buffer, IList? socketList, ref int refsAdded, ref int maxFd)
{
if (socketList == null || socketList.Count == 0 )
{
return;
}

Debug.Assert(buffer.Length == socketList.Count);
for (int i = 0; i < socketList.Count; i++)
{
Socket? socket = socketList[i] as Socket;
if (socket == null)
{
throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList));
}

if (socket.Handle > maxFd)
{
maxFd = (int)socket.Handle;
}

bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
buffer[i] = (int)socket.InternalSafeHandle.DangerousGetHandle();

refsAdded++;
}
}

private static void FilterSelectList(IList? socketList, Span<int> results)
{
if (socketList == null)
return;

// This loop can be O(n^2) in the unexpected and worst case. Some more thoughts are written in FilterPollList that does exactly same operation.

for (int i = socketList.Count - 1; i >= 0; --i)
{
if (results[i] == 0)
{
socketList.RemoveAt(i);
}
}
}

private static unsafe SocketError SelectViaPoll(
IList? checkRead, int checkReadInitialCount,
IList? checkWrite, int checkWriteInitialCount,
Expand Down
107 changes: 103 additions & 4 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/SelectTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.DotNet.XUnitExtensions;
using Xunit;
using Xunit.Abstractions;

Expand All @@ -21,7 +21,7 @@ public SelectTest(ITestOutputHelper output)
}

private const int SmallTimeoutMicroseconds = 10 * 1000;
private const int FailTimeoutMicroseconds = 30 * 1000 * 1000;
internal const int FailTimeoutMicroseconds = 30 * 1000 * 1000;

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Theory]
Expand Down Expand Up @@ -78,6 +78,82 @@ public void Select_ReadWrite_AllReady(int reads, int writes)
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void Select_ReadError_Success(bool dispose)
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

if (dispose)
{
sender.Dispose();
}
else
{
sender.Send(new byte[] { 1 });
}

var readList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, null, errorList, -1);
if (dispose)
{
Assert.True(readList.Count == 1 || errorList.Count == 1);
}
else
{
Assert.Equal(1, readList.Count);
Assert.Equal(0, errorList.Count);
}
}

[Fact]
public void Select_WriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(null, writeList, errorList, -1);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Fact]
public void Select_ReadWriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

sender.Send(new byte[] { 1 });
receiver.Poll(FailTimeoutMicroseconds, SelectMode.SelectRead);
var readList = new List<Socket> { receiver };
var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, writeList, errorList, -1);
Assert.Equal(1, readList.Count);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Theory]
[InlineData(2, 0)]
[InlineData(2, 1)]
Expand Down Expand Up @@ -109,7 +185,6 @@ public void Select_SocketAlreadyClosed_AllSocketsClosableAfterException(int sock
}
}

[SkipOnPlatform(TestPlatforms.OSX, "typical OSX install has very low max open file descriptors value")]
[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
public void Select_ReadError_NoneReady_ManySockets()
Expand Down Expand Up @@ -245,7 +320,7 @@ public void Poll_ReadReady_LongTimeouts(int microsecondsTimeout)
}
}

private static KeyValuePair<Socket, Socket> CreateConnectedSockets()
internal static KeyValuePair<Socket, Socket> CreateConnectedSockets()
{
using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
Expand Down Expand Up @@ -342,5 +417,29 @@ private static void DoAccept(Socket listenSocket, int connectionsToAccept)
}
}
}

[ConditionalFact]
public void Select_LargeNumber_Succcess()
{
const int MaxSockets = 1025;
KeyValuePair<Socket, Socket>[] socketPairs;
try
{
// we try to shoot for more socket than FD_SETSIZE (that is typically 1024)
socketPairs = Enumerable.Range(0, MaxSockets).Select(_ => SelectTest.CreateConnectedSockets()).ToArray();
}
catch
{
throw new SkipTestException("Unable to open large count number of socket");
}

var readList = new List<Socket>(socketPairs.Select(p => p.Key).ToArray());

// Try to write and read on last sockets
(Socket reader, Socket writer) = socketPairs[MaxSockets - 1];
writer.Send(new byte[1]);
Socket.Select(readList, null, null, SelectTest.FailTimeoutMicroseconds);
Assert.Equal(1, readList.Count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,8 @@ public void FailedConnect_GetSocketOption_SocketOptionNameError(bool simpleGet)
Assert.ThrowsAny<Exception>(() => client.Connect(server.LocalEndPoint));
}

// Verify via Select that there's an error
const int FailedTimeout = 10 * 1000 * 1000; // 10 seconds
var errorList = new List<Socket> { client };
Socket.Select(null, null, errorList, FailedTimeout);
Assert.Equal(1, errorList.Count);
// Verify via Poll that there's an error
Assert.True(client.Poll(10_000_000, SelectMode.SelectError));

// Get the last error and validate it's what's expected
int errorCode;
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/System.Native/entrypoints.c
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ static const Entry s_sysNative[] =
DllImportEntry(SystemNative_GetGroupName)
DllImportEntry(SystemNative_GetUInt64OSThreadId)
DllImportEntry(SystemNative_TryGetUInt32OSThreadId)
DllImportEntry(SystemNative_Select)
};

EXTERN_C const void* SystemResolveDllImport(const char* name);
Expand Down
Loading
Loading