diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 0a32356fe9..9e69b7a39f 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -952,6 +952,9 @@ public static bool ParseDataSource(string dataSource, out string hostname, out i port = -1; instanceName = string.Empty; + // Remove leading and trailing spaces + dataSource = dataSource.Trim(); + if (dataSource.Contains(":")) { dataSource = dataSource.Substring(dataSource.IndexOf(":", StringComparison.Ordinal) + 1); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 7a3dad589d..0ec61b5420 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -179,6 +179,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataSourceParserTest/DataSourceParserTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataSourceParserTest/DataSourceParserTest.cs new file mode 100644 index 0000000000..38d4685537 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataSourceParserTest/DataSourceParserTest.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.DataSourceParserTest +{ + public class DataSourceParserTest + { + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))] + [InlineData("localhost")] + [InlineData("tcp:localhost")] + [InlineData(" localhost ")] + [InlineData(" tcp:localhost ")] + [InlineData(" localhost")] + [InlineData(" tcp:localhost")] + [InlineData("localhost ")] + [InlineData("tcp:localhost ")] + public void ParseDataSourceWithoutInstanceNorPortTestShouldSucceed(string dataSource) + { + DataTestUtility.ParseDataSource(dataSource, out string hostname, out _, out _); + Assert.Equal("localhost", hostname); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))] + [InlineData("localhost,1433")] + [InlineData("tcp:localhost,1433")] + [InlineData(" localhost,1433 ")] + [InlineData(" tcp:localhost,1433 ")] + [InlineData(" localhost,1433")] + [InlineData(" tcp:localhost,1433")] + [InlineData("localhost,1433 ")] + [InlineData("tcp:localhost,1433 ")] + public void ParseDataSourceWithoutInstanceButWithPortTestShouldSucceed(string dataSource) + { + DataTestUtility.ParseDataSource(dataSource, out string hostname, out int port, out _); + Assert.Equal("localhost", hostname); + Assert.Equal(1433, port); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))] + [InlineData("localhost\\MSSQLSERVER02")] + [InlineData("tcp:localhost\\MSSQLSERVER02")] + [InlineData(" localhost\\MSSQLSERVER02 ")] + [InlineData(" tcp:localhost\\MSSQLSERVER02 ")] + [InlineData(" localhost\\MSSQLSERVER02")] + [InlineData(" tcp:localhost\\MSSQLSERVER02")] + [InlineData("localhost\\MSSQLSERVER02 ")] + [InlineData("tcp:localhost\\MSSQLSERVER02 ")] + public void ParseDataSourceWithInstanceButWithoutPortTestShouldSucceed(string dataSource) + { + DataTestUtility.ParseDataSource(dataSource, out string hostname, out _, out string instanceName); + Assert.Equal("localhost", hostname); + Assert.Equal("MSSQLSERVER02", instanceName); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))] + [InlineData("localhost\\MSSQLSERVER02,1433")] + [InlineData("tcp:localhost\\MSSQLSERVER02,1433")] + [InlineData(" localhost\\MSSQLSERVER02,1433 ")] + [InlineData(" tcp:localhost\\MSSQLSERVER02,1433 ")] + [InlineData(" localhost\\MSSQLSERVER02,1433")] + [InlineData(" tcp:localhost\\MSSQLSERVER02,1433")] + [InlineData("localhost\\MSSQLSERVER02,1433 ")] + [InlineData("tcp:localhost\\MSSQLSERVER02,1433 ")] + public void ParseDataSourceWithInstanceAndPortTestShouldSucceed(string dataSource) + { + DataTestUtility.ParseDataSource(dataSource, out string hostname, out int port, out string instanceName); + Assert.Equal("localhost", hostname); + Assert.Equal("MSSQLSERVER02", instanceName); + Assert.Equal(1433, port); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index 61474ece7f..751cd0a35b 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Data.Common; using System.Net; using System.Net.Sockets; using System.Reflection; @@ -87,9 +88,8 @@ public static void ConnectManagedWithInstanceNameTest(bool useMultiSubnetFailove } #if NET6_0_OR_GREATER - [ActiveIssue("27824")] // When specifying instance name and port number, this method call always returns false [ConditionalFact(nameof(IsSPNPortNumberTestForTCP))] - public static void PortNumberInSPNTestForTCP() + public static void SPNTestForTCPMustReturnPortNumber() { string connectionString = DataTestUtility.TCPConnectionString; SqlConnectionStringBuilder builder = new(connectionString); @@ -98,11 +98,23 @@ public static void PortNumberInSPNTestForTCP() Assert.True(port > 0, "Named instance must have a valid port number."); builder.DataSource = $"{builder.DataSource},{port}"; - PortNumberInSPNTest(builder.ConnectionString, port); + PortNumberInSPNTest(connectionString: builder.ConnectionString, expectedPortNumber: port); + } + + [ConditionalFact(nameof(IsSPNPortNumberTestForNP))] + public static void SPNTestForNPMustReturnNamedInstance() + { + string connectionString = DataTestUtility.NPConnectionString; + SqlConnectionStringBuilder builder = new(connectionString); + + DataTestUtility.ParseDataSource(builder.DataSource, out _, out _, out string instanceName); + + Assert.True(!string.IsNullOrEmpty(instanceName), "Instance name must be included in data source."); + PortNumberInSPNTest(connectionString: builder.ConnectionString, expectedInstanceName: instanceName.ToUpper()); } #endif - private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber) + private static void PortNumberInSPNTest(string connectionString, int expectedPortNumber = 0, string expectedInstanceName = null) { if (DataTestUtility.IsIntegratedSecuritySetup()) { @@ -124,20 +136,27 @@ private static void PortNumberInSPNTest(string connectionString, int expectedPor { connection.Open(); - string spnInfo = GetSPNInfo(builder.DataSource); - Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo); - - string[] spnStrs = spnInfo.Split(':'); - int portInSPN = 0; - if (spnStrs.Length > 1) + string spnInfo = GetSPNInfo(builder.DataSource, instanceName); + if (expectedPortNumber > 0) { - int.TryParse(spnStrs[1], out portInSPN); + Assert.Matches(@"MSSQLSvc\/.*:[\d]", spnInfo); + string[] spnStrs = spnInfo.Split(':'); + int portInSPN = 0; + if (spnStrs.Length > 1) + { + int.TryParse(spnStrs[1], out portInSPN); + } + Assert.Equal(expectedPortNumber, portInSPN); + } + else + { + string[] spnStrs = spnInfo.Split(':'); + Assert.Equal(expectedInstanceName, spnStrs[1].ToUpper()); } - Assert.Equal(expectedPortNumber, portInSPN); } } - private static string GetSPNInfo(string dataSource) + private static string GetSPNInfo(string dataSource, string inInstanceName) { Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection)); @@ -178,9 +197,12 @@ private static string GetSPNInfo(string dataSource) PropertyInfo serverInfo = dataSrcInfo.GetType().GetProperty("ServerName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); string serverName = serverInfo.GetValue(dataSrcInfo, null).ToString(); - + // Set the instance name from the data source + PropertyInfo instanceNameToSetInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + instanceNameToSetInfo.SetValue(dataSrcInfo, inInstanceName, null); + // Ensure that the instance name is set PropertyInfo instanceNameInfo = dataSrcInfo.GetType().GetProperty("InstanceName", BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); - string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString(); + string instanceName = instanceNameInfo.GetValue(dataSrcInfo, null).ToString().ToUpper(); object port = getPortByInstanceNameInfo.Invoke(ssrpObj, parameters: new object[] { serverName, instanceName, timeoutTimerObj, false, 0 }); @@ -205,6 +227,13 @@ private static bool IsSPNPortNumberTestForTCP() && DataTestUtility.IsNotAzureSynapse()); } + private static bool IsSPNPortNumberTestForNP() + { + return (IsInstanceNameValid(DataTestUtility.NPConnectionString) + && DataTestUtility.IsUsingManagedSNI() + && DataTestUtility.IsNotAzureServer() + && DataTestUtility.IsNotAzureSynapse()); + } private static bool IsInstanceNameValid(string connectionString) { string instanceName = "";