Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UdpClient with span support #53429

7 changes: 7 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -756,12 +756,19 @@ public void JoinMulticastGroup(System.Net.IPAddress multicastAddr, int timeToLiv
public void JoinMulticastGroup(System.Net.IPAddress multicastAddr, System.Net.IPAddress localAddress) { }
public byte[] Receive([System.Diagnostics.CodeAnalysis.NotNullAttribute] ref System.Net.IPEndPoint? remoteEP) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.UdpReceiveResult> ReceiveAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.UdpReceiveResult> ReceiveAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public int Send(byte[] dgram, int bytes) { throw null; }
public int Send(System.ReadOnlySpan<byte> datagram) {throw null; }
public int Send(byte[] dgram, int bytes, System.Net.IPEndPoint? endPoint) { throw null; }
public int Send(System.ReadOnlySpan<byte> datagram, System.Net.IPEndPoint? endPoint) { throw null; }
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
public int Send(byte[] dgram, int bytes, string? hostname, int port) { throw null; }
public int Send(System.ReadOnlySpan<byte> datagram, string? hostname, int port) { throw null; }
public System.Threading.Tasks.Task<int> SendAsync(byte[] datagram, int bytes) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendAsync(System.ReadOnlyMemory<byte> datagram, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public System.Threading.Tasks.Task<int> SendAsync(byte[] datagram, int bytes, System.Net.IPEndPoint? endPoint) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendAsync(System.ReadOnlyMemory<byte> datagram, System.Net.IPEndPoint? endPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; }
public System.Threading.Tasks.Task<int> SendAsync(byte[] datagram, int bytes, string? hostname, int port) { throw null; }
public System.Threading.Tasks.ValueTask<int> SendAsync(System.ReadOnlyMemory<byte> datagram, string? hostname, int port, System.Threading.CancellationToken cancellationToken = default) { throw null; }
}
public partial struct UdpReceiveResult : System.IEquatable<System.Net.Sockets.UdpReceiveResult>
{
Expand Down
105 changes: 80 additions & 25 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/UDPClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Threading.Tasks;
using System.Runtime.Versioning;
using System.Threading;

namespace System.Net.Sockets
{
Expand Down Expand Up @@ -334,6 +335,17 @@ private void ValidateDatagram(byte[] datagram, int bytes, IPEndPoint? endPoint)
}
}

private void ValidateDatagram(ReadOnlyMemory<byte> datagram, IPEndPoint? endPoint)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The datagram argument is unused, note that this also makes the naming unfortunate. Since this particular overload has only one usage, it's cleaner to inline probably.

Copy link
Contributor Author

@lateapexearlyspeed lateapexearlyspeed Jun 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored code to validate inside calling code.

{
ThrowIfDisposed();

if (_active && endPoint != null)
{
// Do not allow sending packets to arbitrary host when connected.
throw new InvalidOperationException(SR.net_udpconnected);
}
}

private IPEndPoint? GetEndpoint(string? hostname, int port)
{
if (_active && ((hostname != null) || (port != 0)))
Expand Down Expand Up @@ -600,9 +612,15 @@ public void DropMulticastGroup(IPAddress multicastAddr, int ifindex)
public Task<int> SendAsync(byte[] datagram, int bytes) =>
SendAsync(datagram, bytes, null);

public ValueTask<int> SendAsync(ReadOnlyMemory<byte> datagram, CancellationToken cancellationToken = default) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to a new rule, we need to add triple-slash docs for all new public methods right in product code PR-s. They will be fed to a tool to generate API docs.

You can find examples in Socket.cs. My recommendation is to copy and alter documentation text from existing overloads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added xml doc for new APIs.

SendAsync(datagram, null, cancellationToken);

public Task<int> SendAsync(byte[] datagram, int bytes, string? hostname, int port) =>
SendAsync(datagram, bytes, GetEndpoint(hostname, port));

public ValueTask<int> SendAsync(ReadOnlyMemory<byte> datagram, string? hostname, int port, CancellationToken cancellationToken = default) =>
SendAsync(datagram, GetEndpoint(hostname, port), cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is an async overload, we should not use the blocking Dns.GetHostAddresses when resolving hostname. We need an async variant of GetEndpoint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should have an async version of GetEndpoint. That said, the existing SendAsync above just uses sync GetEndpoint. So since we don't have this today, I think we could live without it for this PR and file a separate issue.


public Task<int> SendAsync(byte[] datagram, int bytes, IPEndPoint? endPoint)
{
ValidateDatagram(datagram, bytes, endPoint);
Expand All @@ -618,6 +636,21 @@ public Task<int> SendAsync(byte[] datagram, int bytes, IPEndPoint? endPoint)
}
}

public ValueTask<int> SendAsync(ReadOnlyMemory<byte> datagram, IPEndPoint? endPoint, CancellationToken cancellationToken = default)
{
ValidateDatagram(datagram, endPoint);

if (endPoint is null)
{
return _clientSocket.SendAsync(datagram, SocketFlags.None, cancellationToken);
}
else
{
CheckForBroadcast(endPoint.Address);
return _clientSocket.SendToAsync(datagram, SocketFlags.None, endPoint, cancellationToken);
}
}

public Task<UdpReceiveResult> ReceiveAsync()
{
ThrowIfDisposed();
Expand All @@ -639,6 +672,27 @@ async Task<UdpReceiveResult> WaitAndWrap(Task<SocketReceiveFromResult> task)
}
}

public ValueTask<UdpReceiveResult> ReceiveAsync(CancellationToken cancellationToken)
{
ThrowIfDisposed();

return WaitAndWrap(_clientSocket.ReceiveFromAsync(
_buffer,
SocketFlags.None,
_family == AddressFamily.InterNetwork ? IPEndPointStatics.Any : IPEndPointStatics.IPv6Any, cancellationToken));

async ValueTask<UdpReceiveResult> WaitAndWrap(ValueTask<SocketReceiveFromResult> task)
antonfirsov marked this conversation as resolved.
Show resolved Hide resolved
{
SocketReceiveFromResult result = await task.ConfigureAwait(false);

byte[] buffer = result.ReceivedBytes < MaxUDPSize ?
_buffer.AsSpan(0, result.ReceivedBytes).ToArray() :
_buffer;

return new UdpReceiveResult(buffer, (IPEndPoint)result.RemoteEndPoint);
}
}

private void CreateClientSocket()
{
// Common initialization code.
Expand Down Expand Up @@ -892,45 +946,32 @@ public int Send(byte[] dgram, int bytes, IPEndPoint? endPoint)
return Client.SendTo(dgram, 0, bytes, SocketFlags.None, endPoint);
}


// Sends a UDP datagram to the specified port on the specified remote host.
public int Send(byte[] dgram, int bytes, string? hostname, int port)
// Sends a UDP datagram to the host at the remote end point.
public int Send(ReadOnlySpan<byte> datagram, IPEndPoint? endPoint)
{
ThrowIfDisposed();

if (dgram == null)
{
throw new ArgumentNullException(nameof(dgram));
}
if (_active && ((hostname != null) || (port != 0)))
if (_active && endPoint != null)
{
// Do not allow sending packets to arbitrary host when connected
throw new InvalidOperationException(SR.net_udpconnected);
}

if (hostname == null || port == 0)
{
return Client.Send(dgram, 0, bytes, SocketFlags.None);
}

IPAddress[] addresses = Dns.GetHostAddresses(hostname);

int i = 0;
for (; i < addresses.Length && !IsAddressFamilyCompatible(addresses[i].AddressFamily); i++)
if (endPoint == null)
{
; // just count the addresses
return Client.Send(datagram, SocketFlags.None);
}

if (addresses.Length == 0 || i == addresses.Length)
{
throw new ArgumentException(SR.net_invalidAddressList, nameof(hostname));
}
CheckForBroadcast(endPoint.Address);

CheckForBroadcast(addresses[i]);
IPEndPoint ipEndPoint = new IPEndPoint(addresses[i], port);
return Client.SendTo(dgram, 0, bytes, SocketFlags.None, ipEndPoint);
return Client.SendTo(datagram, SocketFlags.None, endPoint);
}

// Sends a UDP datagram to the specified port on the specified remote host.
public int Send(byte[] dgram, int bytes, string? hostname, int port) => Send(dgram, bytes, GetEndpoint(hostname, port));

// Sends a UDP datagram to the specified port on the specified remote host.
public int Send(ReadOnlySpan<byte> datagram, string? hostname, int port) => Send(datagram, GetEndpoint(hostname, port));

// Sends a UDP datagram to a remote host.
public int Send(byte[] dgram, int bytes)
Expand All @@ -950,6 +991,20 @@ public int Send(byte[] dgram, int bytes)
return Client.Send(dgram, 0, bytes, SocketFlags.None);
}

// Sends a UDP datagram to a remote host.
public int Send(ReadOnlySpan<byte> datagram)
{
ThrowIfDisposed();

if (!_active)
{
// only allowed on connected socket
throw new InvalidOperationException(SR.net_notconnected);
}

return Client.Send(datagram, SocketFlags.None);
}

private void ThrowIfDisposed()
{
if (_disposed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace System.Net.Sockets
{
/// <summary>
/// Presents UDP receive result information from a call to the <see cref="UdpClient.ReceiveAsync"/> method
/// Presents UDP receive result information from a call to the <see cref="UdpClient.ReceiveAsync()"/> and <see cref="UdpClient.ReceiveAsync(System.Threading.CancellationToken)"/> method
/// </summary>
public struct UdpReceiveResult : IEquatable<UdpReceiveResult>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,78 @@ public async Task SendToRecvFromAsync_Datagram_UDP_UdpClient(IPAddress loopbackA
}
}
}

[OuterLoop]
[Theory]
[MemberData(nameof(Loopbacks))]
public async Task SendToRecvFromAsyncWithReadOnlyMemory_Datagram_UDP_UdpClient(IPAddress loopbackAddress)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of duplicating the whole test case, can we add an additional bool useMemoryOverload parameter to the previous one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

{
IPAddress leftAddress = loopbackAddress, rightAddress = loopbackAddress;

const int DatagramSize = 256;
const int DatagramsToSend = 256;
const int AckTimeout = 20000;
const int TestTimeout = 60000;

using (var left = new UdpClient(new IPEndPoint(leftAddress, 0)))
using (var right = new UdpClient(new IPEndPoint(rightAddress, 0)))
{
var leftEndpoint = (IPEndPoint)left.Client.LocalEndPoint;
var rightEndpoint = (IPEndPoint)right.Client.LocalEndPoint;

var receiverAck = new ManualResetEventSlim();

var receivedChecksums = new uint?[DatagramsToSend];
int receivedDatagrams = 0;

Task receiverTask = Task.Run(async () =>
{
for (; receivedDatagrams < DatagramsToSend; receivedDatagrams++)
{
UdpReceiveResult result = await left.ReceiveAsync(default);

receiverAck.Set();

Assert.Equal(DatagramSize, result.Buffer.Length);
Assert.Equal(rightEndpoint, result.RemoteEndPoint);

int datagramId = (int)result.Buffer[0];
Assert.Null(receivedChecksums[datagramId]);

receivedChecksums[datagramId] = Fletcher32.Checksum(result.Buffer, 0, result.Buffer.Length);
}
});

var sentChecksums = new uint[DatagramsToSend];
int sentDatagrams = 0;

Task senderTask = Task.Run(async () =>
{
var random = new Random();
var sendBuffer = new byte[DatagramSize];

for (; sentDatagrams < DatagramsToSend; sentDatagrams++)
{
random.NextBytes(sendBuffer);
sendBuffer[0] = (byte)sentDatagrams;

int sent = await right.SendAsync(new ReadOnlyMemory<byte>(sendBuffer), leftEndpoint);

Assert.True(receiverAck.Wait(AckTimeout));
receiverAck.Reset();

Assert.Equal(DatagramSize, sent);
sentChecksums[sentDatagrams] = Fletcher32.Checksum(sendBuffer, 0, sent);
}
});

await (new[] { receiverTask, senderTask }).WhenAllOrAnyFailed(TestTimeout);
for (int i = 0; i < DatagramsToSend; i++)
{
Assert.NotNull(receivedChecksums[i]);
Assert.Equal(sentChecksums[i], (uint)receivedChecksums[i]);
}
}
}
}
}
Loading