diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index b0a0b70040..f949ac4548 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -45,6 +45,7 @@ + @@ -57,6 +58,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs index e6a35caeda..49a4308eb7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetCoreApp.cs @@ -18,7 +18,6 @@ internal partial class SNIPacket /// Completion callback public void ReadFromStreamAsync(Stream stream, SNIAsyncCallback callback) { - // Treat local function as a static and pass all params otherwise as async will allocate async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask valueTask) { bool error = false; @@ -45,7 +44,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - ValueTask vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None); + ValueTask vt; + try + { + vt = stream.ReadAsync(new Memory(_data, 0, _capacity), CancellationToken.None); + } + catch (Exception ex) + { + vt = new ValueTask(Task.FromException(ex)); + } if (vt.IsCompletedSuccessfully) { @@ -78,8 +85,7 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, ValueTask< /// public void WriteToStreamAsync(Stream stream, SNIAsyncCallback callback, SNIProviders provider, bool disposeAfterWriteAsync = false) { - // Treat local function as a static and pass all params otherwise as async will allocate - async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool disposeAfter, ValueTask valueTask) + async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProviders providers, bool dispose, ValueTask valueTask) { uint status = TdsEnums.SNI_SUCCESS; try @@ -94,13 +100,21 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider cb(packet, status); - if (disposeAfter) + if (dispose) { packet.Dispose(); } } - ValueTask vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None); + ValueTask vt; + try + { + vt = stream.WriteAsync(new Memory(_data, 0, _length), CancellationToken.None); + } + catch (Exception ex) + { + vt = new ValueTask(Task.FromException(ex)); + } if (vt.IsCompletedSuccessfully) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs index 2a3cf12670..ebd93aaa99 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPacket.NetStandard.cs @@ -45,7 +45,15 @@ async Task ReadFromStreamAsync(SNIPacket packet, SNIAsyncCallback cb, Task cb(packet, error ? TdsEnums.SNI_ERROR : TdsEnums.SNI_SUCCESS); } - Task t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None); + Task t; + try + { + t = stream.ReadAsync(_data, 0, _capacity, CancellationToken.None); + } + catch (Exception ex) + { + t = Task.FromException(ex); + } if ((t.Status & TaskStatus.RanToCompletion) != 0) { @@ -95,7 +103,15 @@ async Task WriteToStreamAsync(SNIPacket packet, SNIAsyncCallback cb, SNIProvider } } - Task t = stream.WriteAsync(_data, 0, _length, CancellationToken.None); + Task t; + try + { + t = stream.WriteAsync(_data, 0, _length, CancellationToken.None); + } + catch (Exception ex) + { + t = Task.FromException(ex); + } if ((t.Status & TaskStatus.RanToCompletion) != 0) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs new file mode 100644 index 0000000000..ab4b506c5f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs @@ -0,0 +1,164 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.SNI +{ + internal sealed partial class SslOverTdsStream + { + public override int Read(byte[] buffer, int offset, int count) + => ReadInternal(new Memory(buffer, offset, count), default, async: false).GetAwaiter().GetResult(); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token) + => ReadInternal(new Memory(buffer, offset, count), token, async: true).AsTask(); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) + => ReadInternal(buffer, cancellationToken, async: true); + + public override void Write(byte[] buffer, int offset, int count) + => WriteInternal(new ReadOnlyMemory(buffer, offset, count), default, async: true).GetAwaiter().GetResult(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token) + => WriteInternal(new ReadOnlyMemory(buffer, offset, count), token, async: true).AsTask(); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + => WriteInternal(buffer, cancellationToken, async: true); + + /// + /// Read Internal is called synchronosly when async is false + /// + private async ValueTask ReadInternal(Memory buffer, CancellationToken token, bool async) + { + int readBytes = 0; + int count = buffer.Length; + byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count]; + + if (_encapsulate) + { + if (_packetBytes == 0) + { + // Account for split packets + while (readBytes < TdsEnums.HEADER_LEN) + { + readBytes += async ? + await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) : + _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes); + } + + _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; + _packetBytes -= TdsEnums.HEADER_LEN; + } + + if (count > _packetBytes) + { + count = _packetBytes; + } + } + + readBytes = async ? + await _stream.ReadAsync(new Memory(packetData, 0, count), token).ConfigureAwait(false) : + _stream.Read(packetData.AsSpan(0, count)); + + if (_encapsulate) + { + _packetBytes -= readBytes; + } + + packetData.AsSpan(0, readBytes).CopyTo(buffer.Span); + return readBytes; + } + + /// + /// The internal write method calls Sync APIs when Async flag is false + /// + private async ValueTask WriteInternal(ReadOnlyMemory buffer, CancellationToken token, bool async) + { + int count = buffer.Length; + int currentOffset = 0; + + while (count > 0) + { + int currentCount; + // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After + // negotiation, the underlying socket only sees SSL frames. + // + if (_encapsulate) + { + if (count > PACKET_SIZE_WITHOUT_HEADER) + { + currentCount = PACKET_SIZE_WITHOUT_HEADER; + } + else + { + currentCount = count; + } + + count -= currentCount; + + // Prepend buffer data with TDS prelogin header + byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount]; + + // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a + // partial packet (whether or not count != 0). + // + combinedBuffer[0] = PRELOGIN_PACKET_TYPE; + combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); + combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); + combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100); + combinedBuffer[4] = 0; + combinedBuffer[5] = 0; + combinedBuffer[6] = 0; + combinedBuffer[7] = 0; + + CopyToBuffer(combinedBuffer, buffer.Span.Slice(currentOffset)); + + if (async) + { + await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false); + } + else + { + _stream.Write(combinedBuffer.AsSpan()); + } + } + else + { + currentCount = count; + count = 0; + + if (async) + { + await _stream.WriteAsync(buffer.Slice(currentOffset, currentCount), token).ConfigureAwait(false); + } + else + { + _stream.Write(buffer.Span.Slice(currentOffset, currentCount)); + } + } + + if (async) + { + await _stream.FlushAsync().ConfigureAwait(false); + } + else + { + _stream.Flush(); + } + + currentOffset += currentCount; + } + + void CopyToBuffer(byte[] combinedBuffer, ReadOnlySpan span) + { + for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++) + { + combinedBuffer[i] = span[i - TdsEnums.HEADER_LEN]; + } + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs new file mode 100644 index 0000000000..5e36902ce1 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetStandard.cs @@ -0,0 +1,182 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.IO.Pipes; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Data.SqlClient.SNI +{ + internal sealed partial class SslOverTdsStream + { + /// + /// Read buffer + /// + /// Buffer + /// Offset + /// Byte count + /// Bytes read + public override int Read(byte[] buffer, int offset, int count) => + ReadInternal(buffer, offset, count, CancellationToken.None, async: false).GetAwaiter().GetResult(); + + /// + /// Write Buffer + /// + /// + /// + /// + public override void Write(byte[] buffer, int offset, int count) + => WriteInternal(buffer, offset, count, CancellationToken.None, async: false).Wait(); + + /// + /// Write Buffer Asynchronosly + /// + /// + /// + /// + /// + /// + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token) + => WriteInternal(buffer, offset, count, token, async: true); + + /// + /// Read Buffer Asynchronosly + /// + /// + /// + /// + /// + /// + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token) + => ReadInternal(buffer, offset, count, token, async: true); + + /// + /// Read Internal is called synchronosly when async is false + /// + private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) + { + int readBytes = 0; + byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count]; + + if (_encapsulate) + { + if (_packetBytes == 0) + { + // Account for split packets + while (readBytes < TdsEnums.HEADER_LEN) + { + readBytes += async ? + await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) : + _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes); + } + + _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; + _packetBytes -= TdsEnums.HEADER_LEN; + } + + if (count > _packetBytes) + { + count = _packetBytes; + } + } + + readBytes = async ? + await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) : + _stream.Read(packetData, 0, count); + + if (_encapsulate) + { + _packetBytes -= readBytes; + } + + Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); + return readBytes; + } + + /// + /// The internal write method calls Sync APIs when Async flag is false + /// + private async Task WriteInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) + { + int currentCount = 0; + int currentOffset = offset; + + while (count > 0) + { + // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After + // negotiation, the underlying socket only sees SSL frames. + // + if (_encapsulate) + { + if (count > PACKET_SIZE_WITHOUT_HEADER) + { + currentCount = PACKET_SIZE_WITHOUT_HEADER; + } + else + { + currentCount = count; + } + + count -= currentCount; + + // Prepend buffer data with TDS prelogin header + byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount]; + + // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a + // partial packet (whether or not count != 0). + // + combinedBuffer[0] = PRELOGIN_PACKET_TYPE; + combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); + combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); + combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100); + combinedBuffer[4] = 0; + combinedBuffer[5] = 0; + combinedBuffer[6] = 0; + combinedBuffer[7] = 0; + + for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++) + { + combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)]; + } + + if (async) + { + await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false); + } + else + { + _stream.Write(combinedBuffer, 0, combinedBuffer.Length); + } + } + else + { + currentCount = count; + count = 0; + + if (async) + { + await _stream.WriteAsync(buffer, currentOffset, currentCount, token).ConfigureAwait(false); + } + else + { + _stream.Write(buffer, currentOffset, currentCount); + } + } + + if (async) + { + await _stream.FlushAsync().ConfigureAwait(false); + } + else + { + _stream.Flush(); + } + + currentOffset += currentCount; + } + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs index 6fc4ec0268..7c906fc00b 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.cs @@ -15,7 +15,7 @@ namespace Microsoft.Data.SqlClient.SNI /// transported in TDS packet type 0x12. Once SSL handshake has completed, SSL /// packets are sent transparently. /// - internal sealed class SslOverTdsStream : Stream + internal sealed partial class SslOverTdsStream : Stream { private readonly Stream _stream; @@ -43,173 +43,6 @@ public void FinishHandshake() _encapsulate = false; } - /// - /// Read buffer - /// - /// Buffer - /// Offset - /// Byte count - /// Bytes read - public override int Read(byte[] buffer, int offset, int count) => - ReadInternal(buffer, offset, count, CancellationToken.None, async: false).GetAwaiter().GetResult(); - - /// - /// Write Buffer - /// - /// - /// - /// - public override void Write(byte[] buffer, int offset, int count) - => WriteInternal(buffer, offset, count, CancellationToken.None, async: false).Wait(); - - /// - /// Write Buffer Asynchronosly - /// - /// - /// - /// - /// - /// - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken token) - => WriteInternal(buffer, offset, count, token, async: true); - - /// - /// Read Buffer Asynchronosly - /// - /// - /// - /// - /// - /// - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken token) - => ReadInternal(buffer, offset, count, token, async: true); - - /// - /// Read Internal is called synchronosly when async is false - /// - private async Task ReadInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) - { - int readBytes = 0; - byte[] packetData = new byte[count < TdsEnums.HEADER_LEN ? TdsEnums.HEADER_LEN : count]; - - if (_encapsulate) - { - if (_packetBytes == 0) - { - // Account for split packets - while (readBytes < TdsEnums.HEADER_LEN) - { - readBytes += async ? - await _stream.ReadAsync(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes, token).ConfigureAwait(false) : - _stream.Read(packetData, readBytes, TdsEnums.HEADER_LEN - readBytes); - } - - _packetBytes = (packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | packetData[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]; - _packetBytes -= TdsEnums.HEADER_LEN; - } - - if (count > _packetBytes) - { - count = _packetBytes; - } - } - - readBytes = async ? - await _stream.ReadAsync(packetData, 0, count, token).ConfigureAwait(false) : - _stream.Read(packetData, 0, count); - - if (_encapsulate) - { - _packetBytes -= readBytes; - } - - Buffer.BlockCopy(packetData, 0, buffer, offset, readBytes); - return readBytes; - } - - /// - /// The internal write method calls Sync APIs when Async flag is false - /// - private async Task WriteInternal(byte[] buffer, int offset, int count, CancellationToken token, bool async) - { - int currentCount = 0; - int currentOffset = offset; - - while (count > 0) - { - // During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After - // negotiation, the underlying socket only sees SSL frames. - // - if (_encapsulate) - { - if (count > PACKET_SIZE_WITHOUT_HEADER) - { - currentCount = PACKET_SIZE_WITHOUT_HEADER; - } - else - { - currentCount = count; - } - - count -= currentCount; - - // Prepend buffer data with TDS prelogin header - byte[] combinedBuffer = new byte[TdsEnums.HEADER_LEN + currentCount]; - - // We can only send 4088 bytes in one packet. Header[1] is set to 1 if this is a - // partial packet (whether or not count != 0). - // - combinedBuffer[0] = PRELOGIN_PACKET_TYPE; - combinedBuffer[1] = (byte)(count > 0 ? 0 : 1); - combinedBuffer[2] = (byte)((currentCount + TdsEnums.HEADER_LEN) / 0x100); - combinedBuffer[3] = (byte)((currentCount + TdsEnums.HEADER_LEN) % 0x100); - combinedBuffer[4] = 0; - combinedBuffer[5] = 0; - combinedBuffer[6] = 0; - combinedBuffer[7] = 0; - - for (int i = TdsEnums.HEADER_LEN; i < combinedBuffer.Length; i++) - { - combinedBuffer[i] = buffer[currentOffset + (i - TdsEnums.HEADER_LEN)]; - } - - if (async) - { - await _stream.WriteAsync(combinedBuffer, 0, combinedBuffer.Length, token).ConfigureAwait(false); - } - else - { - _stream.Write(combinedBuffer, 0, combinedBuffer.Length); - } - } - else - { - currentCount = count; - count = 0; - - if (async) - { - await _stream.WriteAsync(buffer, currentOffset, currentCount, token).ConfigureAwait(false); - } - else - { - _stream.Write(buffer, currentOffset, currentCount); - } - } - - if (async) - { - await _stream.FlushAsync().ConfigureAwait(false); - } - else - { - _stream.Flush(); - } - - currentOffset += currentCount; - } - } - /// /// Set stream length. /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs index d8026b36e3..5035259414 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetCoreApp.cs @@ -4,6 +4,8 @@ using System; using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.Data.SqlClient { @@ -20,5 +22,46 @@ internal static Guid ConstructGuid(ReadOnlySpan bytes) Debug.Assert(bytes.Length >= 16, "not enough bytes to set guid"); return new Guid(bytes); } + + private sealed partial class TdsOutputStream + { + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask(); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + Debug.Assert(_parser._asyncWrite); + ReadOnlySpan span = buffer.Span; + + StripPreamble(ref span); + + ValueTask task = default; + if (span.Length > 0) + { + _parser.WriteInt(span.Length, _stateObj); // write length of chunk + task = new ValueTask(_stateObj.WriteByteSpan(span, canAccumulate: false)); + } + + return task; + } + + private void StripPreamble(ref ReadOnlySpan buffer) + { + if (_preambleToStrip != null && buffer.Length >= _preambleToStrip.Length) + { + for (int idx = 0; idx < _preambleToStrip.Length; idx++) + { + if (_preambleToStrip[idx] != buffer[idx]) + { + _preambleToStrip = null; + return; + } + } + + buffer = buffer.Slice(_preambleToStrip.Length); + } + _preambleToStrip = null; + } + } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs index 72c0b77b19..3b93490612 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.NetStandard.cs @@ -5,6 +5,8 @@ using System; using System.Buffers; using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.Data.SqlClient { @@ -37,5 +39,25 @@ internal static Guid ConstructGuid(ReadOnlySpan bytes) ArrayPool.Shared.Return(temp); return retval; } + + private sealed partial class TdsOutputStream + { + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Debug.Assert(_parser._asyncWrite); + ValidateWriteParameters(buffer, offset, count); + + StripPreamble(buffer, ref offset, ref count); + + Task task = null; + if (count > 0) + { + _parser.WriteInt(count, _stateObj); // write length of chunk + task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false); + } + + return task ?? Task.CompletedTask; + } + } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index e530dacc6a..3331da20b4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -10870,7 +10870,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe return null; } - private sealed class TdsOutputStream : Stream + private sealed partial class TdsOutputStream : Stream { private TdsParser _parser; private TdsParserStateObject _stateObj; @@ -10968,23 +10968,6 @@ public override void Write(byte[] buffer, int offset, int count) } } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - Debug.Assert(_parser._asyncWrite); - ValidateWriteParameters(buffer, offset, count); - - StripPreamble(buffer, ref offset, ref count); - - Task task = null; - if (count > 0) - { - _parser.WriteInt(count, _stateObj); // write length of chunk - task = _stateObj.WriteByteArray(buffer, count, offset, canAccumulate: false); - } - - return task ?? Task.CompletedTask; - } - internal static void ValidateWriteParameters(byte[] buffer, int offset, int count) { if (buffer == null)