Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SslOverTdsStream #541

Merged
merged 8 commits into from
Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0'">
<Compile Include="Microsoft\Data\SqlClient\SqlColumnEncryptionEnclaveProvider.NetCoreApp.cs" />
Expand All @@ -289,6 +290,7 @@
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetCoreApp.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
</ItemGroup>
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetGroup)' == 'netcoreapp' AND '$(BuildSimulator)' == 'true'">
<Compile Include="Microsoft\Data\SqlClient\SimulatorEnclaveProvider.NetCoreApp.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
{
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}

public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}

public override int Read(Span<byte> buffer)
{
if (!_encapsulate)
{
return _stream.Read(buffer);
}

using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.Read |SNI|INFO|SCOPE> reading encapsulated bytes"))
{
if (_packetBytes > 0)
{
// there are queued bytes from a previous packet available
// work out how many of the remaining bytes we can consume
int wantedCount = Math.Min(buffer.Length, _packetBytes);
int readCount = _stream.Read(buffer.Slice(0, wantedCount));
if (readCount == 0)
{
// 0 means the connection was closed, tell the caller
return 0;
}
_packetBytes -= readCount;
return readCount;
}
else
{
Span<byte> headerBytes = stackalloc byte[TdsEnums.HEADER_LEN];

// fetch the packet header to determine how long the packet is
int headerBytesRead = 0;
do
{
int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, TdsEnums.HEADER_LEN - headerBytesRead));
if (headerBytesReadIteration == 0)
{
// 0 means the connection was closed, tell the caller
return 0;
}
headerBytesRead += headerBytesReadIteration;
} while (headerBytesRead < TdsEnums.HEADER_LEN);

// read the packet data size from the header and store it in case it is needed for a subsequent call
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN;

// read as much from the packet as the caller can accept
int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)));
_packetBytes -= packetBytesRead;
return packetBytesRead;
}
}
}

public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
{
if (!_encapsulate)
{
int read;
{
ValueTask<int> readValueTask = _stream.ReadAsync(buffer, cancellationToken);
if (readValueTask.IsCompletedSuccessfully)
{
read = readValueTask.Result;
}
else
{
read = await readValueTask.ConfigureAwait(false);
}
}
return read;
}
DavoudEshtehari marked this conversation as resolved.
Show resolved Hide resolved
using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.ReadAsync |SNI|INFO|SCOPE> reading encapsulated bytes"))
{
if (_packetBytes > 0)
{
// there are queued bytes from a previous packet available
// work out how many of the remaining bytes we can consume
int wantedCount = Math.Min(buffer.Length, _packetBytes);

int readCount;
{
ValueTask<int> remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am curious... What is gained by the split logic for IsCompletedSuccessfully? Why not just always await?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It saves a call into the state machine which can avoid a lot of work. it's a common pattern with hot paths and since this is internal and will be required on all network io I thought I may as well include it. You can see it in https://devblogs.microsoft.com/dotnet/understanding-the-whys-whats-and-whens-of-valuetask/

Copy link
Member

@roji roji Apr 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm familiar with the pattern, but only when the async path has some additional logic that you want to skip in the sync path. For example, in the blog post you point to, the async path contains a RegisterCancellation invocation which should indeed be skipped if the operation completes synchronously - but here there's no such thing. Note also how Stephen explicitly recommends developers use this pattern "hopefully only after measuring carefully and finding it provides meaningful benefit".

Out of curiosity, I went ahead and measured the difference and got this:

BenchmarkDotNet=v0.12.0, OS=ubuntu 19.10
Intel Xeon W-2133 CPU 3.60GHz, 1 CPU, 12 logical and 6 physical cores
.NET Core SDK=5.0.100-preview.2.20176.6
  [Host]     : .NET Core 3.1.1 (CoreCLR 4.700.19.60701, CoreFX 4.700.19.60801), X64 RyuJIT
  DefaultJob : .NET Core 3.1.1 (CoreCLR 4.700.19.60701, CoreFX 4.700.19.60801), X64 RyuJIT

Method Mean Error StdDev Median
SplitLogic 48.28 ns 0.976 ns 1.631 ns 47.33 ns
JustAwaitIt 55.31 ns 0.047 ns 0.042 ns 55.31 ns
Benchmark code
public class Program
{
    [Benchmark]
    public async ValueTask<int> SplitLogic()
    {
        var valueTask = SomeSyncReturningAsyncMethod();
        return valueTask.IsCompletedSuccessfully
            ? valueTask.Result
            : await SomeSyncReturningAsyncMethod();
    }
    
    [Benchmark]
    public async ValueTask<int> JustAwaitIt()
        => await SomeSyncReturningAsyncMethod();

    ValueTask<int> SomeSyncReturningAsyncMethod() => new ValueTask<int>(8);
    
    static void Main(string[] args)
        => BenchmarkRunner.Run<Program>();
}

I personally wouldn't complicate the code with this pattern for a 7ns improvement - not in this project anyway, where it's almost sure to be lost in overall perf - but of course that's up to you and the SqlClient people to decide on. In any case, when working any kind of micro-optimizations like this one, it's really a good idea to have a BenchmarkDotNet suite to verify that there's meaningful improvement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the numbers. It is a tiny change in times. I view this part of the codebase as both performance critical and low level because it blocks any better performance higher in the pipeline so as you've seen I chose to go with the more complex and less maintainable long form. I've no problem changing it if someone strongly desires it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the codebase does indeed seem quite critical for perf and what you're doing seems important (am lacking context here and am just looking at the PR code). My point is only that specifically for this ValueTask pattern, I'd only introduce this kind of code when a functional benchmark shows it has a meaningful impact. My benchmark above isn't that - I'd be extremely surprised if this could be shown to have any impact on actual SqlClient code executing a command (and code complexity does have a price). Up to you guys.

if (remainderReadValueTask.IsCompletedSuccessfully)
{
readCount = remainderReadValueTask.Result;
}
else
{
readCount = await remainderReadValueTask.ConfigureAwait(false);
}
}
if (readCount == 0)
{
// 0 means the connection was closed, tell the caller
return 0;
}
_packetBytes -= readCount;
return readCount;
}
else
{
byte[] headerBytes = ArrayPool<byte>.Shared.Rent(TdsEnums.HEADER_LEN);

// fetch the packet header to determine how long the packet is
int headerBytesRead = 0;
do
{
int headerBytesReadIteration;
{
ValueTask<int> headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken);
if (headerReadValueTask.IsCompletedSuccessfully)
{
headerBytesReadIteration = headerReadValueTask.Result;
}
else
{
headerBytesReadIteration = await headerReadValueTask.ConfigureAwait(false);
}
}
if (headerBytesReadIteration == 0)
{
// 0 means the connection was closed, cleanup the rented array and then tell the caller
ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true);
return 0;
}
headerBytesRead += headerBytesReadIteration;
} while (headerBytesRead < TdsEnums.HEADER_LEN);

// read the packet data size from the header and store it in case it is needed for a subsequent call
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN;

ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true);

// read as much from the packet as the caller can accept
int packetBytesRead;
{
ValueTask<int> packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken);
if (packetReadValueTask.IsCompletedSuccessfully)
{
packetBytesRead = packetReadValueTask.Result;
}
else
{
packetBytesRead = await packetReadValueTask.ConfigureAwait(false);
}
}
_packetBytes -= packetBytesRead;
return packetBytesRead;
}
}
}

public override void Write(ReadOnlySpan<byte> buffer)
{
// During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
// negotiation, the underlying socket only sees SSL frames.
if (!_encapsulate)
{
_stream.Write(buffer);
_stream.Flush();
return;
}

using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.Write |SNI|INFO|SCOPE> writing encapsulated bytes"))
{
ReadOnlySpan<byte> remaining = buffer;
byte[] packetBuffer = null;
try
{
while (remaining.Length > 0)
{
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length);
int packetLength = TdsEnums.HEADER_LEN + dataLength;

if (packetBuffer == null)
{
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
}
else if (packetBuffer.Length < packetLength)
{
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
}

SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength);

Span<byte> data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength);
remaining.Slice(0, dataLength).CopyTo(data);

_stream.Write(packetBuffer.AsSpan(0, packetLength));
_stream.Flush();

remaining = remaining.Slice(dataLength);
}
}
finally
{
if (packetBuffer != null)
{
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
}
}
}
}

public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (!_encapsulate)
{
{
ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken);
if (!valueTask.IsCompletedSuccessfully)
{
await valueTask.ConfigureAwait(false);
}
}
Task flushTask = _stream.FlushAsync();
if (flushTask.IsCompletedSuccessfully)
{
await flushTask.ConfigureAwait(false);
}
return;
}

using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.WriteAsync |SNI|INFO|SCOPE> writing encapsulated bytes"))
{
ReadOnlyMemory<byte> remaining = buffer;
byte[] packetBuffer = null;
try
{
while (remaining.Length > 0)
{
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length);
int packetLength = TdsEnums.HEADER_LEN + dataLength;

if (packetBuffer == null)
{
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
}
else if (packetBuffer.Length < packetLength)
{
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
}

SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength);

remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength));

{
ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory<byte>(packetBuffer, 0, packetLength), cancellationToken);
if (!packetWriteValueTask.IsCompletedSuccessfully)
{
await packetWriteValueTask.ConfigureAwait(false);
}
}

await _stream.FlushAsync().ConfigureAwait(false);


remaining = remaining.Slice(dataLength);
}
}
finally
{
if (packetBuffer != null)
{
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
}
}
}
}
}
}
Loading