Skip to content

Commit

Permalink
[Release/2.0] Fix pooled connection re-use on access token expiry (#639)
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra authored Jul 21, 2020
1 parent 8b8196a commit 2881d67
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ internal static partial class ADP
internal static Task<bool> FalseTask => _falseTask ?? (_falseTask = Task.FromResult(false));

internal const CompareOptions DefaultCompareOptions = CompareOptions.IgnoreKanaType | CompareOptions.IgnoreWidth | CompareOptions.IgnoreCase;

internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout;
internal const int InfiniteConnectionTimeout = 0; // infinite connection timeout identifier in seconds
internal const int MaxBufferAccessTokenExpiry = 600; // max duration for buffer in seconds

static private void TraceException(string trace, Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ virtual protected bool ReadyToPrepareTransaction
}
}

internal virtual bool IsAccessTokenExpired => false;

abstract protected void Activate(Transaction transaction);

internal void ActivateConnection(Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
_waitHandles.CreationSemaphore.Release(1);
}
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if(null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
{
SqlConnectionString opt = (SqlConnectionString)options;
SqlConnectionPoolKey key = (SqlConnectionPoolKey)poolKey;
SqlInternalConnection result = null;
SessionData recoverySessionData = null;

SqlConnection sqlOwningConnection = (SqlConnection)owningConnection;
Expand Down Expand Up @@ -131,8 +130,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return result;
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,8 @@ static internal Exception InvalidArgumentValue(string methodName)
internal const int DecimalMaxPrecision28 = 28; // there are some cases in Odbc where we need that ...
internal const int DefaultCommandTimeout = 30;
internal const int DefaultConnectionTimeout = DbConnectionStringDefaults.ConnectTimeout;
internal const int InfiniteConnectionTimeout = 0; // infinite connection timeout identifier in seconds
internal const int MaxBufferAccessTokenExpiry = 600; // max duration for buffer in seconds
internal const float FailoverTimeoutStep = 0.08F; // fraction of timeout to use for fast failover connections
internal const float FailoverTimeoutStepForTnir = 0.125F; // Fraction of timeout to use in case of Transparent Network IP resolution.
internal const int MinimumTimeoutForTnirMs = 500; // The first login attempt in Transparent network IP Resolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ public ConnectionState State
}
}

internal virtual bool IsAccessTokenExpired => false;

abstract protected void Activate(SysTx.Transaction transaction);

internal void ActivateConnection(SysTx.Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
{
Marshal.ThrowExceptionForHR(releaseSemaphoreResult); // will only throw if (hresult < 0)
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if (null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
// Connection Resiliency
private bool _sessionRecoveryRequested;
internal bool _sessionRecoveryAcknowledged;
internal SessionData _currentSessionData; // internal for use from TdsParser only, otehr should use CurrentSessionData property that will fix database and language
internal SessionData _currentSessionData; // internal for use from TdsParser only, other should use CurrentSessionData property that will fix database and language
private SessionData _recoverySessionData;

// Federated Authentication
Expand All @@ -131,13 +131,14 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
internal bool _federatedAuthenticationInfoRequested; // Keep this distinct from _federatedAuthenticationRequested, since some fedauth library types may not need more info
internal bool _federatedAuthenticationInfoReceived;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;

// Certificate auth calbacks.
//
ServerCertificateValidationCallback _serverCallback;
ClientCertificateRetrievalCallback _clientCallback;
SqlClientOriginalNetworkAddressInfo _originalNetworkAddressInfo;
Expand All @@ -146,6 +147,18 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa

private bool _serverSupportsDNSCaching = false;

/// <summary>
/// Returns buffer time allowed before access token expiry to continue using the access token.
/// </summary>
private int accessTokenExpirationBufferTime
{
get
{
return (ConnectionOptions.ConnectTimeout == ADP.InfiniteConnectionTimeout || ConnectionOptions.ConnectTimeout >= ADP.MaxBufferAccessTokenExpiry)
? ADP.MaxBufferAccessTokenExpiry : ConnectionOptions.ConnectTimeout;
}
}

/// <summary>
/// Get or set if SQLDNSCaching FeatureExtAck is supported by the server.
/// </summary>
Expand Down Expand Up @@ -808,6 +821,10 @@ protected override bool UnbindOnTransactionCompletion
}
}

/// <summary>
/// Validates if federated authentication is used, Access Token used by this connection is active for the value of 'accessTokenExpirationBufferTime'.
/// </summary>
internal override bool IsAccessTokenExpired => _federatedAuthenticationInfoRequested && DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime) < DateTime.UtcNow.AddSeconds(accessTokenExpirationBufferTime);

////////////////////////////////////////////////////////////////////////////////////////
// GENERAL METHODS
Expand Down Expand Up @@ -1321,10 +1338,10 @@ internal void ExecuteTransactionYukon(
ThreadHasParserLockForClose = false;
_parserLock.Release();
releaseConnectionLock = false;
}, 0);
}, ADP.InfiniteConnectionTimeout);
if (reconnectTask != null)
{
AsyncHelper.WaitForCompletion(reconnectTask, 0); // there is no specific timeout for BeginTransaction, uses ConnectTimeout
AsyncHelper.WaitForCompletion(reconnectTask, ADP.InfiniteConnectionTimeout); // there is no specific timeout for BeginTransaction, uses ConnectTimeout
internalTransaction.ConnectionHasBeenRestored = true;
return;
}
Expand Down Expand Up @@ -2538,9 +2555,6 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
// We want to refresh the token, if taking the lock on the authentication context is successful.
bool attemptRefreshTokenLocked = false;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken fedAuthToken = null;

if (_dbConnectionPool != null)
{
Debug.Assert(_dbConnectionPool.AuthenticationContexts != null);
Expand Down Expand Up @@ -2575,7 +2589,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
}
else if (_forceExpiryLocked)
{
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);
}
#endif

Expand All @@ -2589,11 +2603,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)

// Call the function which tries to acquire a lock over the authentication context before trying to update.
// If the lock could not be obtained, it will return false, without attempting to fetch a new token.
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);

// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and fedAuthToken should not be null.
// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and _fedAuthToken should not be null.
// If there was an exception in retrieving the new token, TryGetFedAuthTokenLocked should have thrown, so we won't be here.
Debug.Assert(!attemptRefreshTokenLocked || fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _fedAuthToken != null, "Either Lock should not have been obtained or _fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _newDbConnectionPoolAuthenticationContext != null, "Either Lock should not have been obtained or _newDbConnectionPoolAuthenticationContext should not be null.");

// Indicate in Bid Trace that we are successful with the update.
Expand All @@ -2610,8 +2624,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
if (dbConnectionPoolAuthenticationContext == null || attemptRefreshTokenUnLocked)
{
// Get the Federated Authentication Token.
fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
_fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(_fedAuthToken != null, "_fedAuthToken should not be null.");

if (_dbConnectionPool != null)
{
Expand All @@ -2622,18 +2636,19 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
else if (!attemptRefreshTokenLocked)
{
Debug.Assert(dbConnectionPoolAuthenticationContext != null, "dbConnectionPoolAuthenticationContext should not be null.");
Debug.Assert(fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_fedAuthToken == null, "_fedAuthToken should be null in this case.");
Debug.Assert(_newDbConnectionPoolAuthenticationContext == null, "_newDbConnectionPoolAuthenticationContext should be null.");

fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// If the code flow is here, then we are re-using the context from the cache for this connection attempt and not
// generating a new access token on this thread.
fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.expirationFileTime = dbConnectionPoolAuthenticationContext.ExpirationTime.ToFileTime();
}

Debug.Assert(fedAuthToken != null && fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(fedAuthToken);
Debug.Assert(_fedAuthToken != null && _fedAuthToken.accessToken != null, "_fedAuthToken and _fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(_fedAuthToken);
}

/// <summary>
Expand Down Expand Up @@ -2873,7 +2888,8 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (_routingInfo != null)
{
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) {
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId)
{
return;
}
}
Expand Down Expand Up @@ -3101,16 +3117,18 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream);
}

if (1 == data[0]) {
if (1 == data[0])
{
IsSQLDNSCachingSupported = true;
_cleanSQLDNSCaching = false;

if (_routingInfo != null)
{
IsDNSCachingBeforeRedirectSupported = true;
}
}
else {
else
{
// we receive the IsSupported whose value is 0
IsSQLDNSCachingSupported = false;
_cleanSQLDNSCaching = true;
Expand Down

0 comments on commit 2881d67

Please sign in to comment.