diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 616c44a3f4..cacb7de65b 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1877,7 +1877,7 @@ private void LoginNoFailover(ServerInfo serverInfo, string newPassword, SecureSt throw SQL.ROR_TimeoutAfterRoutingInfo(this); } - serverInfo = new ServerInfo(ConnectionOptions, _routingInfo, serverInfo.ResolvedServerName); + serverInfo = new ServerInfo(ConnectionOptions, _routingInfo, serverInfo.ResolvedServerName, serverInfo.ServerSPN); timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination); _originalClientConnectionId = _clientConnectionId; _routingDestination = serverInfo.UserServerName; @@ -2047,7 +2047,7 @@ TimeoutTimer timeout long timeoutUnitInterval; string protocol = ConnectionOptions.NetworkLibrary; - ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost); + ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost, connectionOptions.FailoverPartnerSPN); ResolveExtendedServerName(primaryServerInfo, !redirectedUserInstance, connectionOptions); if (null == ServerProvidedFailOverPartner) @@ -2150,7 +2150,7 @@ TimeoutTimer timeout _parser = new TdsParser(ConnectionOptions.MARS, ConnectionOptions.Asynchronous); Debug.Assert(SniContext.Undefined == Parser._physicalStateObj.SniContext, $"SniContext should be Undefined; actual Value: {Parser._physicalStateObj.SniContext}"); - currentServerInfo = new ServerInfo(ConnectionOptions, _routingInfo, currentServerInfo.ResolvedServerName); + currentServerInfo = new ServerInfo(ConnectionOptions, _routingInfo, currentServerInfo.ResolvedServerName, currentServerInfo.ServerSPN); timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.RoutingDestination); _originalClientConnectionId = _clientConnectionId; _routingDestination = currentServerInfo.UserServerName; @@ -2296,13 +2296,9 @@ private void AttemptOneLogin(ServerInfo serverInfo, string newPassword, SecureSt this, ignoreSniOpenTimeout, timeout.LegacyTimerExpire, - ConnectionOptions.Encrypt, - ConnectionOptions.TrustServerCertificate, - ConnectionOptions.IntegratedSecurity, + ConnectionOptions, withFailover, isFirstTransparentAttempt, - ConnectionOptions.Authentication, - ConnectionOptions.Certificate, _serverCallback, _clientCallback, _originalNetworkAddressInfo != null, @@ -3244,6 +3240,7 @@ internal sealed class ServerInfo internal string ResolvedServerName { get; private set; } // the resolved servername only internal string ResolvedDatabaseName { get; private set; } // name of target database after resolution internal string UserProtocol { get; private set; } // the user specified protocol + internal string ServerSPN { get; private set; } // the server SPN // The original user-supplied server name from the connection string. // If connection string has no Data Source, the value is set to string.Empty. @@ -3264,10 +3261,16 @@ private set internal readonly string PreRoutingServerName; // Initialize server info from connection options, - internal ServerInfo(SqlConnectionString userOptions) : this(userOptions, userOptions.DataSource) { } + internal ServerInfo(SqlConnectionString userOptions) : this(userOptions, userOptions.DataSource, userOptions.ServerSPN) { } + + // Initialize server info from connection options, but override DataSource and ServerSPN with given server name and server SPN + internal ServerInfo(SqlConnectionString userOptions, string serverName, string serverSPN) : this(userOptions, serverName) + { + ServerSPN = serverSPN; + } // Initialize server info from connection options, but override DataSource with given server name - internal ServerInfo(SqlConnectionString userOptions, string serverName) + private ServerInfo(SqlConnectionString userOptions, string serverName) { //----------------- // Preconditions @@ -3286,7 +3289,7 @@ internal ServerInfo(SqlConnectionString userOptions, string serverName) // Initialize server info from connection options, but override DataSource with given server name - internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string preRoutingServerName) + internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string preRoutingServerName, string serverSPN) { //----------------- // Preconditions @@ -3307,6 +3310,7 @@ internal ServerInfo(SqlConnectionString userOptions, RoutingInfo routing, string UserProtocol = TdsEnums.TCP; SetDerivedNames(UserProtocol, UserServerName); ResolvedDatabaseName = userOptions.InitialCatalog; + ServerSPN = serverSPN; } internal void SetDerivedNames(string protocol, string serverName) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index e69df7dd19..68d4505518 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -493,18 +493,20 @@ internal void Connect(ServerInfo serverInfo, SqlInternalConnectionTds connHandler, bool ignoreSniOpenTimeout, long timerExpire, - bool encrypt, - bool trustServerCert, - bool integratedSecurity, + SqlConnectionString connectionOptions, bool withFailover, bool isFirstTransparentAttempt, - SqlAuthenticationMethod authType, - string certificate, ServerCertificateValidationCallback serverCallback, ClientCertificateRetrievalCallback clientCallback, bool useOriginalAddressInfo, bool disableTnir) { + bool encrypt = connectionOptions.Encrypt; + bool trustServerCert = connectionOptions.TrustServerCertificate; + bool integratedSecurity = connectionOptions.IntegratedSecurity; + SqlAuthenticationMethod authType = connectionOptions.Authentication; + string certificate = connectionOptions.Certificate; + if (_state != TdsParserState.Closed) { Debug.Fail("TdsParser.Connect called on non-closed connection!"); @@ -544,6 +546,9 @@ internal void Connect(ServerInfo serverInfo, LoadSSPILibrary(); // now allocate proper length of buffer _sniSpnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength]; + byte[] srvSPN = Encoding.Unicode.GetBytes(serverInfo.ServerSPN); + Trace.Assert(srvSPN.Length <= _sniSpnBuffer.Length, "The provider SPN length exceeded the buffer size."); + Array.Copy(srvSPN, _sniSpnBuffer, srvSPN.Length); SqlClientEventSource.Log.TryTraceEvent(" SSPI or Active Directory Authentication Library for SQL Server based integrated authentication"); } else diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs index 0557ebaa75..1a779c02e8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs @@ -976,6 +976,8 @@ internal static class DbConnectionStringDefaults internal const SqlConnectionAttestationProtocol AttestationProtocol = SqlConnectionAttestationProtocol.NotSpecified; internal const SqlConnectionIPAddressPreference IPAddressPreference = SqlConnectionIPAddressPreference.IPv4First; internal const PoolBlockingPeriod PoolBlockingPeriod = SqlClient.PoolBlockingPeriod.Auto; + internal const string ServerSPN = ""; + internal const string FailoverPartnerSPN = ""; } internal static class DbConnectionStringKeywords @@ -1029,6 +1031,8 @@ internal static class DbConnectionStringKeywords internal const string EnclaveAttestationUrl = "Enclave Attestation Url"; internal const string AttestationProtocol = "Attestation Protocol"; internal const string IPAddressPreference = "IP Address Preference"; + internal const string ServerSPN = "Server SPN"; + internal const string FailoverPartnerSPN = "Failover Partner SPN"; // common keywords (OleDb, OracleClient, SqlClient) internal const string DataSource = "Data Source"; @@ -1122,5 +1126,9 @@ internal static class DbConnectionStringSynonyms //internal const string WorkstationID = WSID; internal const string WSID = "wsid"; + + //internal const string server SPNs + internal const string ServerSPN = "ServerSPN"; + internal const string FailoverPartnerSPN = "FailoverPartnerSPN"; } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionString.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionString.cs index f1a488c4b6..67caeb6621 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionString.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionString.cs @@ -59,6 +59,8 @@ internal static class DEFAULT internal static readonly SqlAuthenticationMethod Authentication = DbConnectionStringDefaults.Authentication; internal static readonly SqlConnectionAttestationProtocol AttestationProtocol = DbConnectionStringDefaults.AttestationProtocol; internal static readonly SqlConnectionIPAddressPreference IpAddressPreference = DbConnectionStringDefaults.IPAddressPreference; + internal const string ServerSPN = DbConnectionStringDefaults.ServerSPN; + internal const string FailoverPartnerSPN = DbConnectionStringDefaults.FailoverPartnerSPN; #if NETFRAMEWORK internal static readonly bool TransparentNetworkIPResolution = DbConnectionStringDefaults.TransparentNetworkIPResolution; internal const bool Connection_Reset = DbConnectionStringDefaults.ConnectionReset; @@ -113,6 +115,8 @@ internal static class KEY internal const string Connect_Retry_Count = DbConnectionStringKeywords.ConnectRetryCount; internal const string Connect_Retry_Interval = DbConnectionStringKeywords.ConnectRetryInterval; internal const string Authentication = DbConnectionStringKeywords.Authentication; + internal const string Server_SPN = DbConnectionStringKeywords.ServerSPN; + internal const string Failover_Partner_SPN = DbConnectionStringKeywords.FailoverPartnerSPN; #if NETFRAMEWORK internal const string TransparentNetworkIPResolution = DbConnectionStringKeywords.TransparentNetworkIPResolution; #if ADONET_CERT_AUTH @@ -173,6 +177,9 @@ private static class SYNONYM internal const string User = DbConnectionStringSynonyms.User; // workstation id internal const string WSID = DbConnectionStringSynonyms.WSID; + // server SPNs + internal const string ServerSPN = DbConnectionStringSynonyms.ServerSPN; + internal const string FailoverPartnerSPN = DbConnectionStringSynonyms.FailoverPartnerSPN; #if NETFRAMEWORK internal const string TRANSPARENTNETWORKIPRESOLUTION = DbConnectionStringSynonyms.TRANSPARENTNETWORKIPRESOLUTION; @@ -212,9 +219,9 @@ internal static class TRANSACTIONBINDING } #if NETFRAMEWORK - internal const int SynonymCount = 29; + internal const int SynonymCount = 31; #else - internal const int SynonymCount = 26; + internal const int SynonymCount = 28; internal const int DeprecatedSynonymCount = 2; #endif // NETFRAMEWORK @@ -257,6 +264,8 @@ internal static class TRANSACTIONBINDING private readonly string _initialCatalog; private readonly string _password; private readonly string _userID; + private readonly string _serverSPN; + private readonly string _failoverPartnerSPN; private readonly string _workstationId; @@ -322,6 +331,8 @@ internal SqlConnectionString(string connectionString) : base(connectionString, G _enclaveAttestationUrl = ConvertValueToString(KEY.EnclaveAttestationUrl, DEFAULT.EnclaveAttestationUrl); _attestationProtocol = ConvertValueToAttestationProtocol(); _ipAddressPreference = ConvertValueToIPAddressPreference(); + _serverSPN = ConvertValueToString(KEY.Server_SPN, DEFAULT.ServerSPN); + _failoverPartnerSPN = ConvertValueToString(KEY.Failover_Partner_SPN, DEFAULT.FailoverPartnerSPN); // Temporary string - this value is stored internally as an enum. string typeSystemVersionString = ConvertValueToString(KEY.Type_System_Version, null); @@ -675,6 +686,8 @@ internal SqlConnectionString(SqlConnectionString connectionOptions, string dataS _columnEncryptionSetting = connectionOptions._columnEncryptionSetting; _enclaveAttestationUrl = connectionOptions._enclaveAttestationUrl; _attestationProtocol = connectionOptions._attestationProtocol; + _serverSPN = connectionOptions._serverSPN; + _failoverPartnerSPN = connectionOptions._failoverPartnerSPN; #if NETFRAMEWORK _connectionReset = connectionOptions._connectionReset; _contextConnection = connectionOptions._contextConnection; @@ -732,7 +745,8 @@ internal SqlConnectionString(SqlConnectionString connectionOptions, string dataS internal string UserID => _userID; internal string WorkstationId => _workstationId; internal PoolBlockingPeriod PoolBlockingPeriod => _poolBlockingPeriod; - + internal string ServerSPN => _serverSPN; + internal string FailoverPartnerSPN => _failoverPartnerSPN; internal TypeSystem TypeSystemVersion => _typeSystemVersion; internal Version TypeSystemAssemblyVersion => _typeSystemAssemblyVersion; @@ -843,6 +857,8 @@ internal static Dictionary GetParseSynonyms() { KEY.Connect_Retry_Interval, KEY.Connect_Retry_Interval }, { KEY.Authentication, KEY.Authentication }, { KEY.IPAddressPreference, KEY.IPAddressPreference }, + { KEY.Server_SPN, KEY.Server_SPN }, + { KEY.Failover_Partner_SPN, KEY.Failover_Partner_SPN }, { SYNONYM.APP, KEY.Application_Name }, { SYNONYM.APPLICATIONINTENT, KEY.ApplicationIntent }, @@ -871,6 +887,8 @@ internal static Dictionary GetParseSynonyms() { SYNONYM.UID, KEY.User_ID }, { SYNONYM.User, KEY.User_ID }, { SYNONYM.WSID, KEY.Workstation_Id }, + { SYNONYM.ServerSPN, KEY.Server_SPN }, + { SYNONYM.FailoverPartnerSPN, KEY.Failover_Partner_SPN }, #if NETFRAMEWORK #if ADONET_CERT_AUTH { KEY.Certificate, KEY.Certificate }, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs index b78c2e392b..b73ac0d532 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs @@ -65,6 +65,8 @@ private enum Keywords AttestationProtocol, CommandTimeout, IPAddressPreference, + ServerSPN, + FailoverPartnerSPN, #if NETFRAMEWORK ConnectionReset, NetworkLibrary, @@ -122,6 +124,8 @@ private enum Keywords private string _enclaveAttestationUrl = DbConnectionStringDefaults.EnclaveAttestationUrl; private SqlConnectionAttestationProtocol _attestationProtocol = DbConnectionStringDefaults.AttestationProtocol; private SqlConnectionIPAddressPreference _ipAddressPreference = DbConnectionStringDefaults.IPAddressPreference; + private string _serverSPN = DbConnectionStringDefaults.ServerSPN; + private string _failoverPartnerSPN = DbConnectionStringDefaults.FailoverPartnerSPN; #if NETFRAMEWORK private bool _connectionReset = DbConnectionStringDefaults.ConnectionReset; @@ -176,11 +180,13 @@ private static string[] CreateValidKeywords() validKeywords[(int)Keywords.EnclaveAttestationUrl] = DbConnectionStringKeywords.EnclaveAttestationUrl; validKeywords[(int)Keywords.AttestationProtocol] = DbConnectionStringKeywords.AttestationProtocol; validKeywords[(int)Keywords.IPAddressPreference] = DbConnectionStringKeywords.IPAddressPreference; + validKeywords[(int)Keywords.ServerSPN] = DbConnectionStringKeywords.ServerSPN; + validKeywords[(int)Keywords.FailoverPartnerSPN] = DbConnectionStringKeywords.FailoverPartnerSPN; #if NETFRAMEWORK validKeywords[(int)Keywords.ConnectionReset] = DbConnectionStringKeywords.ConnectionReset; + validKeywords[(int)Keywords.NetworkLibrary] = DbConnectionStringKeywords.NetworkLibrary; validKeywords[(int)Keywords.ContextConnection] = DbConnectionStringKeywords.ContextConnection; validKeywords[(int)Keywords.TransparentNetworkIPResolution] = DbConnectionStringKeywords.TransparentNetworkIPResolution; - validKeywords[(int)Keywords.NetworkLibrary] = DbConnectionStringKeywords.NetworkLibrary; #if ADONET_CERT_AUTH validKeywords[(int)Keywords.Certificate] = DbConnectionStringKeywords.Certificate; #endif @@ -228,6 +234,8 @@ private static Dictionary CreateKeywordsDictionary() { DbConnectionStringKeywords.EnclaveAttestationUrl, Keywords.EnclaveAttestationUrl }, { DbConnectionStringKeywords.AttestationProtocol, Keywords.AttestationProtocol }, { DbConnectionStringKeywords.IPAddressPreference, Keywords.IPAddressPreference }, + { DbConnectionStringKeywords.ServerSPN, Keywords.ServerSPN }, + { DbConnectionStringKeywords.FailoverPartnerSPN, Keywords.FailoverPartnerSPN }, #if NETFRAMEWORK { DbConnectionStringKeywords.ConnectionReset, Keywords.ConnectionReset }, @@ -266,7 +274,9 @@ private static Dictionary CreateKeywordsDictionary() { DbConnectionStringSynonyms.PERSISTSECURITYINFO, Keywords.PersistSecurityInfo }, { DbConnectionStringSynonyms.UID, Keywords.UserID }, { DbConnectionStringSynonyms.User, Keywords.UserID }, - { DbConnectionStringSynonyms.WSID, Keywords.WorkstationID } + { DbConnectionStringSynonyms.WSID, Keywords.WorkstationID }, + { DbConnectionStringSynonyms.ServerSPN, Keywords.ServerSPN }, + { DbConnectionStringSynonyms.FailoverPartnerSPN, Keywords.FailoverPartnerSPN }, }; Debug.Assert((KeywordsCount + SqlConnectionString.SynonymCount) == pairs.Count, "initial expected size is incorrect"); return pairs; @@ -373,7 +383,10 @@ private object GetAt(Keywords index) return AttestationProtocol; case Keywords.IPAddressPreference: return IPAddressPreference; - + case Keywords.ServerSPN: + return ServerSPN; + case Keywords.FailoverPartnerSPN: + return FailoverPartnerSPN; #if NETFRAMEWORK #pragma warning disable 618 // Obsolete properties case Keywords.ConnectionReset: @@ -518,6 +531,12 @@ private void Reset(Keywords index) case Keywords.IPAddressPreference: _ipAddressPreference = DbConnectionStringDefaults.IPAddressPreference; break; + case Keywords.ServerSPN: + _serverSPN = DbConnectionStringDefaults.ServerSPN; + break; + case Keywords.FailoverPartnerSPN: + _failoverPartnerSPN = DbConnectionStringDefaults.FailoverPartnerSPN; + break; #if NETFRAMEWORK case Keywords.ConnectionReset: _connectionReset = DbConnectionStringDefaults.ConnectionReset; @@ -1010,6 +1029,12 @@ public override object this[string keyword] case Keywords.ConnectRetryInterval: ConnectRetryInterval = ConvertToInt32(value); break; + case Keywords.ServerSPN: + ServerSPN = ConvertToString(value); + break; + case Keywords.FailoverPartnerSPN: + FailoverPartnerSPN = ConvertToString(value); + break; #if NETFRAMEWORK #pragma warning disable 618 // Obsolete properties case Keywords.ConnectionReset: @@ -1165,6 +1190,23 @@ public string DataSource } } + /// + /// The SPN for the server. The default value is an empty string. An empty string causes SQL Server Native Client to use the default, provider-generated SPN. + /// + [DisplayName(DbConnectionStringKeywords.ServerSPN)] + [ResCategory(StringsHelper.ResourceNames.DataCategory_Source)] + //[ResDescription(StringsHelper.ResourceNames.DbConnectionString_ServerSPN)] + [RefreshProperties(RefreshProperties.All)] + public string ServerSPN + { + get => _serverSPN; + set + { + SetValue(DbConnectionStringKeywords.ServerSPN, value); + _serverSPN = value; + } + } + /// [DisplayName(DbConnectionStringKeywords.Encrypt)] [ResCategory(StringsHelper.ResourceNames.DataCategory_Security)] @@ -1303,6 +1345,22 @@ public string FailoverPartner } } + /// + /// The SPN for the failover partner. The default value is an empty string. An empty string causes SQL Server Native Client to use the default, provider-generated SPN. + /// + [DisplayName(DbConnectionStringKeywords.FailoverPartnerSPN)] + [ResCategory(StringsHelper.ResourceNames.DataCategory_Source)] + [RefreshProperties(RefreshProperties.All)] + public string FailoverPartnerSPN + { + get => _failoverPartnerSPN; + set + { + SetValue(DbConnectionStringKeywords.FailoverPartnerSPN, value); + _failoverPartnerSPN = value; + } + } + /// [DisplayName(DbConnectionStringKeywords.InitialCatalog)] [ResCategory(StringsHelper.ResourceNames.DataCategory_Source)]