Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Amortize WebSocket.EnsureBufferContainsAsync calls #39455

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 120 additions & 14 deletions src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;

namespace System.Net.WebSockets
{
Expand Down Expand Up @@ -154,6 +155,8 @@ public static ManagedWebSocket CreateFromConnectedStream(
/// </summary>
private object ReceiveAsyncLock => _utf8TextState; // some object, as we're simply lock'ing on it

private EnsureBufferValueTaskSource _ensureBuffer;

/// <summary>Initializes the websocket.</summary>
/// <param name="stream">The connected Stream.</param>
/// <param name="isServer">true if this is the server-side of the connection; false if this is the client-side of the connection.</param>
Expand Down Expand Up @@ -1191,10 +1194,9 @@ private void ConsumeFromBuffer(int count)
_receiveBufferOffset += count;
}

private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true)
private ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure = true)
{
Debug.Assert(minimumRequiredBytes <= _receiveBuffer.Length, $"Requested number of bytes {minimumRequiredBytes} must not exceed {_receiveBuffer.Length}");

// If we don't have enough data in the buffer to satisfy the minimum required, read some more.
if (_receiveBufferCount < minimumRequiredBytes)
{
Expand All @@ -1205,18 +1207,12 @@ private async Task EnsureBufferContainsAsync(int minimumRequiredBytes, Cancellat
}
_receiveBufferOffset = 0;

// While we don't have enough data, read more.
while (_receiveBufferCount < minimumRequiredBytes)
{
int numRead = await _stream.ReadAsync(_receiveBuffer.Slice(_receiveBufferCount, _receiveBuffer.Length - _receiveBufferCount), cancellationToken).ConfigureAwait(false);
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
if (numRead <= 0)
{
ThrowIfEOFUnexpected(throwOnPrematureClosure);
break;
}
_receiveBufferCount += numRead;
}
EnsureBufferValueTaskSource ensureBuffer = (_ensureBuffer ??= new EnsureBufferValueTaskSource(this));
return ensureBuffer.EnsureBufferContainsAsync(minimumRequiredBytes, cancellationToken, throwOnPrematureClosure);
}
else
{
return default;
}
}

Expand Down Expand Up @@ -1503,5 +1499,115 @@ private interface IWebSocketReceiveResultGetter<TResult>
public WebSocketReceiveResult GetResult(int count, WebSocketMessageType messageType, bool endOfMessage, WebSocketCloseStatus? closeStatus, string closeDescription) =>
new WebSocketReceiveResult(count, messageType, endOfMessage, closeStatus, closeDescription);
}

private sealed class EnsureBufferValueTaskSource : IValueTaskSource
{
private readonly ManagedWebSocket _webSocket;
private readonly Action _onComplete;
private ManualResetValueTaskSourceCore<bool> _valueTaskSource;

private int _minimumRequiredBytes;
private CancellationToken _cancellationToken;
private bool _throwOnPrematureClosure;

private ConfiguredValueTaskAwaitable<int>.ConfiguredValueTaskAwaiter _awaiter;

public EnsureBufferValueTaskSource(ManagedWebSocket webSocket)
{
_webSocket = webSocket;
_onComplete = new Action(OnComplete);
}

public ValueTask EnsureBufferContainsAsync(int minimumRequiredBytes, CancellationToken cancellationToken, bool throwOnPrematureClosure)
{
_minimumRequiredBytes = minimumRequiredBytes;
_cancellationToken = cancellationToken;
_throwOnPrematureClosure = throwOnPrematureClosure;

return EnsureBufferContains();
}

private ValueTask EnsureBufferContains(bool completeSourceIfComplete = false)
{
// While we don't have enough data, read more.
while (_webSocket._receiveBufferCount < _minimumRequiredBytes)
{
ValueTask<int> vt = ReadAsync();
if (!vt.IsCompletedSuccessfully)
{
_awaiter = vt.ConfigureAwait(false).GetAwaiter();
_awaiter.UnsafeOnCompleted(_onComplete);
return new ValueTask(this, _valueTaskSource.Version);
}
benaadams marked this conversation as resolved.
Show resolved Hide resolved

int numRead = vt.Result;
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
if (numRead <= 0)
{
_webSocket.ThrowIfEOFUnexpected(_throwOnPrematureClosure);
break;
}
_webSocket._receiveBufferCount += numRead;
}

// Completed sync
if (completeSourceIfComplete)
{
// SetResult if this was called from OnComplete
_valueTaskSource.SetResult(default);
}

return default;
}

private void OnComplete()
{
try
{
int numRead = _awaiter.GetResult();
Debug.Assert(numRead >= 0, $"Expected non-negative bytes read, got {numRead}");
if (numRead <= 0)
{
_webSocket.ThrowIfEOFUnexpected(_throwOnPrematureClosure);
_valueTaskSource.SetResult(default);
return;
}
_webSocket._receiveBufferCount += numRead;

if (_webSocket._receiveBufferCount < _minimumRequiredBytes)
{
EnsureBufferContains(completeSourceIfComplete: true);
return;
}
else
{
_valueTaskSource.SetResult(default);
}
}
catch (Exception ex)
{
_valueTaskSource.SetException(ex);
}
}

private ValueTask<int> ReadAsync()
{
return _webSocket._stream.ReadAsync(
_webSocket._receiveBuffer.Slice(_webSocket._receiveBufferCount, _webSocket._receiveBuffer.Length - _webSocket._receiveBufferCount),
_cancellationToken);
}

void IValueTaskSource.GetResult(short token)
{
_valueTaskSource.GetResult(token);
_valueTaskSource.Reset();
}

ValueTaskSourceStatus IValueTaskSource.GetStatus(short token)
=> _valueTaskSource.GetStatus(token);

void IValueTaskSource.OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
=> _valueTaskSource.OnCompleted(continuation, state, token, flags);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,10 @@
<ItemGroup>
benaadams marked this conversation as resolved.
Show resolved Hide resolved
<Reference Include="System.Memory" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetsNetFx)' == 'true'">
<Reference Include="mscorlib" />
<Reference Include="netstandard" />
<Reference Include="System.Threading.Tasks.Extensions" />
<ProjectReference Include="..\..\Microsoft.Bcl.AsyncInterfaces\ref\Microsoft.Bcl.AsyncInterfaces.csproj" />
benaadams marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<ProjectGuid>{747BE014-7C1D-4460-95AF-B41C35717165}</ProjectGuid>
<NoWarn>$(NoWarn);CS1573</NoWarn>
<Configurations>netcoreapp-Debug;netcoreapp-Release;netcoreapp2.1-Debug;netcoreapp2.1-Release;netstandard-Debug;netstandard-Release</Configurations>
<DefineConstants Condition="'$(TargetsNetFx)' == 'true'">$(DefineConstants);netstandard</DefineConstants>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(CommonPath)\System\Net\WebSockets\ManagedWebSocket.cs">
Expand Down Expand Up @@ -41,4 +42,7 @@
<Reference Include="System.Threading.Tasks.Extensions" />
<Reference Include="System.Threading.Timer" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)'!='netcoreapp'">
<Reference Include="Microsoft.Bcl.AsyncInterfaces" />
</ItemGroup>
</Project>
7 changes: 7 additions & 0 deletions src/System.Net.WebSockets/ref/System.Net.WebSockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,11 @@
<ProjectReference Include="..\..\System.Security.Principal\ref\System.Security.Principal.csproj" />
<ProjectReference Include="..\..\System.Threading.Tasks\ref\System.Threading.Tasks.csproj" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetsNetFx)' == 'true'">
<Reference Include="mscorlib" />
<Reference Include="netstandard" />
<Reference Include="System.Memory" />
<Reference Include="System.Threading.Tasks.Extensions" />
<ProjectReference Include="..\..\Microsoft.Bcl.AsyncInterfaces\ref\Microsoft.Bcl.AsyncInterfaces.csproj" />
</ItemGroup>
benaadams marked this conversation as resolved.
Show resolved Hide resolved
</Project>
4 changes: 4 additions & 0 deletions src/System.Net.WebSockets/src/System.Net.WebSockets.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>
<NoWarn>$(NoWarn);CS1573</NoWarn>
<Configurations>netcoreapp-Debug;netcoreapp-Release;uap-Windows_NT-Debug;uap-Windows_NT-Release</Configurations>
<DefineConstants Condition="'$(TargetsNetFx)' == 'true'">$(DefineConstants);netstandard</DefineConstants>
</PropertyGroup>
<ItemGroup>
<Compile Include="System\Net\WebSockets\ManagedWebSocket.netcoreapp.cs" />
Expand Down Expand Up @@ -44,4 +45,7 @@
<Reference Include="System.Threading.Tasks" />
<Reference Include="System.Threading.Timer" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)'!='netcoreapp'">
<Reference Include="Microsoft.Bcl.AsyncInterfaces" />
</ItemGroup>
benaadams marked this conversation as resolved.
Show resolved Hide resolved
</Project>