diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index b0a0b70040..f949ac4548 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -45,6 +45,7 @@
+
@@ -57,6 +58,7 @@
+
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs
index e6a35caeda..49a4308eb7 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs
@@ -18,7 +18,6 @@ internal partial class SNIPacket
/// Completion callback
public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback)
{
- // Treat local function as a static and pass all params otherwise as async will allocate
async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask valueTask)
{
bool error = false;
@@ -45,7 +44,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<
cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}
- ValueTask vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None);
+ ValueTask vt;
+ try
+ {
+ vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None);
+ }
+ catch (Exception ex)
+ {
+ vt = new ValueTask(Task.FromException(ex));
+ }
if (vt.IsCompletedSuccessfully)
{
@@ -78,8 +85,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask<
///
public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false)
{
- // Treat local function as a static and pass all params otherwise as async will allocate
- async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask)
+ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool dispose, ValueTask valueTask)
{
uint status = TdsEnums.SNI_SUCCESS;
try
@@ -94,13 +100,21 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider
cb(packet, status);
- if (disposeAfter)
+ if (dispose)
{
packet.Dispose();
}
}
- ValueTask vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None);
+ ValueTask vt;
+ try
+ {
+ vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None);
+ }
+ catch (Exception ex)
+ {
+ vt = new ValueTask(Task.FromException(ex));
+ }
if (vt.IsCompletedSuccessfully)
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs
index 2a3cf12670..ebd93aaa99 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs
@@ -45,7 +45,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task
cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS);
}
- Task t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
+ Task t;
+ try
+ {
+ t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None);
+ }
+ catch (Exception ex)
+ {
+ t = Task.FromException(ex);
+ }
if ((t.Status & TaskStatus.RanToCompletion) != 0)
{
@@ -95,7 +103,15 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider
}
}
- Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None);
+ Task t;
+ try
+ {
+ t = stream.WriteAsync(_data, 0, _length, CancellationToken.None);
+ }
+ catch (Exception ex)
+ {
+ t = Task.FromException(ex);
+ }
if ((t.Status & TaskStatus.RanToCompletion) != 0)
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
new file mode 100644
index 0000000000..ab4b506c5f
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
@@ -0,0 +1,164 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Data.SqlClient.SNI
+{
+ internal sealed partial class SslOverTdsStream
+ {
+ public override int Read(byte[] buffer, int offset, int count)
+ => ReadInternal(new Memory(buffer, offset, count), default, async: false).GetAwaiter().GetResult();
+
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token)
+ => ReadInternal(new Memory(buffer, offset, count), token, async: true).AsTask();
+
+ public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken)
+ => ReadInternal(buffer, cancellationToken, async: true);
+
+ public override void Write(byte[] buffer, int offset, int count)
+ => WriteInternal(new ReadOnlyMemory(buffer, offset, count), default, async: true).GetAwaiter().GetResult();
+
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token)
+ => WriteInternal(new ReadOnlyMemory(buffer, offset, count), token, async: true).AsTask();
+
+ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken)
+ => WriteInternal(buffer, cancellationToken, async: true);
+
+ ///
+ /// Read Internal is called synchronosly when async is false
+ ///
+ private async ValueTask ReadInternal(Memory buffer, CancellationToken token, bool async)
+ {
+ int readBytes = 0;
+ int count = buffer.Length;
+ byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count];
+
+ if (_encapsulate)
+ {
+ if (_packetBytes == 0)
+ {
+ // Account for split packets
+ while (readBytes < TdsEnums.HEADER_LEN)
+ {
+ readBytes += async ?
+ await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) :
+ _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes);
+ }
+
+ _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1];
+ _packetBytes -= TdsEnums.HEADER_LEN;
+ }
+
+ if (count > _packetBytes)
+ {
+ count = _packetBytes;
+ }
+ }
+
+ readBytes = async ?
+ await _stream.ReadAsync(new Memory(packetData, 0, count), token).ConfigureAwait(false) :
+ _stream.Read(packetData.AsSpan(0, count));
+
+ if (_encapsulate)
+ {
+ _packetBytes -= readBytes;
+ }
+
+ packetData.AsSpan(0, readBytes).CopyTo(buffer.Span);
+ return readBytes;
+ }
+
+ ///
+ /// The internal write method calls Sync APIs when Async flag is false
+ ///
+ private async ValueTask WriteInternal(ReadOnlyMemory buffer, CancellationToken token, bool async)
+ {
+ int count = buffer.Length;
+ int currentOffset = 0;
+
+ while (count > 0)
+ {
+ int currentCount;
+ // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
+ // negotiation, the underlying socket only sees SSL frames.
+ //
+ if (_encapsulate)
+ {
+ if (count > PACKET_SIZE_WITHOUT_HEADER)
+ {
+ currentCount = PACKET_SIZE_WITHOUT_HEADER;
+ }
+ else
+ {
+ currentCount = count;
+ }
+
+ count -= currentCount;
+
+ // Prepend buffer data with TDS prelogin header
+ byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount];
+
+ // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a
+ // partial packet (whether or not count != 0).
+ //
+ combinedBuffer[0] = PRELOGIN_PACKET_TYPE;
+ combinedBuffer[1] = (byte)(count > 0 ? 0 : 1);
+ combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100);
+ combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100);
+ combinedBuffer[4] = 0;
+ combinedBuffer[5] = 0;
+ combinedBuffer[6] = 0;
+ combinedBuffer[7] = 0;
+
+ CopyToBuffer(combinedBuffer, buffer.Span.Slice(currentOffset));
+
+ if (async)
+ {
+ await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Write(combinedBuffer.AsSpan());
+ }
+ }
+ else
+ {
+ currentCount = count;
+ count = 0;
+
+ if (async)
+ {
+ await _stream.WriteAsync(buffer.Slice(currentOffset, currentCount), token).ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Write(buffer.Span.Slice(currentOffset, currentCount));
+ }
+ }
+
+ if (async)
+ {
+ await _stream.FlushAsync().ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Flush();
+ }
+
+ currentOffset += currentCount;
+ }
+
+ void CopyToBuffer(byte[] combinedBuffer, ReadOnlySpan span)
+ {
+ for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++)
+ {
+ combinedBuffer[i] = span[i - TdsEnums.HEADER_LEN];
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs
new file mode 100644
index 0000000000..5e36902ce1
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs
@@ -0,0 +1,182 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.IO;
+using System.IO.Pipes;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Data.SqlClient.SNI
+{
+ internal sealed partial class SslOverTdsStream
+ {
+ ///
+ /// Read buffer
+ ///
+ /// Buffer
+ /// Offset
+ /// Byte count
+ /// Bytes read
+ public override int Read(byte[] buffer, int offset, int count) =>
+ ReadInternal(buffer, offset, count, CancellationToken.None, async: false).GetAwaiter().GetResult();
+
+ ///
+ /// Write Buffer
+ ///
+ ///
+ ///
+ ///
+ public override void Write(byte[] buffer, int offset, int count)
+ => WriteInternal(buffer, offset, count, CancellationToken.None, async: false).Wait();
+
+ ///
+ /// Write Buffer Asynchronosly
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token)
+ => WriteInternal(buffer, offset, count, token, async: true);
+
+ ///
+ /// Read Buffer Asynchronosly
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token)
+ => ReadInternal(buffer, offset, count, token, async: true);
+
+ ///
+ /// Read Internal is called synchronosly when async is false
+ ///
+ private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async)
+ {
+ int readBytes = 0;
+ byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count];
+
+ if (_encapsulate)
+ {
+ if (_packetBytes == 0)
+ {
+ // Account for split packets
+ while (readBytes < TdsEnums.HEADER_LEN)
+ {
+ readBytes += async ?
+ await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) :
+ _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes);
+ }
+
+ _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1];
+ _packetBytes -= TdsEnums.HEADER_LEN;
+ }
+
+ if (count > _packetBytes)
+ {
+ count = _packetBytes;
+ }
+ }
+
+ readBytes = async ?
+ await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) :
+ _stream.Read(packetData, 0, count);
+
+ if (_encapsulate)
+ {
+ _packetBytes -= readBytes;
+ }
+
+ Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes);
+ return readBytes;
+ }
+
+ ///
+ /// The internal write method calls Sync APIs when Async flag is false
+ ///
+ private async Task WriteInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async)
+ {
+ int currentCount = 0;
+ int currentOffset = offset;
+
+ while (count > 0)
+ {
+ // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
+ // negotiation, the underlying socket only sees SSL frames.
+ //
+ if (_encapsulate)
+ {
+ if (count > PACKET_SIZE_WITHOUT_HEADER)
+ {
+ currentCount = PACKET_SIZE_WITHOUT_HEADER;
+ }
+ else
+ {
+ currentCount = count;
+ }
+
+ count -= currentCount;
+
+ // Prepend buffer data with TDS prelogin header
+ byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount];
+
+ // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a
+ // partial packet (whether or not count != 0).
+ //
+ combinedBuffer[0] = PRELOGIN_PACKET_TYPE;
+ combinedBuffer[1] = (byte)(count > 0 ? 0 : 1);
+ combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100);
+ combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100);
+ combinedBuffer[4] = 0;
+ combinedBuffer[5] = 0;
+ combinedBuffer[6] = 0;
+ combinedBuffer[7] = 0;
+
+ for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++)
+ {
+ combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)];
+ }
+
+ if (async)
+ {
+ await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Write(combinedBuffer, 0, combinedBuffer.Length);
+ }
+ }
+ else
+ {
+ currentCount = count;
+ count = 0;
+
+ if (async)
+ {
+ await _stream.WriteAsync(buffer, currentOffset, currentCount, token).ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Write(buffer, currentOffset, currentCount);
+ }
+ }
+
+ if (async)
+ {
+ await _stream.FlushAsync().ConfigureAwait(false);
+ }
+ else
+ {
+ _stream.Flush();
+ }
+
+ currentOffset += currentCount;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs
index 6fc4ec0268..7c906fc00b 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs
@@ -15,7 +15,7 @@ namespace Microsoft.Data.SqlClient.SNI
/// transported in TDS packet type 0x12. Once SSL handshake has completed, SSL
/// packets are sent transparently.
///
- internal sealed class SslOverTdsStream : Stream
+ internal sealed partial class SslOverTdsStream : Stream
{
private readonly Stream _stream;
@@ -43,173 +43,6 @@ public void FinishHandshake()
_encapsulate = false;
}
- ///
- /// Read buffer
- ///
- /// Buffer
- /// Offset
- /// Byte count
- /// Bytes read
- public override int Read(byte[] buffer, int offset, int count) =>
- ReadInternal(buffer, offset, count, CancellationToken.None, async: false).GetAwaiter().GetResult();
-
- ///
- /// Write Buffer
- ///
- ///
- ///
- ///
- public override void Write(byte[] buffer, int offset, int count)
- => WriteInternal(buffer, offset, count, CancellationToken.None, async: false).Wait();
-
- ///
- /// Write Buffer Asynchronosly
- ///
- ///
- ///
- ///
- ///
- ///
- public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token)
- => WriteInternal(buffer, offset, count, token, async: true);
-
- ///
- /// Read Buffer Asynchronosly
- ///
- ///
- ///
- ///
- ///
- ///
- public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token)
- => ReadInternal(buffer, offset, count, token, async: true);
-
- ///
- /// Read Internal is called synchronosly when async is false
- ///
- private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async)
- {
- int readBytes = 0;
- byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count];
-
- if (_encapsulate)
- {
- if (_packetBytes == 0)
- {
- // Account for split packets
- while (readBytes < TdsEnums.HEADER_LEN)
- {
- readBytes += async ?
- await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) :
- _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes);
- }
-
- _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1];
- _packetBytes -= TdsEnums.HEADER_LEN;
- }
-
- if (count > _packetBytes)
- {
- count = _packetBytes;
- }
- }
-
- readBytes = async ?
- await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) :
- _stream.Read(packetData, 0, count);
-
- if (_encapsulate)
- {
- _packetBytes -= readBytes;
- }
-
- Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes);
- return readBytes;
- }
-
- ///
- /// The internal write method calls Sync APIs when Async flag is false
- ///
- private async Task WriteInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async)
- {
- int currentCount = 0;
- int currentOffset = offset;
-
- while (count > 0)
- {
- // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
- // negotiation, the underlying socket only sees SSL frames.
- //
- if (_encapsulate)
- {
- if (count > PACKET_SIZE_WITHOUT_HEADER)
- {
- currentCount = PACKET_SIZE_WITHOUT_HEADER;
- }
- else
- {
- currentCount = count;
- }
-
- count -= currentCount;
-
- // Prepend buffer data with TDS prelogin header
- byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount];
-
- // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a
- // partial packet (whether or not count != 0).
- //
- combinedBuffer[0] = PRELOGIN_PACKET_TYPE;
- combinedBuffer[1] = (byte)(count > 0 ? 0 : 1);
- combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100);
- combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100);
- combinedBuffer[4] = 0;
- combinedBuffer[5] = 0;
- combinedBuffer[6] = 0;
- combinedBuffer[7] = 0;
-
- for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++)
- {
- combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)];
- }
-
- if (async)
- {
- await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false);
- }
- else
- {
- _stream.Write(combinedBuffer, 0, combinedBuffer.Length);
- }
- }
- else
- {
- currentCount = count;
- count = 0;
-
- if (async)
- {
- await _stream.WriteAsync(buffer, currentOffset, currentCount, token).ConfigureAwait(false);
- }
- else
- {
- _stream.Write(buffer, currentOffset, currentCount);
- }
- }
-
- if (async)
- {
- await _stream.FlushAsync().ConfigureAwait(false);
- }
- else
- {
- _stream.Flush();
- }
-
- currentOffset += currentCount;
- }
- }
-
///
/// Set stream length.
///
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs
index d8026b36e3..5035259414 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs
@@ -4,6 +4,8 @@
using System;
using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.Data.SqlClient
{
@@ -20,5 +22,46 @@ internal static Guid ConstructGuid(ReadOnlySpan bytes)
Debug.Assert(bytes.Length >= 16, "not enough bytes to set guid");
return new Guid(bytes);
}
+
+ private sealed partial class TdsOutputStream
+ {
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ => WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask();
+
+ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken)
+ {
+ Debug.Assert(_parser._asyncWrite);
+ ReadOnlySpan span = buffer.Span;
+
+ StripPreamble(ref span);
+
+ ValueTask task = default;
+ if (span.Length > 0)
+ {
+ _parser.WriteInt(span.Length, _stateObj); // write length of chunk
+ task = new ValueTask(_stateObj.WriteByteSpan(span, canAccumulate: false));
+ }
+
+ return task;
+ }
+
+ private void StripPreamble(ref ReadOnlySpan buffer)
+ {
+ if (_preambleToStrip != null && buffer.Length >= _preambleToStrip.Length)
+ {
+ for (int idx = 0; idx < _preambleToStrip.Length; idx++)
+ {
+ if (_preambleToStrip[idx] != buffer[idx])
+ {
+ _preambleToStrip = null;
+ return;
+ }
+ }
+
+ buffer = buffer.Slice(_preambleToStrip.Length);
+ }
+ _preambleToStrip = null;
+ }
+ }
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs
index 72c0b77b19..3b93490612 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs
@@ -5,6 +5,8 @@
using System;
using System.Buffers;
using System.Diagnostics;
+using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.Data.SqlClient
{
@@ -37,5 +39,25 @@ internal static Guid ConstructGuid(ReadOnlySpan bytes)
ArrayPool.Shared.Return(temp);
return retval;
}
+
+ private sealed partial class TdsOutputStream
+ {
+ public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ Debug.Assert(_parser._asyncWrite);
+ ValidateWriteParameters(buffer, offset, count);
+
+ StripPreamble(buffer, ref offset, ref count);
+
+ Task task = null;
+ if (count > 0)
+ {
+ _parser.WriteInt(count, _stateObj); // write length of chunk
+ task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false);
+ }
+
+ return task ?? Task.CompletedTask;
+ }
+ }
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
index e530dacc6a..3331da20b4 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs
@@ -10870,7 +10870,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe
return null;
}
- private sealed class TdsOutputStream : Stream
+ private sealed partial class TdsOutputStream : Stream
{
private TdsParser _parser;
private TdsParserStateObject _stateObj;
@@ -10968,23 +10968,6 @@ public override void Write(byte[] buffer, int offset, int count)
}
}
- public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
- {
- Debug.Assert(_parser._asyncWrite);
- ValidateWriteParameters(buffer, offset, count);
-
- StripPreamble(buffer, ref offset, ref count);
-
- Task task = null;
- if (count > 0)
- {
- _parser.WriteInt(count, _stateObj); // write length of chunk
- task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false);
- }
-
- return task ?? Task.CompletedTask;
- }
-
internal static void ValidateWriteParameters(byte[] buffer, int offset, int count)
{
if (buffer == null)