Skip to content

Commit

Permalink
Fixes "InvalidOperationException" errors by performing async operatio…
Browse files Browse the repository at this point in the history
…ns in SemaphoreSlim (#796)
  • Loading branch information
cheenamalhotra authored Nov 19, 2020
1 parent f0572f3 commit cde615e
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIPhysicalHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIProxy.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNITcpHandle.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNIStreams.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SNICommon.cs" />
<Compile Include="Microsoft\Data\SqlClient\SNI\SspiClientContextStatus.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, object
}

_sslOverTdsStream = new SslOverTdsStream(_pipeStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));

_stream = _pipeStream;
_status = TdsEnums.SNI_SUCCESS;
Expand Down Expand Up @@ -286,7 +286,7 @@ public override uint Send(SNIPacket packet)
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Net.Security;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClient.SNI
{
/// <summary>
/// This class extends SslStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNISslStream : SslStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}

/// <summary>
/// This class extends NetworkStream to customize stream behavior for Managed SNI implementation.
/// </summary>
internal class SNINetworkStream : NetworkStream
{
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
{
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_readAsyncSemaphore.Release();
}
}

// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
_writeAsyncSemaphore.Release();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
bool reportError = true;

// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with IPv4 first and followed by IPv6 if
// IPv4 fails. The exceptions will be throw to upper level and be handled as before.
try
{
Expand All @@ -160,14 +160,14 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
// Retry with cached IP address
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
{
{
if (hasCachedDNSInfo == false)
{
throw;
}
else
{
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);
int portRetry = String.IsNullOrEmpty(cachedDNSInfo.Port) ? port : Int32.Parse(cachedDNSInfo.Port);

try
{
Expand All @@ -180,9 +180,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
_socket = Connect(cachedDNSInfo.AddrIPv4, portRetry, ts, isInfiniteTimeOut, cachedFQDN, ref pendingDNSInfo);
}
}
catch(Exception exRetry)
catch (Exception exRetry)
{
if (exRetry is SocketException || exRetry is ArgumentNullException
if (exRetry is SocketException || exRetry is ArgumentNullException
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
{
if (parallel)
Expand All @@ -199,7 +199,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
throw;
}
}
}
}
}
else
{
Expand All @@ -223,10 +223,10 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
}

_socket.NoDelay = true;
_tcpStream = new NetworkStream(_socket, true);
_tcpStream = new SNINetworkStream(_socket, true);

_sslOverTdsStream = new SslOverTdsStream(_tcpStream);
_sslStream = new SslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate), null);
_sslStream = new SNISslStream(_sslOverTdsStream, true, new RemoteCertificateValidationCallback(ValidateServerCertificate));
}
catch (SocketException se)
{
Expand Down Expand Up @@ -331,7 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
}

CancellationTokenSource cts = null;

void Cancel()
{
for (int i = 0; i < sockets.Length; ++i)
Expand All @@ -355,7 +355,7 @@ void Cancel()
}

Socket availableSocket = null;
try
try
{
for (int i = 0; i < sockets.Length; ++i)
{
Expand Down Expand Up @@ -566,45 +566,45 @@ public override uint Send(SNIPacket packet)
{
bool releaseLock = false;
try
{
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
Monitor.TryEnter(this, ref releaseLock);
}
else
{
Monitor.Enter(this);
releaseLock = true;
}

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
return ReportTcpSNIError(ode);
Monitor.TryEnter(this, ref releaseLock);
}
catch (SocketException se)
else
{
return ReportTcpSNIError(se);
Monitor.Enter(this);
releaseLock = true;
}
catch (IOException ioe)

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
return ReportTcpSNIError(ioe);
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
return ReportTcpSNIError(ode);
}
catch (SocketException se)
{
return ReportTcpSNIError(se);
}
catch (IOException ioe)
{
return ReportTcpSNIError(ioe);
}
}
}
}
finally
{
if (releaseLock)
Expand Down Expand Up @@ -633,7 +633,8 @@ public override uint Receive(out SNIPacket packet, int timeoutInMilliseconds)
_socket.ReceiveTimeout = timeoutInMilliseconds;
}
else if (timeoutInMilliseconds == -1)
{ // SqlCient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
{
// SqlClient internally represents infinite timeout by -1, and for TcpClient this is translated to a timeout of 0
_socket.ReceiveTimeout = 0;
}
else
Expand Down Expand Up @@ -706,12 +707,17 @@ public override void SetAsyncCallbacks(SNIAsyncCallback receiveCallback, SNIAsyn
/// <returns>SNI error code</returns>
public override uint SendAsync(SNIPacket packet, SNIAsyncCallback callback = null)
{
SNIAsyncCallback cb = callback ?? _sendCallback;
lock (this)
long scopeID = SqlClientEventSource.Log.TrySNIScopeEnterEvent("<sc.SNI.SNIMarsHandle.SendAsync |SNI|INFO|SCOPE>");
try
{
SNIAsyncCallback cb = callback ?? _sendCallback;
packet.WriteToStreamAsync(_stream, cb, SNIProviders.TCP_PROV);
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}
finally
{
SqlClientEventSource.Log.TrySNIScopeLeaveEvent(scopeID);
}
return TdsEnums.SNI_SUCCESS_IO_PENDING;
}

/// <summary>
Expand Down Expand Up @@ -745,15 +751,15 @@ public override uint CheckConnection()
{
try
{
// _socket.Poll method with argument SelectMode.SelectRead returns
// _socket.Poll method with argument SelectMode.SelectRead returns
// True : if Listen has been called and a connection is pending, or
// True : if data is available for reading, or
// True : if the connection has been closed, reset, or terminated, i.e no active connection.
// False : otherwise.
// _socket.Available property returns the number of bytes of data available to read.
//
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// Since _socket.Connected alone doesn't guarantee if the connection is still active, we use it in
// combination with _socket.Poll method and _socket.Available == 0 check. When both of them
// return true we can safely determine that the connection is no longer active.
if (!_socket.Connected || (_socket.Poll(100, SelectMode.SelectRead) && _socket.Available == 0))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,16 @@ namespace Microsoft.Data.SqlClient.SNI
internal sealed partial class SslOverTdsStream
{
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
=> Read(buffer.AsSpan(offset, count));

public override void Write(byte[] buffer, int offset, int count)
{
Write(buffer.AsSpan(offset, count));
}
=> Write(buffer.AsSpan(offset, count));

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
}
=> WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();

public override int Read(Span<byte> buffer)
{
Expand Down Expand Up @@ -288,7 +280,6 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella

await _stream.FlushAsync().ConfigureAwait(false);


remaining = remaining.Slice(dataLength);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,7 @@ private void ThrowIfReconnectionHasBeenCanceled()
if (_stateObj == null)
{
var reconnectionCompletionSource = _reconnectionCompletionSource;
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task.IsCanceled)
if (reconnectionCompletionSource != null && reconnectionCompletionSource.Task != null && reconnectionCompletionSource.Task.IsCanceled)
{
throw SQL.CR_ReconnectionCancelled();
}
Expand Down
Loading

0 comments on commit cde615e

Please sign in to comment.