diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index a8a02be484..466816e7e0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -842,7 +842,7 @@ protected override void Dispose(bool disposing) } base.Dispose(disposing); } - catch(SqlException ex) + catch (SqlException ex) { SqlClientEventSource.Log.TryTraceEvent("SqlDataReader.Dispose | ERR | Error Message: {0}, Stack Trace: {1}", ex.Message, ex.StackTrace); } @@ -3767,26 +3767,18 @@ private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; _sharedState._nextColumnHeaderToRead++; // We read this one - if (isNull) + // Trigger new behavior for RowVersion to send DBNull.Value by allowing entry for Timestamp or discard entry for Timestamp for legacy support. + // if LegacyRowVersionNullBehavior is enabled, Timestamp type must enter "else" block. + if (isNull && (!LocalAppContextSwitches.LegacyRowVersionNullBehavior || columnMetaData.type != SqlDbType.Timestamp)) { - if (columnMetaData.type == SqlDbType.Timestamp) - { - if (!LocalAppContextSwitches.LegacyRowVersionNullBehavior) - { - _data[i].SetToNullOfType(SqlBuffer.StorageType.SqlBinary); - } - } - else - { - TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], - columnMetaData, + TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], + columnMetaData, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, _parser.Connection); - if (!readHeaderOnly) - { - _sharedState._nextColumnDataToRead++; - } + if (!readHeaderOnly) + { + _sharedState._nextColumnDataToRead++; } } else @@ -4098,8 +4090,8 @@ internal bool TrySetMetaData(_SqlMetaDataSet metaData, bool moreInfo) if (_parser != null) { // There is a valid case where parser is null - // Peek, and if row token present, set _hasRows true since there is a - // row in the result + // Peek, and if row token present, set _hasRows true since there is a + // row in the result byte b; if (!_stateObj.TryPeekByte(out b)) { @@ -5021,7 +5013,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { _stateObj._shouldHaveEnoughData = true; #endif - return Task.FromResult(GetFieldValueInternal(i)); + return Task.FromResult(GetFieldValueInternal(i)); #if DEBUG } finally diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 04ec52be41..75bb8aef19 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -5429,6 +5429,10 @@ internal static object GetNullSqlValue(SqlBuffer nullVal, SqlMetaDataPriv md, Sq break; case SqlDbType.Timestamp: + if (!LocalAppContextSwitches.LegacyRowVersionNullBehavior) + { + nullVal.SetToNullOfType(SqlBuffer.StorageType.SqlBinary); + } break; default: diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index eb54272de3..7bfa391b8d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4279,26 +4279,18 @@ private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; _sharedState._nextColumnHeaderToRead++; // We read this one - if (isNull) + // Trigger new behavior for RowVersion to send DBNull.Value by allowing entry for Timestamp or discard entry for Timestamp for legacy support. + // if LegacyRowVersionNullBehavior is enabled, Timestamp type must enter "else" block. + if (isNull && (!LocalAppContextSwitches.LegacyRowVersionNullBehavior || columnMetaData.type != SqlDbType.Timestamp)) { - if (columnMetaData.type == SqlDbType.Timestamp) - { - if (!LocalAppContextSwitches.LegacyRowVersionNullBehavior) - { - _data[i].SetToNullOfType(SqlBuffer.StorageType.SqlBinary); - } - } - else - { - TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], + TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, _parser.Connection); - if (!readHeaderOnly) - { - _sharedState._nextColumnDataToRead++; - } + if (!readHeaderOnly) + { + _sharedState._nextColumnDataToRead++; } } else diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs index 448aa9ed64..5fa95000ee 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderTest.cs @@ -15,60 +15,46 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public static class DataReaderTest { - private static object s_rowVersionLock = new object(); + private static readonly object s_rowVersionLock = new(); [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void LoadReaderIntoDataTableToTestGetSchemaTable() { - using (SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString)) - { - connection.Open(); - var dt = new DataTable(); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = "select 3 as [three], 4 as [four]"; - // Datatables internally call IDataReader.GetSchemaTable() - dt.Load(command.ExecuteReader()); - Assert.Equal(2, dt.Columns.Count); - Assert.Equal("three", dt.Columns[0].ColumnName); - Assert.Equal("four", dt.Columns[1].ColumnName); - Assert.Equal(1, dt.Rows.Count); - Assert.Equal(3, (int)dt.Rows[0][0]); - Assert.Equal(4, (int)dt.Rows[0][1]); - } - } + using SqlConnection connection = new(DataTestUtility.TCPConnectionString); + connection.Open(); + var dt = new DataTable(); + using SqlCommand command = connection.CreateCommand(); + command.CommandText = "select 3 as [three], 4 as [four]"; + // Datatables internally call IDataReader.GetSchemaTable() + dt.Load(command.ExecuteReader()); + Assert.Equal(2, dt.Columns.Count); + Assert.Equal("three", dt.Columns[0].ColumnName); + Assert.Equal("four", dt.Columns[1].ColumnName); + Assert.Equal(1, dt.Rows.Count); + Assert.Equal(3, (int)dt.Rows[0][0]); + Assert.Equal(4, (int)dt.Rows[0][1]); } [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void MultiQuerySchema() { - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString); - using (SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + using SqlConnection connection = new(DataTestUtility.TCPConnectionString); + connection.Open(); + using SqlCommand command = connection.CreateCommand(); + // Use multiple queries + command.CommandText = "SELECT 1 as ColInteger; SELECT 'STRING' as ColString"; + using SqlDataReader reader = command.ExecuteReader(); + HashSet columnNames = new(); + do { - connection.Open(); - - using (SqlCommand command = connection.CreateCommand()) + DataTable schemaTable = reader.GetSchemaTable(); + foreach (DataRow myField in schemaTable.Rows) { - // Use multiple queries - command.CommandText = "SELECT 1 as ColInteger; SELECT 'STRING' as ColString"; - using (SqlDataReader reader = command.ExecuteReader()) - { - HashSet columnNames = new HashSet(); - do - { - DataTable schemaTable = reader.GetSchemaTable(); - foreach (DataRow myField in schemaTable.Rows) - { - columnNames.Add(myField["ColumnName"].ToString()); - } - - } while (reader.NextResult()); - - Assert.Contains("ColInteger", columnNames); - Assert.Contains("ColString", columnNames); - } + columnNames.Add(myField["ColumnName"].ToString()); } - } + } while (reader.NextResult()); + Assert.Contains("ColInteger", columnNames); + Assert.Contains("ColString", columnNames); } @@ -82,17 +68,17 @@ public static void CheckSparseColumnBit() // TSQL for "CREATE TABLE" with sparse columns // table name will be provided as an argument - StringBuilder createBuilder = new StringBuilder("CREATE TABLE {0} ([ID] int PRIMARY KEY, [CSET] xml COLUMN_SET FOR ALL_SPARSE_COLUMNS NULL"); + StringBuilder createBuilder = new("CREATE TABLE {0} ([ID] int PRIMARY KEY, [CSET] xml COLUMN_SET FOR ALL_SPARSE_COLUMNS NULL"); // TSQL to create the same table, but without the column set column and without sparse // also, it has only 1024 columns, which is the server limit in this case - StringBuilder createNonSparseBuilder = new StringBuilder("CREATE TABLE {0} ([ID] int PRIMARY KEY"); + StringBuilder createNonSparseBuilder = new("CREATE TABLE {0} ([ID] int PRIMARY KEY"); // TSQL to select all columns from the sparse table, without columnset one - StringBuilder selectBuilder = new StringBuilder("SELECT [ID]"); + StringBuilder selectBuilder = new("SELECT [ID]"); // TSQL to select all columns from the sparse table, with a limit of 1024 (for bulk-copy test) - StringBuilder selectNonSparseBuilder = new StringBuilder("SELECT [ID]"); + StringBuilder selectNonSparseBuilder = new("SELECT [ID]"); // add sparse columns for (int c = 0; c < sparseColumns; c++) @@ -109,29 +95,27 @@ public static void CheckSparseColumnBit() string createStatementFormat = createBuilder.ToString(); // add a row with nulls only - using (SqlConnection con = new SqlConnection(DataTestUtility.TCPConnectionString)) - using (SqlCommand cmd = con.CreateCommand()) + using SqlConnection con = new SqlConnection(DataTestUtility.TCPConnectionString); + using SqlCommand cmd = con.CreateCommand(); + try { - try - { - con.Open(); + con.Open(); - cmd.CommandType = CommandType.Text; - cmd.CommandText = string.Format(createStatementFormat, tempTableName); - cmd.ExecuteNonQuery(); + cmd.CommandType = CommandType.Text; + cmd.CommandText = string.Format(createStatementFormat, tempTableName); + cmd.ExecuteNonQuery(); - cmd.CommandText = string.Format("INSERT INTO {0} ([ID]) VALUES (0)", tempTableName);// insert row with values set to their defaults (DBNULL) - cmd.ExecuteNonQuery(); + cmd.CommandText = string.Format("INSERT INTO {0} ([ID]) VALUES (0)", tempTableName);// insert row with values set to their defaults (DBNULL) + cmd.ExecuteNonQuery(); - // run the test cases - Assert.True(IsColumnBitSet(con, string.Format("SELECT [ID], [CSET], [C1] FROM {0}", tempTableName), indexOfColumnSet: 1)); - } - finally - { - // drop the temp table to release its resources - cmd.CommandText = "DROP TABLE " + tempTableName; - cmd.ExecuteNonQuery(); - } + // run the test cases + Assert.True(IsColumnBitSet(con, string.Format("SELECT [ID], [CSET], [C1] FROM {0}", tempTableName), indexOfColumnSet: 1)); + } + finally + { + // drop the temp table to release its resources + cmd.CommandText = "DROP TABLE " + tempTableName; + cmd.ExecuteNonQuery(); } } @@ -143,50 +127,46 @@ public static void CollatedDataReaderTest() // Remove square brackets var dbName = databaseName.Substring(1, databaseName.Length - 2); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString) + SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString) { InitialCatalog = dbName, Pooling = false }; - using (SqlConnection con = new SqlConnection(DataTestUtility.TCPConnectionString)) - using (SqlCommand cmd = con.CreateCommand()) + using SqlConnection con = new(DataTestUtility.TCPConnectionString); + using SqlCommand cmd = con.CreateCommand(); + try { - try - { - con.Open(); - - // Create collated database - cmd.CommandText = $"CREATE DATABASE {databaseName} COLLATE KAZAKH_90_CI_AI"; - cmd.ExecuteNonQuery(); + con.Open(); - //Create connection without pooling in order to delete database later. - using (SqlConnection dbCon = new SqlConnection(builder.ConnectionString)) - using (SqlCommand dbCmd = dbCon.CreateCommand()) - { - var data = "TestData"; - - dbCon.Open(); - dbCmd.CommandText = $"SELECT '{data}'"; - using (SqlDataReader reader = dbCmd.ExecuteReader()) - { - reader.Read(); - Assert.Equal(data, reader.GetString(0)); - } - } + // Create collated database + cmd.CommandText = $"CREATE DATABASE {databaseName} COLLATE KAZAKH_90_CI_AI"; + cmd.ExecuteNonQuery(); - // Let connection close safely before dropping database for slow servers. - Thread.Sleep(500); - } - catch (SqlException e) - { - Assert.True(false, $"Unexpected Exception occurred: {e.Message}"); - } - finally + //Create connection without pooling in order to delete database later. + using (SqlConnection dbCon = new(builder.ConnectionString)) + using (SqlCommand dbCmd = dbCon.CreateCommand()) { - cmd.CommandText = $"DROP DATABASE {databaseName}"; - cmd.ExecuteNonQuery(); + var data = "TestData"; + + dbCon.Open(); + dbCmd.CommandText = $"SELECT '{data}'"; + using SqlDataReader reader = dbCmd.ExecuteReader(); + reader.Read(); + Assert.Equal(data, reader.GetString(0)); } + + // Let connection close safely before dropping database for slow servers. + Thread.Sleep(500); + } + catch (SqlException e) + { + Assert.True(false, $"Unexpected Exception occurred: {e.Message}"); + } + finally + { + cmd.CommandText = $"DROP DATABASE {databaseName}"; + cmd.ExecuteNonQuery(); } } @@ -194,22 +174,17 @@ private static bool IsColumnBitSet(SqlConnection con, string selectQuery, int in { bool columnSetPresent = false; { - using (SqlCommand cmd = con.CreateCommand()) + using SqlCommand cmd = con.CreateCommand(); + cmd.CommandText = selectQuery; + using SqlDataReader reader = cmd.ExecuteReader(); + DataTable schemaTable = reader.GetSchemaTable(); + for (int i = 0; i < schemaTable.Rows.Count; i++) { - cmd.CommandText = selectQuery; - using (SqlDataReader reader = cmd.ExecuteReader()) - { - DataTable schemaTable = reader.GetSchemaTable(); - - for (int i = 0; i < schemaTable.Rows.Count; i++) - { - bool isColumnSet = (bool)schemaTable.Rows[i]["IsColumnSet"]; + bool isColumnSet = (bool)schemaTable.Rows[i]["IsColumnSet"]; - if (indexOfColumnSet == i) - { - columnSetPresent = true; - } - } + if (indexOfColumnSet == i) + { + columnSetPresent = true; } } } @@ -229,49 +204,43 @@ public static void CheckHiddenColumns() string tempTableName = DataTestUtility.GenerateObjectName(); string createQuery = $@" -create table [{tempTableName}] ( - user_id int not null identity(1,1), - first_name varchar(100) null, - last_name varchar(100) null); + create table [{tempTableName}] ( + user_id int not null identity(1,1), + first_name varchar(100) null, + last_name varchar(100) null); alter table [{tempTableName}] add constraint pk_{tempTableName}_user_id primary key (user_id); -insert into [{tempTableName}] (first_name,last_name) values ('Joe','Smith') -"; + insert into [{tempTableName}] (first_name,last_name) values ('Joe','Smith') + "; string dataQuery = $@"select first_name, last_name from [{tempTableName}]"; - int fieldCount = 0; int visibleFieldCount = 0; Type[] types = null; string[] names = null; - using (SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + using (SqlConnection connection = new(DataTestUtility.TCPConnectionString)) { connection.Open(); try { - using (SqlCommand createCommand = new SqlCommand(createQuery, connection)) + using (SqlCommand createCommand = new(createQuery, connection)) { createCommand.ExecuteNonQuery(); } - - using (SqlCommand queryCommand = new SqlCommand(dataQuery, connection)) + using SqlCommand queryCommand = new(dataQuery, connection); + using SqlDataReader reader = queryCommand.ExecuteReader(CommandBehavior.KeyInfo); + fieldCount = reader.FieldCount; + visibleFieldCount = reader.VisibleFieldCount; + types = new Type[fieldCount]; + names = new string[fieldCount]; + for (int index = 0; index < fieldCount; index++) { - using (SqlDataReader reader = queryCommand.ExecuteReader(CommandBehavior.KeyInfo)) - { - fieldCount = reader.FieldCount; - visibleFieldCount = reader.VisibleFieldCount; - types = new Type[fieldCount]; - names = new string[fieldCount]; - for (int index = 0; index < fieldCount; index++) - { - types[index] = reader.GetFieldType(index); - names[index] = reader.GetName(index); - } - } + types[index] = reader.GetFieldType(index); + names[index] = reader.GetName(index); } } finally @@ -301,20 +270,18 @@ public static void CheckNullRowVersionIsBDNull() bool? originalValue = SetLegacyRowVersionNullBehavior(false); try { - using (SqlConnection con = new SqlConnection(DataTestUtility.TCPConnectionString)) - { - con.Open(); - using (SqlCommand command = con.CreateCommand()) - { - command.CommandText = "select cast(null as rowversion) rv"; - using (SqlDataReader reader = command.ExecuteReader()) - { - reader.Read(); - Assert.True(reader.IsDBNull(0)); - Assert.Equal(reader[0], DBNull.Value); - } - } - } + using SqlConnection con = new(DataTestUtility.TCPConnectionString); + con.Open(); + using SqlCommand command = con.CreateCommand(); + command.CommandText = "select cast(null as rowversion) rv"; + using SqlDataReader reader = command.ExecuteReader(); + reader.Read(); + Assert.True(reader.IsDBNull(0)); + Assert.Equal(DBNull.Value, reader[0]); + var result = reader.GetValue(0); + Assert.IsType(result); + Assert.Equal(result, reader.GetFieldValue(0)); + Assert.Throws(() => reader.GetFieldValue(0)); } finally { @@ -332,23 +299,20 @@ public static void CheckLegacyNullRowVersionIsEmptyArray() bool? originalValue = SetLegacyRowVersionNullBehavior(true); try { - using (SqlConnection con = new SqlConnection(DataTestUtility.TCPConnectionString)) - { - con.Open(); - using (SqlCommand command = con.CreateCommand()) - { - command.CommandText = "select cast(null as rowversion) rv"; - using (SqlDataReader reader = command.ExecuteReader()) - { - reader.Read(); - Assert.False(reader.IsDBNull(0)); - SqlBinary value = reader.GetSqlBinary(0); - Assert.False(value.IsNull); - Assert.Equal(0, value.Length); - Assert.NotNull(value.Value); - } - } - } + using SqlConnection con = new(DataTestUtility.TCPConnectionString); + con.Open(); + using SqlCommand command = con.CreateCommand(); + command.CommandText = "select cast(null as rowversion) rv"; + using SqlDataReader reader = command.ExecuteReader(); + reader.Read(); + Assert.False(reader.IsDBNull(0)); + SqlBinary value = reader.GetSqlBinary(0); + Assert.False(value.IsNull); + Assert.Equal(0, value.Length); + Assert.NotNull(value.Value); + var result = reader.GetValue(0); + Assert.IsType(result); + Assert.Equal(result, reader.GetFieldValue(0)); } finally {