From 49c429317f48d038d4816972dd5b2f40f29d6b37 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra Date: Tue, 19 May 2020 14:57:42 -0700 Subject: [PATCH 1/2] Fix SqlBulkCopy to work with Data Classification enabled tables --- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 52 ++--- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 52 ++--- .../ManualTests/DataCommon/DataTestUtility.cs | 25 +++ ....Data.SqlClient.ManualTesting.Tests.csproj | 1 + .../ConnectionBehaviorTest.cs | 4 +- .../DataClassificationTest.cs | 209 ++++++++++++++++++ 6 files changed, 289 insertions(+), 54 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataClassificationTest/DataClassificationTest.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index cfac61e4c7..1a7821dcd5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2253,40 +2253,40 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead } } - if (null != dataStream) + byte peekedToken; + if (!stateObj.TryPeekByte(out peekedToken)) + { // temporarily cache next byte + return false; + } + + if (TdsEnums.SQLDATACLASSIFICATION == peekedToken) { - byte peekedToken; - if (!stateObj.TryPeekByte(out peekedToken)) - { // temporarily cache next byte + byte dataClassificationToken; + if (!stateObj.TryReadByte(out dataClassificationToken)) + { return false; } + Debug.Assert(TdsEnums.SQLDATACLASSIFICATION == dataClassificationToken); - if (TdsEnums.SQLDATACLASSIFICATION == peekedToken) + SensitivityClassification sensitivityClassification; + if (!TryProcessDataClassification(stateObj, out sensitivityClassification)) { - byte dataClassificationToken; - if (!stateObj.TryReadByte(out dataClassificationToken)) - { - return false; - } - Debug.Assert(TdsEnums.SQLDATACLASSIFICATION == dataClassificationToken); - - SensitivityClassification sensitivityClassification; - if (!TryProcessDataClassification(stateObj, out sensitivityClassification)) - { - return false; - } - if (!dataStream.TrySetSensitivityClassification(sensitivityClassification)) - { - return false; - } + return false; + } + if (null != dataStream && !dataStream.TrySetSensitivityClassification(sensitivityClassification)) + { + return false; + } - // update peekedToken - if (!stateObj.TryPeekByte(out peekedToken)) - { - return false; - } + // update peekedToken + if (!stateObj.TryPeekByte(out peekedToken)) + { + return false; } + } + if (null != dataStream) + { if (!dataStream.TrySetMetaData(stateObj._cleanupMetaData, (TdsEnums.SQLTABNAME == peekedToken || TdsEnums.SQLCOLINFO == peekedToken))) { return false; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index d97a5aa2bd..e48364ad6f 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -2584,40 +2584,40 @@ internal bool TryRun(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDataRead } } - if (null != dataStream) + byte peekedToken; + if (!stateObj.TryPeekByte(out peekedToken)) + { // temporarily cache next byte + return false; + } + + if (TdsEnums.SQLDATACLASSIFICATION == peekedToken) { - byte peekedToken; - if (!stateObj.TryPeekByte(out peekedToken)) - { // temporarily cache next byte + byte dataClassificationToken; + if (!stateObj.TryReadByte(out dataClassificationToken)) + { return false; } + Debug.Assert(TdsEnums.SQLDATACLASSIFICATION == dataClassificationToken); - if (TdsEnums.SQLDATACLASSIFICATION == peekedToken) + SensitivityClassification sensitivityClassification; + if (!TryProcessDataClassification(stateObj, out sensitivityClassification)) { - byte dataClassificationToken; - if (!stateObj.TryReadByte(out dataClassificationToken)) - { - return false; - } - Debug.Assert(TdsEnums.SQLDATACLASSIFICATION == dataClassificationToken); - - SensitivityClassification sensitivityClassification; - if (!TryProcessDataClassification(stateObj, out sensitivityClassification)) - { - return false; - } - if (!dataStream.TrySetSensitivityClassification(sensitivityClassification)) - { - return false; - } + return false; + } + if (null != dataStream && !dataStream.TrySetSensitivityClassification(sensitivityClassification)) + { + return false; + } - // update peekedToken - if (!stateObj.TryPeekByte(out peekedToken)) - { - return false; - } + // update peekedToken + if (!stateObj.TryPeekByte(out peekedToken)) + { + return false; } + } + if (null != dataStream) + { if (!dataStream.TrySetMetaData(stateObj._cleanupMetaData, (TdsEnums.SQLTABNAME == peekedToken || TdsEnums.SQLCOLINFO == peekedToken))) { return false; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index b26a732d16..590ae5a5a6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -233,6 +233,31 @@ public static bool IsDatabasePresent(string name) return present; } + /// + /// Checks if object SYS.SENSITIVITY_CLASSIFICATIONS exists in SQL Server + /// + /// True, if target SQL Server supports Data Classification + public static bool IsSupportedDataClassification() + { + try + { + using (var connection = new SqlConnection(TCPConnectionString)) + using (var command = new SqlCommand("SELECT * FROM SYS.SENSITIVITY_CLASSIFICATIONS", connection)) + { + connection.Open(); + command.ExecuteNonQuery(); + } + } + catch (SqlException e) + { + // Check for Error 208: Invalid Object Name + if (e.Errors != null && e.Errors[0].Number == 208) + { + return false; + } + } + return true; + } public static bool IsUdtTestDatabasePresent() => IsDatabasePresent(UdtTestDbName); public static bool AreConnStringsSetup() diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 7db44f4358..1e6154dc92 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -59,6 +59,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectionBehaviorTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectionBehaviorTest.cs index 5b108400b0..2e2f948ace 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectionBehaviorTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/ConnectionBehaviorTest.cs @@ -10,7 +10,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public class ConnectionBehaviorTest { - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [CheckConnStrSetupFact] public void ConnectionBehaviorClose() { using (SqlConnection sqlConnection = new SqlConnection((new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString) { MaxPoolSize = 1 }).ConnectionString)) @@ -31,7 +31,7 @@ public void ConnectionBehaviorClose() } } - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [CheckConnStrSetupFact] public void ConnectionBehaviorCloseAsync() { using (SqlConnection sqlConnection = new SqlConnection((new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString) { MaxPoolSize = 1 }).ConnectionString)) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataClassificationTest/DataClassificationTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataClassificationTest/DataClassificationTest.cs new file mode 100644 index 0000000000..51c8ebf5eb --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataClassificationTest/DataClassificationTest.cs @@ -0,0 +1,209 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.ObjectModel; +using System.Data; +using Microsoft.Data.SqlClient.DataClassification; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public static class DataClassificationTest + { + private static string s_tableName; + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsSupportedDataClassification))] + public static void TestDataClassificationResultSet() + { + s_tableName = DataTestUtility.GetUniqueNameForSqlServer("DC"); + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TCPConnectionString)) + using (SqlCommand sqlCommand = sqlConnection.CreateCommand()) + { + try + { + sqlConnection.Open(); + Assert.True(DataTestUtility.IsSupportedDataClassification()); + CreateTable(sqlCommand); + RunTestsForServer(sqlCommand); + } + finally + { + DataTestUtility.DropTable(sqlConnection, s_tableName); + } + } + } + + private static void RunTestsForServer(SqlCommand sqlCommand) + { + sqlCommand.CommandText = "SELECT * FROM " + s_tableName; + using (SqlDataReader reader = sqlCommand.ExecuteReader()) + { + VerifySensitivityClassification(reader); + } + } + + private static void VerifySensitivityClassification(SqlDataReader reader) + { + if (null != reader.SensitivityClassification) + { + for (int columnPos = 0; columnPos < reader.SensitivityClassification.ColumnSensitivities.Count; + columnPos++) + { + foreach (SensitivityProperty sp in reader.SensitivityClassification.ColumnSensitivities[columnPos].SensitivityProperties) + { + ReadOnlyCollection infoTypes = reader.SensitivityClassification.InformationTypes; + Assert.Equal(3, infoTypes.Count); + for (int i = 0; i < infoTypes.Count; i++) + { + VerifyInfoType(infoTypes[i], i + 1); + } + + ReadOnlyCollection