Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Improve Async Paths #335

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacket.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\AzureAttestationBasedEnclaveProvider.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\VirtualSecureModeEnclaveProvider.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\VirtualSecureModeEnclaveProviderBase.NetCoreApp.cs" />
Expand All @@ -57,6 +58,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPacket.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
</ItemGroup>
<ItemGroup Condition="'$(IsPartialFacadeAssembly)' != 'true' AND '$(OSGroup)' != 'AnyOS'">
<Compile Include="Microsoft\Data\SqlClient\Server\ITypedGetters.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ internal partial class SNIPacket
/// <param name="callback">Completion callback</param>
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
{
// Treat local function as a static and pass all params otherwise as async will allocate
async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<int> valueTask)
{
bool error = false;
Expand All @@ -45,7 +44,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<
cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}

ValueTask<int> vt = stream.ReadAsync(new Memory<byte>(_data, 0, _capacity), CancellationToken.None);
ValueTask<int> vt;
try
{
vt = stream.ReadAsync(new Memory<byte>(_data, 0, _capacity), CancellationToken.None);
}
catch (Exception ex)
{
vt = new ValueTask<int>(Task.FromException<int>(ex));
}

if (vt.IsCompletedSuccessfully)
{
Expand Down Expand Up @@ -78,8 +85,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<
/// <param name="disposeAfterWriteAsync"></param>
public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
{
// Treat local function as a static and pass all params otherwise as async will allocate
async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask)
async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool dispose, ValueTask valueTask)
{
uint status = TdsEnums.SNI_SUCCESS;
try
Expand All @@ -94,13 +100,21 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider

cb(packet, status);

if (disposeAfter)
if (dispose)
{
packet.Dispose();
}
}

ValueTask vt = stream.WriteAsync(new Memory<byte>(_data, 0, _length), CancellationToken.None);
ValueTask vt;
try
{
vt = stream.WriteAsync(new Memory<byte>(_data, 0, _length), CancellationToken.None);
}
catch (Exception ex)
{
vt = new ValueTask(Task.FromException(ex));
}

if (vt.IsCompletedSuccessfully)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task<int>
cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}

Task<int> t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
Task<int> t;
try
{
t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
}
catch (Exception ex)
{
t = Task.FromException<int>(ex);
}

if ((t.Status & TaskStatus.RanToCompletion) != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code (t.Status & TaskStatus.RanToCompletion) != 0 in SNIPacket.NetStandard.cs looks strange, since TaskStatus is not a flag, so this will even be true for TaskStatus.WaitingForActivation, maybe it should be t.Status == TaskStatus.RanToCompletion instead to follow the logic in netcoreapp?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Looking at the enum values that'll let fault and cancel through. I'm surprised visual studio doesn't complain about doing bitwise operations on a non-[Flag] enum.

{
Expand Down Expand Up @@ -95,7 +103,15 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider
}
}

Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None);
Task t;
try
{
t = stream.WriteAsync(_data, 0, _length, CancellationToken.None);
}
catch (Exception ex)
{
t = Task.FromException(ex);
}

if ((t.Status & TaskStatus.RanToCompletion) != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also for this one

{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
=> ReadInternal(new Memory<byte>(buffer, offset, count), default, async: false).GetAwaiter().GetResult();

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken token)
=> ReadInternal(new Memory<byte>(buffer, offset, count), token, async: true).AsTask();

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken)
=> ReadInternal(buffer, cancellationToken, async: true);

public override void Write(byte[] buffer, int offset, int count)
=> WriteInternal(new ReadOnlyMemory<byte>(buffer, offset, count), default, async: true).GetAwaiter().GetResult();

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token)
=> WriteInternal(new ReadOnlyMemory<byte>(buffer, offset, count), token, async: true).AsTask();

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
=> WriteInternal(buffer, cancellationToken, async: true);

/// <summary>
/// Read Internal is called synchronosly when async is false
/// </summary>
private async ValueTask<int> ReadInternal(Memory<byte> buffer, CancellationToken token, bool async)
{
int readBytes = 0;
int count = buffer.Length;
byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count];

if (_encapsulate)
{
if (_packetBytes == 0)
{
// Account for split packets
while (readBytes < TdsEnums.HEADER_LEN)
{
readBytes += async ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can infinite loop if the connection is closed and the read returns 0 and doesn't throw, Are closed connections guaranteed to throw or return non-zero?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... the existing loop also has this issue.

Will fix in both. Good spot

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone else mentioned it elsewhere before I just haven't got around to looking at it in any depth yet so since you're here anyway I thought it worth mentioning.

There may not be a clear way to "fix" it from what I remember. I'm not sure what should happen if you end up in a partial packet read because it's below ssl stream which is going to try and decode anything you return in the buffer which can error that looks like a transport error when in fact it's connection closed. See what you think.

await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) :
_stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes);
}

_packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1];
_packetBytes -= TdsEnums.HEADER_LEN;
}

if (count > _packetBytes)
{
count = _packetBytes;
}
}

readBytes = async ?
await _stream.ReadAsync(new Memory<byte>(packetData, 0, count), token).ConfigureAwait(false) :
_stream.Read(packetData.AsSpan(0, count));

if (_encapsulate)
{
_packetBytes -= readBytes;
}

packetData.AsSpan(0, readBytes).CopyTo(buffer.Span);
return readBytes;
}

/// <summary>
/// The internal write method calls Sync APIs when Async flag is false
/// </summary>
private async ValueTask WriteInternal(ReadOnlyMemory<byte> buffer, CancellationToken token, bool async)
{
int count = buffer.Length;
int currentOffset = 0;

while (count > 0)
{
int currentCount;
// During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
// negotiation, the underlying socket only sees SSL frames.
//
if (_encapsulate)
{
if (count > PACKET_SIZE_WITHOUT_HEADER)
{
currentCount = PACKET_SIZE_WITHOUT_HEADER;
}
else
{
currentCount = count;
}

count -= currentCount;

// Prepend buffer data with TDS prelogin header
byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount];

// We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a
// partial packet (whether or not count != 0).
//
combinedBuffer[0] = PRELOGIN_PACKET_TYPE;
combinedBuffer[1] = (byte)(count > 0 ? 0 : 1);
combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100);
combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100);
combinedBuffer[4] = 0;
combinedBuffer[5] = 0;
combinedBuffer[6] = 0;
combinedBuffer[7] = 0;

CopyToBuffer(combinedBuffer, buffer.Span.Slice(currentOffset));

if (async)
{
await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false);
}
else
{
_stream.Write(combinedBuffer.AsSpan());
}
}
else
{
currentCount = count;
count = 0;

if (async)
{
await _stream.WriteAsync(buffer.Slice(currentOffset, currentCount), token).ConfigureAwait(false);
}
else
{
_stream.Write(buffer.Span.Slice(currentOffset, currentCount));
}
}

if (async)
{
await _stream.FlushAsync().ConfigureAwait(false);
}
else
{
_stream.Flush();
}

currentOffset += currentCount;
}

void CopyToBuffer(byte[] combinedBuffer, ReadOnlySpan<byte> span)
{
for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++)
{
combinedBuffer[i] = span[i - TdsEnums.HEADER_LEN];
}
}
}
}
}
Loading