diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index d4d8d25cb7..e46d123395 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -442,6 +442,7 @@
+
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs
index 692aa9b7fe..0132d7df58 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs
@@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}
_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
- _sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
+ _sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
@@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet)
}
// this lock ensures that two packets are not being written to the transport at the same time
- // so that sending a standard and an out-of-band packet are both written atomically no data is
+ // so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs
new file mode 100644
index 0000000000..eb8661d022
--- /dev/null
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIStreams.cs
@@ -0,0 +1,99 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Net.Security;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Net.Sockets;
+
+namespace Microsoft.Data.SqlClient.SNI
+{
+ ///
+ /// This class extends SslStream to customize stream behavior for Managed SNI implementation.
+ ///
+ internal class SNISslStream : SslStream
+ {
+ private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
+ private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
+
+ public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
+ : base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
+ {
+ _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
+ _readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
+ }
+
+ // Prevent ReadAsync collisions by running the task in a Semaphore Slim
+ public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+ }
+ finally
+ {
+ _readAsyncSemaphore.Release();
+ }
+ }
+
+ // Prevent the WriteAsync collisions by running the task in a Semaphore Slim
+ public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+ }
+ finally
+ {
+ _writeAsyncSemaphore.Release();
+ }
+ }
+ }
+
+ ///
+ /// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
+ ///
+ internal class SNINetworkStream : NetworkStream
+ {
+ private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
+ private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;
+
+ public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
+ {
+ _writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
+ _readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
+ }
+
+ // Prevent ReadAsync collisions by running the task in a Semaphore Slim
+ public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+ }
+ finally
+ {
+ _readAsyncSemaphore.Release();
+ }
+ }
+
+ // Prevent the WriteAsync collisions by running the task in a Semaphore Slim
+ public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+ {
+ await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
+ try
+ {
+ await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
+ }
+ finally
+ {
+ _writeAsyncSemaphore.Release();
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs
index b072a4fa01..ef85841d24 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs
@@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
bool reportError = true;
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
- // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
+ // If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
try
{
@@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
- {
+ {
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
- int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
+ int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
try
{
@@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
- catch(Exception exRetry)
+ catch (Exception exRetry)
{
- if (exRetry is SocketException || exRetry is ArgumentNullException
+ if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
@@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
- }
+ }
}
else
{
@@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}
_socket.NoDelay = true;
- _tcpStream = new NetworkStream(_socket, true);
+ _tcpStream = new SNINetworkStream(_socket, true);
_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
- _sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
+ _sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
@@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}
CancellationTokenSource cts = null;
-
+
void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
@@ -355,7 +355,7 @@ void Cancel()
}
Socket availableSocket = null;
- try
+ try
{
for (int i = 0; i < sockets.Length; ++i)
{
@@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet)
{
bool releaseLock = false;
try
- {
- // is the packet is marked out out-of-band (attention packets only) it must be
- // sent immediately even if a send of recieve operation is already in progress
- // because out of band packets are used to cancel ongoing operations
- // so try to take the lock if possible but continue even if it can't be taken
- if (packet.IsOutOfBand)
- {
- Monitor.TryEnter(this, ref releaseLock);
- }
- else
- {
- Monitor.Enter(this);
- releaseLock = true;
- }
-
- // this lock ensures that two packets are not being written to the transport at the same time
- // so that sending a standard and an out-of-band packet are both written atomically no data is
- // interleaved
- lock (_sendSync)
{
- try
- {
- packet.WriteToStream(_stream);
- return TdsEnums.SNI_SUCCESS;
- }
- catch (ObjectDisposedException ode)
+ // is the packet is marked out out-of-band (attention packets only) it must be
+ // sent immediately even if a send of recieve operation is already in progress
+ // because out of band packets are used to cancel ongoing operations
+ // so try to take the lock if possible but continue even if it can't be taken
+ if (packet.IsOutOfBand)
{
- return ReportTcpSNIError(ode);
+ Monitor.TryEnter(this, ref releaseLock);
}
- catch (SocketException se)
+ else
{
- return ReportTcpSNIError(se);
+ Monitor.Enter(this);
+ releaseLock = true;
}
- catch (IOException ioe)
+
+ // this lock ensures that two packets are not being written to the transport at the same time
+ // so that sending a standard and an out-of-band packet are both written atomically no data is
+ // interleaved
+ lock (_sendSync)
{
- return ReportTcpSNIError(ioe);
+ try
+ {
+ packet.WriteToStream(_stream);
+ return TdsEnums.SNI_SUCCESS;
+ }
+ catch (ObjectDisposedException ode)
+ {
+ return ReportTcpSNIError(ode);
+ }
+ catch (SocketException se)
+ {
+ return ReportTcpSNIError(se);
+ }
+ catch (IOException ioe)
+ {
+ return ReportTcpSNIError(ioe);
+ }
}
}
- }
finally
{
if (releaseLock)
@@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
else if (timeoutInMilliseconds == -1)
- { // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
+ {
+ // SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
else
@@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// SNI error code
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
- SNIAsyncCallback cb = callback ?? _sendCallback;
- lock (this)
+ long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("");
+ try
{
+ SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
+ return TdsEnums.SNI_SUCCESS_IO_PENDING;
+ }
+ finally
+ {
+ SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
- return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
///
@@ -745,15 +751,15 @@ public override uint CheckConnection()
{
try
{
- // _socket.Poll method with argument SelectMode.SelectRead returns
+ // _socket.Poll method with argument SelectMode.SelectRead returns
// True : if Listen has been called and a connection is pending, or
// True : if data is available for reading, or
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
// False : otherwise.
// _socket.Available property returns the number of bytes of data available to read.
//
- // Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
- // combination with _socket.Poll method and _socket.Available == 0 check. When both of them
+ // Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
+ // combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// return true we can safely determine that the connection is no longer active.
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
{
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
index cb634cb6af..97a1181ae3 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SslOverTdsStream.NetCoreApp.cs
@@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
- {
- return Read(buffer.AsSpan(offset, count));
- }
+ => Read(buffer.AsSpan(offset, count));
public override void Write(byte[] buffer, int offset, int count)
- {
- Write(buffer.AsSpan(offset, count));
- }
+ => Write(buffer.AsSpan(offset, count));
public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
- {
- return ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask();
- }
+ => ReadAsync(new Memory(buffer, offset, count), cancellationToken).AsTask();
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
- {
- return WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask();
- }
+ => WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken).AsTask();
public override int Read(Span buffer)
{
@@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, Cancella
await _stream.FlushAsync().ConfigureAwait(false);
-
remaining = remaining.Slice(dataLength);
}
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
index e19ee3eba0..10d5064a87 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs
@@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled()
if (_stateObj == null)
{
var reconnectionCompletionSource = _reconnectionCompletionSource;
- if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled)
+ if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled)
{
throw SQL.CR_ReconnectionCancelled();
}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs
index 5ddd978093..bf21c8db3a 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs
@@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
@@ -2131,4 +2132,38 @@ public static MethodInfo GetPromotedToken
}
}
}
+
+ ///
+ /// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
+ /// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
+ ///
+ internal class ConcurrentQueueSemaphore
+ {
+ private readonly SemaphoreSlim _semaphore;
+ private readonly ConcurrentQueue> _queue =
+ new ConcurrentQueue>();
+
+ public ConcurrentQueueSemaphore(int initialCount)
+ {
+ _semaphore = new SemaphoreSlim(initialCount);
+ }
+
+ public Task WaitAsync(CancellationToken cancellationToken)
+ {
+ var tcs = new TaskCompletionSource();
+ _queue.Enqueue(tcs);
+ _semaphore.WaitAsync().ContinueWith(t =>
+ {
+ if (_queue.TryDequeue(out TaskCompletionSource popped))
+ popped.SetResult(true);
+ }, cancellationToken);
+ return tcs.Task;
+ }
+
+ public void Release()
+ {
+ _semaphore.Release();
+ }
+ }
+
}//namespace
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
index 7e9f5a6d67..5d36213755 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs
@@ -3266,7 +3266,7 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false)
if (willCancel)
{
- // If we have been cancelled, then ensure that we write the ATTN packet as well
+ // If we have been canceled, then ensure that we write the ATTN packet as well
task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket);
}
return task;
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
index ea802107ff..9e9e281a32 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs
@@ -72,13 +72,13 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache
internal void ReadAsyncCallback(SNIPacket packet, uint error)
{
ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error);
- _sessionHandle.ReturnPacket(packet);
+ _sessionHandle?.ReturnPacket(packet);
}
internal void WriteAsyncCallback(SNIPacket packet, uint sniError)
{
WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), sniError);
- _sessionHandle.ReturnPacket(packet);
+ _sessionHandle?.ReturnPacket(packet);
}
protected override void RemovePacketFromPendingList(PacketHandle packet)
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs
index dfcd78bc15..7ff00b8335 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/AsyncTest/AsyncCancelledConnectionsTest.cs
@@ -12,7 +12,9 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
public class AsyncCancelledConnectionsTest
{
private readonly ITestOutputHelper _output;
+
private const int NumberOfTasks = 100; // How many attempts to poison the connection pool we will try
+
private const int NumberOfNonPoisoned = 10; // Number of normal requests for each attempt
public AsyncCancelledConnectionsTest(ITestOutputHelper output)