Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix broken Select with error list on macOS #104915

Merged
merged 12 commits into from
Jul 28, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal static partial class SocketPal
public static readonly int MaximumAddressSize = Interop.Sys.GetMaximumAddressSize();
private static readonly bool SupportsDualModeIPv4PacketInfo = GetPlatformSupportsDualModeIPv4PacketInfo();

private static readonly bool PollNeedsErrorListFixup = OperatingSystem.IsMacOS() || OperatingSystem.IsIOS() || OperatingSystem.IsTvOS();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume there's nothing we can query for in the PAL layer and it really is just us knowing that these OSes behave differently?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is update we can check Darwin kernel version. But I'm not aware of anything so far.


// IovStackThreshold matches Linux's UIO_FASTIOV, which is the number of 'struct iovec'
// that get stackalloced in the Linux kernel.
private const int IovStackThreshold = 8;
Expand Down Expand Up @@ -1817,20 +1819,22 @@ private static unsafe SocketError SelectViaPoll(
Debug.Assert(eventsLength == checkReadInitialCount + checkWriteInitialCount + checkErrorInitialCount, "Invalid eventsLength");
int offset = 0;
int refsAdded = 0;
int readRefs;
try
{
// In case we can't increase the reference count for each Socket,
// we'll unref refAdded Sockets in the finally block ordered: [checkRead, checkWrite, checkError].
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.PollEvents.POLLIN | Interop.PollEvents.POLLHUP, ref refsAdded);
readRefs = refsAdded;
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.PollEvents.POLLOUT, ref refsAdded);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.PollEvents.POLLPRI, ref refsAdded);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded == eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.PollEvents.POLLPRI, ref refsAdded, PollNeedsErrorListFixup ? readRefs : 0);
Debug.Assert(offset <= eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded <= eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
Interop.Error err = Interop.Sys.Poll(events, (uint)refsAdded, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
Expand Down Expand Up @@ -1867,7 +1871,7 @@ private static unsafe SocketError SelectViaPoll(
}
}

private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength, IList? socketList, ref int arrOffset, Interop.PollEvents events, ref int refsAdded)
private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength, IList? socketList, ref int arrOffset, Interop.PollEvents events, ref int refsAdded, int readCount = 0)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
if (socketList == null)
return;
Expand All @@ -1887,6 +1891,29 @@ private static unsafe void AddToPollArray(Interop.PollEvent* arr, int arrLength,
bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
int fd = (int)socket.InternalSafeHandle.DangerousGetHandle();

if (readCount > 0)
{
// some platfoms like macOS do not like if there is duplication between real and error list.
wfurt marked this conversation as resolved.
Show resolved Hide resolved
// To fix that we will search read list and if macthing descriptor exiost we will add events flags
wfurt marked this conversation as resolved.
Show resolved Hide resolved
// instead of adding new entry to error list.
int readIndex = 0;
while (readIndex < readCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this turning a linear operation into an N^2 operation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. This is where the assumption comes in that this is used from small number of sockets where it does not matter as much .... and the cost is only on platforms that are currently broken. (and use read and error list together)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GrabYourPitchforks, any concerns?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put up different implementation that avoids the N^2 problem. Please take another look @stephentoub

{
if (arr[readIndex].FileDescriptor == fd)
{
arr[i].Events |= events;
socket.InternalSafeHandle.DangerousRelease();
break;
}
readIndex++;
}
if (readIndex != readCount)
{
continue;
}
}

arr[arrOffset++] = new Interop.PollEvent { Events = events, FileDescriptor = fd };
refsAdded++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,81 @@ public void Select_ReadWrite_AllReady(int reads, int writes)
}
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void Select_ReadError_Success(bool dispose)
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

if (dispose)
{
sender.Dispose();
}
else
{
sender.Send(new byte[] { 1 });
}

var readList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, null, errorList, -1);
if (dispose)
{
Assert.True(readList.Count == 1 || errorList.Count == 1);
}
else
{
Assert.Equal(1, readList.Count);
Assert.Equal(0, errorList.Count);
}
}

[Fact]
public void Select_WriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(null, writeList, errorList, -1);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Fact]
public void Select_ReadWriteError_Success()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);
using Socket sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Unspecified);

listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);
sender.Connect(listener.LocalEndPoint);
using Socket receiver = listener.Accept();

sender.Send(new byte[] { 1 });
var readList = new List<Socket> { receiver };
var writeList = new List<Socket> { receiver };
var errorList = new List<Socket> { receiver };
Socket.Select(readList, writeList, errorList, -1);
Assert.Equal(1, readList.Count);
Assert.Equal(1, writeList.Count);
Assert.Equal(0, errorList.Count);
}

[Theory]
[InlineData(2, 0)]
[InlineData(2, 1)]
Expand Down