diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml index 8d83781431..3d1ad11722 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlDataReader.xml @@ -319,7 +319,8 @@ Synchronously gets the value of the specified column as a type. is the asynchronous version of this method. The returned type object. - + .||| +|Stream|String|TextReader|UDT, which can be any CLR type marked with .| +|XmlReader|||| For more information, see [SqlClient Streaming Support](/sql/connect/ado-net/sqlclient-streaming-support). @@ -359,7 +361,8 @@ Asynchronously gets the value of the specified column as a type. is the synchronous version of this method. The returned type object. - + .||| +|Stream|String|TextReader|UDT, which can be any CLR type marked with .| +|XmlReader|||| For more information, see [SqlClient Streaming Support](/sql/connect/ado-net/sqlclient-streaming-support). diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs index 04aeaa552e..be5e1e330c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs @@ -133,7 +133,7 @@ internal SqlXml ToSqlXml() [MethodImpl(MethodImplOptions.NoInlining)] internal XmlReader ToXmlReader() { - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(ToStream(), closeInput: false); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(ToStream(), closeInput: false, async: false); } public bool IsNull diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 39c116bebc..f691da0e7a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -1499,7 +1499,7 @@ virtual public XmlReader GetXmlReader(int i) // Wrap the sequential stream in an XmlReader _currentStream = new SqlSequentialStream(this, i); _lastColumnWithDataChunkRead = i; - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true, async: false); } else { @@ -1509,7 +1509,7 @@ virtual public XmlReader GetXmlReader(int i) if (_data[i].IsNull) { // A 'null' stream - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(Array.Empty(), writable: false), closeInput: true); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(Array.Empty(), writable: false), closeInput: true, async: false); } else { @@ -2644,7 +2644,7 @@ override public T GetFieldValue(int i) statistics = SqlStatistics.StartTimer(Statistics); SetTimeout(_defaultTimeoutMilliseconds); - return GetFieldValueInternal(i); + return GetFieldValueInternal(i, isAsync: false); } finally { @@ -2780,7 +2780,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } } - private T GetFieldValueInternal(int i) + private T GetFieldValueInternal(int i, bool isAsync) { if (_currentTask != null) { @@ -2788,16 +2788,17 @@ private T GetFieldValueInternal(int i) } Debug.Assert(_stateObj == null || _stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); - bool result = TryReadColumn(i, setTimeout: false); + bool forStreaming = typeof(T) == typeof(XmlReader) || typeof(T) == typeof(TextReader) || typeof(T) == typeof(Stream); + bool result = TryReadColumn(i, setTimeout: false, forStreaming: forStreaming); if (!result) { throw SQL.SynchronousCallMayNotPend(); } - return GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i]); + return GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i], isAsync: isAsync); } - private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaData) + private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaData, bool isAsync) { // this block of type specific shortcuts uses RyuJIT jit behaviors to achieve fast implementations of the primitive types // RyuJIT will be able to determine at compilation time that the typeof(T)==typeof() options are constant @@ -2847,14 +2848,114 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met { return (T)(object)data.DateTime; } + else if (typeof(T) == typeof(XmlReader)) + { + // XmlReader only allowed on XML types + if (metaData.metaType.SqlDbType != SqlDbType.Xml) + { + throw SQL.XmlReaderNotSupportOnColumnType(metaData.column); + } + + if (IsCommandBehavior(CommandBehavior.SequentialAccess)) + { + // Wrap the sequential stream in an XmlReader + _currentStream = new SqlSequentialStream(this, metaData.ordinal); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true, async: isAsync); + } + else + { + if (data.IsNull) + { + // A 'null' stream + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(Array.Empty(), writable: false), closeInput: true, async: isAsync); + } + else + { + // Grab already read data + return (T)(object)data.SqlXml.CreateReader(); + } + } + } + else if (typeof(T) == typeof(TextReader)) + { + // Xml type is not supported + MetaType metaType = metaData.metaType; + if (metaData.cipherMD != null) + { + Debug.Assert(metaData.baseTI != null, "_metaData[i].baseTI should not be null."); + metaType = metaData.baseTI.metaType; + } + + if ( + (!metaType.IsCharType && metaType.SqlDbType != SqlDbType.Variant) || + (metaType.SqlDbType == SqlDbType.Xml) + ) + { + throw SQL.TextReaderNotSupportOnColumnType(metaData.column); + } + + // For non-variant types with sequential access, we support proper streaming + if ((metaType.SqlDbType != SqlDbType.Variant) && IsCommandBehavior(CommandBehavior.SequentialAccess)) + { + if (metaData.cipherMD != null) + { + throw SQL.SequentialAccessNotSupportedOnEncryptedColumn(metaData.column); + } + + System.Text.Encoding encoding = SqlUnicodeEncoding.SqlUnicodeEncodingInstance; + if (!metaType.IsNCharType) + { + encoding = metaData.encoding; + } + + _currentTextReader = new SqlSequentialTextReader(this, metaData.ordinal, encoding); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)_currentTextReader; + } + else + { + string value = data.IsNull ? string.Empty : data.SqlString.Value; + return (T)(object)new StringReader(value); + } + + } + else if (typeof(T) == typeof(Stream)) + { + if (metaData != null && metaData.cipherMD != null) + { + throw SQL.StreamNotSupportOnEncryptedColumn(metaData.column); + } + + // Stream is only for Binary, Image, VarBinary, Udt, Xml and Timestamp(RowVersion) types + MetaType metaType = metaData.metaType; + if ( + (!metaType.IsBinType || metaType.SqlDbType == SqlDbType.Timestamp) && + metaType.SqlDbType != SqlDbType.Variant + ) + { + throw SQL.StreamNotSupportOnColumnType(metaData.column); + } + + if ((metaType.SqlDbType != SqlDbType.Variant) && (IsCommandBehavior(CommandBehavior.SequentialAccess))) + { + _currentStream = new SqlSequentialStream(this, metaData.ordinal); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)_currentStream; + } + else + { + byte[] value = data.IsNull ? Array.Empty() : data.SqlBinary.Value; + return (T)(object)new MemoryStream(value, writable: false); + } + } else { - Type typeofT = typeof(T); - if (_typeofINullable.IsAssignableFrom(typeofT)) + if (typeof(INullable).IsAssignableFrom(typeof(T))) { // If its a SQL Type or Nullable UDT object rawValue = GetSqlValueFromSqlBufferInternal(data, metaData); - if (typeofT == s_typeofSqlString) + if (typeof(T) == s_typeofSqlString) { // Special case: User wants SqlString, but we have a SqlXml // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion @@ -2875,60 +2976,19 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { - if (typeof(XmlReader) == typeofT) + // the requested type is likely to be one that isn't supported so try the cast and + // unless there is a null value conversion then feedback the cast exception with + // type named to the user so they know what went wrong. Supported types are listed + // in the documentation + try { - if (metaData.metaType.SqlDbType != SqlDbType.Xml) - { - throw SQL.XmlReaderNotSupportOnColumnType(metaData.column); - } - else - { - object clrValue = null; - if (!data.IsNull) - { - clrValue = GetValueFromSqlBufferInternal(data, metaData); - } - if (clrValue is null) // covers IsNull and when there is data which is present but is a clr null somehow - { - return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( - new MemoryStream(Array.Empty(), writable: false), - closeInput: true - ); - } - else if (clrValue.GetType() == typeof(string)) - { - return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( - new StringReader(clrValue as string), - closeInput: true - ); - } - else - { - // try the type cast to throw the invalid cast exception and inform the user what types they're trying to use and that why it is wrong - return (T)clrValue; - } - } + return (T)GetValueFromSqlBufferInternal(data, metaData); } - else + catch (InvalidCastException) when (data.IsNull) { - try - { - return (T)GetValueFromSqlBufferInternal(data, metaData); - } - catch (InvalidCastException) - { - if (data.IsNull) - { - // If the value was actually null, then we should throw a SqlNullValue instead - throw SQL.SqlNullValue(); - } - else - { - // Legitimate InvalidCast, rethrow - throw; - } - } + throw SQL.SqlNullValue(); } + } } } @@ -3622,7 +3682,7 @@ private void ReadColumn(int i, bool setTimeout = true, bool allowPartiallyReadCo } } - private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn = false) + private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn = false, bool forStreaming = false) { CheckDataIsReady(columnIndex: i, permitAsync: true, allowPartiallyReadColumn: allowPartiallyReadColumn, methodName: null); @@ -3634,7 +3694,7 @@ private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn SetTimeout(_defaultTimeoutMilliseconds); } - if (!TryReadColumnInternal(i, readHeaderOnly: false)) + if (!TryReadColumnInternal(i, readHeaderOnly: false, forStreaming: forStreaming)) { return false; } @@ -3683,7 +3743,7 @@ private bool TryReadColumnHeader(int i) return TryReadColumnInternal(i, readHeaderOnly: true); } - private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) + internal bool TryReadColumnInternal(int i, bool readHeaderOnly = false, bool forStreaming = false) { AssertReaderState(requireData: true, permitAsync: true, columnIndex: i); @@ -3747,17 +3807,69 @@ private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) { _SqlMetaData columnMetaData = _metaData[_sharedState._nextColumnHeaderToRead]; - if ((isSequentialAccess) && (_sharedState._nextColumnHeaderToRead < i)) + if (isSequentialAccess) { - // SkipValue is no-op if the column appears in NBC bitmask - // if not, it skips regular and PLP types - if (!_parser.TrySkipValue(columnMetaData, _sharedState._nextColumnHeaderToRead, _stateObj)) + if (_sharedState._nextColumnHeaderToRead < i) { - return false; + // SkipValue is no-op if the column appears in NBC bitmask + // if not, it skips regular and PLP types + if (!_parser.TrySkipValue(columnMetaData, _sharedState._nextColumnHeaderToRead, _stateObj)) + { + return false; + } + + _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; + _sharedState._nextColumnHeaderToRead++; } + else if (_sharedState._nextColumnHeaderToRead == i) + { + bool isNull; + ulong dataLength; + if (!_parser.TryProcessColumnHeader(columnMetaData, _stateObj, _sharedState._nextColumnHeaderToRead, out isNull, out dataLength)) + { + return false; + } - _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; - _sharedState._nextColumnHeaderToRead++; + _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; + _sharedState._nextColumnHeaderToRead++; // We read this one + _sharedState._columnDataBytesRemaining = (long)dataLength; + + if (isNull) + { + if (columnMetaData.type != SqlDbType.Timestamp) + { + TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], + columnMetaData, + _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, + _parser.Connection); + } + } + else + { + if (!readHeaderOnly && !forStreaming) + { + // If we're in sequential mode try to read the data and then if it succeeds update shared + // state so there are no remaining bytes and advance the next column to read + if (!_parser.TryReadSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, (int)dataLength, _stateObj, + _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, + columnMetaData.column)) + { // will read UDTs as VARBINARY. + return false; + } + _sharedState._columnDataBytesRemaining = 0; + _sharedState._nextColumnDataToRead++; + } + else + { + _sharedState._columnDataBytesRemaining = (long)dataLength; + } + } + } + else + { + // we have read past the column somehow, this is an error + Debug.Assert(false, "We have read past the column somehow, this is an error"); + } } else { @@ -4967,7 +5079,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat var metaData = _metaData; if ((data != null) && (metaData != null)) { - return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i])); + return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i], isAsync:false)); } else { @@ -5007,7 +5119,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { _stateObj._shouldHaveEnoughData = true; #endif - return Task.FromResult(GetFieldValueInternal(i)); + return Task.FromResult(GetFieldValueInternal(i, isAsync:true)); #if DEBUG } finally @@ -5069,9 +5181,17 @@ private static Task GetFieldValueAsyncExecute(Task task, object state) reader.PrepareForAsyncContinuation(); } + if (typeof(T) == typeof(Stream) || typeof(T) == typeof(TextReader) || typeof(T) == typeof(XmlReader)) + { + if (reader.IsCommandBehavior(CommandBehavior.SequentialAccess) && reader._sharedState._dataReady && reader.TryReadColumnInternal(context._columnIndex, readHeaderOnly: true)) + { + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: true)); + } + } + if (reader.TryReadColumn(columnIndex, setTimeout: false)) { - return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex])); + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync:false)); } else { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs index 62f4d85591..000d2647cc 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCachedBuffer.cs @@ -7,9 +7,9 @@ using System.Data.SqlTypes; using System.Diagnostics; using System.IO; -using System.Reflection; using System.Runtime.CompilerServices; using System.Xml; +using Microsoft.Data.SqlTypes; namespace Microsoft.Data.SqlClient { @@ -134,26 +134,7 @@ internal SqlXml ToSqlXml() [MethodImpl(MethodImplOptions.NoInlining)] internal XmlReader ToXmlReader() { - //XmlTextReader xr = new XmlTextReader(fragment, XmlNodeType.Element, null); - XmlReaderSettings readerSettings = new XmlReaderSettings(); - readerSettings.ConformanceLevel = ConformanceLevel.Fragment; - - // Call internal XmlReader.CreateSqlReader from System.Xml. - // Signature: internal static XmlReader CreateSqlReader(Stream input, XmlReaderSettings settings, XmlParserContext inputContext); - MethodInfo createSqlReaderMethodInfo = typeof(System.Xml.XmlReader).GetMethod("CreateSqlReader", BindingFlags.Static | BindingFlags.NonPublic); - object[] args = new object[3] { ToStream(), readerSettings, null }; - XmlReader xr; - - new System.Security.Permissions.ReflectionPermission(System.Security.Permissions.ReflectionPermissionFlag.MemberAccess).Assert(); - try - { - xr = (XmlReader)createSqlReaderMethodInfo.Invoke(null, args); - } - finally - { - System.Security.Permissions.ReflectionPermission.RevertAssert(); - } - return xr; + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(ToStream(), closeInput: false, async: false); } public bool IsNull @@ -163,7 +144,5 @@ public bool IsNull return (_cachedBytes == null) ? true : false; } } - } - } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 9605e27035..0c2200321e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -95,10 +95,6 @@ internal class SharedState private CancellationTokenSource _cancelAsyncOnCloseTokenSource; private CancellationToken _cancelAsyncOnCloseToken; - // Used for checking if the Type parameter provided to GetValue is an INullable - internal static readonly Type _typeofINullable = typeof(INullable); - private static readonly Type _typeofSqlString = typeof(SqlString); - private SqlSequentialStream _currentStream; private SqlSequentialTextReader _currentTextReader; @@ -1739,7 +1735,7 @@ virtual public XmlReader GetXmlReader(int i) // Wrap the sequential stream in an XmlReader _currentStream = new SqlSequentialStream(this, i); _lastColumnWithDataChunkRead = i; - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true, async: false); } else { @@ -1749,7 +1745,7 @@ virtual public XmlReader GetXmlReader(int i) if (_data[i].IsNull) { // A 'null' stream - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(new byte[0], writable: false), closeInput: true); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(new byte[0], writable: false), closeInput: true, async: false); } else { @@ -3036,7 +3032,7 @@ override public T GetFieldValue(int i) statistics = SqlStatistics.StartTimer(Statistics); SetTimeout(_defaultTimeoutMilliseconds); - return GetFieldValueInternal(i); + return GetFieldValueInternal(i, isAsync: false); } finally { @@ -3172,7 +3168,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } } - private T GetFieldValueInternal(int i) + private T GetFieldValueInternal(int i, bool isAsync) { if (_currentTask != null) { @@ -3180,22 +3176,125 @@ private T GetFieldValueInternal(int i) } Debug.Assert(_stateObj == null || _stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); - bool result = TryReadColumn(i, setTimeout: false); + bool forStreaming = typeof(T) == typeof(XmlReader) || typeof(T) == typeof(TextReader) || typeof(T) == typeof(Stream); + bool result = TryReadColumn(i, setTimeout: false, forStreaming: forStreaming); if (!result) - { throw SQL.SynchronousCallMayNotPend(); } + { + throw SQL.SynchronousCallMayNotPend(); + } - return GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i]); + return GetFieldValueFromSqlBufferInternal(_data[i], _metaData[i], isAsync: isAsync); } - private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaData) + private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaData, bool isAsync) { - Type typeofT = typeof(T); - if (_typeofINullable.IsAssignableFrom(typeofT)) + if (typeof(T) == typeof(XmlReader)) + { + // XmlReader only allowed on XML types + if (metaData.metaType.SqlDbType != SqlDbType.Xml) + { + throw SQL.XmlReaderNotSupportOnColumnType(metaData.column); + } + + if (IsCommandBehavior(CommandBehavior.SequentialAccess)) + { + // Wrap the sequential stream in an XmlReader + _currentStream = new SqlSequentialStream(this, metaData.ordinal); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(_currentStream, closeInput: true, async: isAsync); + } + else + { + if (data.IsNull) + { + // A 'null' stream + return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(Array.Empty(), writable: false), closeInput: true, async: isAsync); + } + else + { + // Grab already read data + return (T)(object)data.SqlXml.CreateReader(); + } + } + } + else if (typeof(T) == typeof(TextReader)) + { + // Xml type is not supported + MetaType metaType = metaData.metaType; + if (metaData.cipherMD != null) + { + Debug.Assert(metaData.baseTI != null, "_metaData[i].baseTI should not be null."); + metaType = metaData.baseTI.metaType; + } + + if ( + (!metaType.IsCharType && metaType.SqlDbType != SqlDbType.Variant) || + (metaType.SqlDbType == SqlDbType.Xml) + ) + { + throw SQL.TextReaderNotSupportOnColumnType(metaData.column); + } + + // For non-variant types with sequential access, we support proper streaming + if ((metaType.SqlDbType != SqlDbType.Variant) && IsCommandBehavior(CommandBehavior.SequentialAccess)) + { + if (metaData.cipherMD != null) + { + throw SQL.SequentialAccessNotSupportedOnEncryptedColumn(metaData.column); + } + + System.Text.Encoding encoding = SqlUnicodeEncoding.SqlUnicodeEncodingInstance; + if (!metaType.IsNCharType) + { + encoding = metaData.encoding; + } + + _currentTextReader = new SqlSequentialTextReader(this, metaData.ordinal, encoding); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)_currentTextReader; + } + else + { + string value = data.IsNull ? string.Empty : data.SqlString.Value; + return (T)(object)new StringReader(value); + } + + } + else if (typeof(T) == typeof(Stream)) + { + if (metaData != null && metaData.cipherMD != null) + { + throw SQL.StreamNotSupportOnEncryptedColumn(metaData.column); + } + + // Stream is only for Binary, Image, VarBinary, Udt, Xml and Timestamp(RowVersion) types + MetaType metaType = metaData.metaType; + if ( + (!metaType.IsBinType || metaType.SqlDbType == SqlDbType.Timestamp) && + metaType.SqlDbType != SqlDbType.Variant + ) + { + throw SQL.StreamNotSupportOnColumnType(metaData.column); + } + + if ((metaType.SqlDbType != SqlDbType.Variant) && (IsCommandBehavior(CommandBehavior.SequentialAccess))) + { + _currentStream = new SqlSequentialStream(this, metaData.ordinal); + _lastColumnWithDataChunkRead = metaData.ordinal; + return (T)(object)_currentStream; + } + else + { + byte[] value = data.IsNull ? Array.Empty() : data.SqlBinary.Value; + return (T)(object)new MemoryStream(value, writable: false); + } + } + else if (typeof(INullable).IsAssignableFrom(typeof(T))) { // If its a SQL Type or Nullable UDT object rawValue = GetSqlValueFromSqlBufferInternal(data, metaData); - if (typeofT == _typeofSqlString) + if (typeof(T) == typeof(SqlString)) { // Special case: User wants SqlString, but we have a SqlXml // SqlXml can not be typecast into a SqlString, but we need to support SqlString on XML Types - so do a manual conversion @@ -3217,61 +3316,17 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { - if (typeof(XmlReader) == typeofT) + // Otherwise Its a CLR or non-Nullable UDT + try { - if (metaData.metaType.SqlDbType != SqlDbType.Xml) - { - throw SQL.XmlReaderNotSupportOnColumnType(metaData.column); - } - else - { - object clrValue = null; - if (!data.IsNull) - { - clrValue = GetValueFromSqlBufferInternal(data, metaData); - } - if (clrValue is null) - { // covers IsNull and when there is data which is present but is a clr null somehow - return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( - new MemoryStream(Array.Empty(), writable: false), - closeInput: true - ); - } - else if (clrValue.GetType() == typeof(string)) - { - return (T)(object)SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader( - new StringReader(clrValue as string), - closeInput: true - ); - } - else - { - // try the type cast to throw the invalid cast exception and inform the user what types they're trying to use and that why it is wrong - return (T)clrValue; - } - } + return (T)GetValueFromSqlBufferInternal(data, metaData); } - else + catch (InvalidCastException) when (data.IsNull) { - // Otherwise Its a CLR or non-Nullable UDT - try - { - return (T)GetValueFromSqlBufferInternal(data, metaData); - } - catch (InvalidCastException) - { - if (data.IsNull) - { - // If the value was actually null, then we should throw a SqlNullValue instead - throw SQL.SqlNullValue(); - } - else - { - // Legitmate InvalidCast, rethrow - throw; - } - } + // If the value was actually null, then we should throw a SqlNullValue instead + throw SQL.SqlNullValue(); } + } } @@ -4040,10 +4095,12 @@ private void ReadColumn(int i, bool setTimeout = true, bool allowPartiallyReadCo Debug.Assert(_stateObj == null || _stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); bool result = TryReadColumn(i, setTimeout, allowPartiallyReadColumn); if (!result) - { throw SQL.SynchronousCallMayNotPend(); } + { + throw SQL.SynchronousCallMayNotPend(); + } } - private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn = false) + private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn = false, bool forStreaming = false) { CheckDataIsReady(columnIndex: i, permitAsync: true, allowPartiallyReadColumn: allowPartiallyReadColumn); @@ -4068,7 +4125,7 @@ private bool TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn SetTimeout(_defaultTimeoutMilliseconds); } - if (!TryReadColumnInternal(i, readHeaderOnly: false)) + if (!TryReadColumnInternal(i, readHeaderOnly: false, forStreaming: forStreaming)) { return false; } @@ -4158,7 +4215,7 @@ private bool TryReadColumnHeader(int i) { tdsReliabilitySection.Start(); #endif //DEBUG - return TryReadColumnInternal(i, readHeaderOnly: true); + return TryReadColumnInternal(i, readHeaderOnly: true, forStreaming: false); #if DEBUG } finally @@ -4196,7 +4253,7 @@ private bool TryReadColumnHeader(int i) } } - private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) + internal bool TryReadColumnInternal(int i, bool readHeaderOnly/* = false*/, bool forStreaming) { AssertReaderState(requireData: true, permitAsync: true, columnIndex: i); @@ -4260,17 +4317,84 @@ private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) { _SqlMetaData columnMetaData = _metaData[_sharedState._nextColumnHeaderToRead]; - if ((isSequentialAccess) && (_sharedState._nextColumnHeaderToRead < i)) + if (isSequentialAccess) { - // SkipValue is no-op if the column appears in NBC bitmask - // if not, it skips regular and PLP types - if (!_parser.TrySkipValue(columnMetaData, _sharedState._nextColumnHeaderToRead, _stateObj)) + if (_sharedState._nextColumnHeaderToRead < i) { - return false; + // SkipValue is no-op if the column appears in NBC bitmask + // if not, it skips regular and PLP types + if (!_parser.TrySkipValue(columnMetaData, _sharedState._nextColumnHeaderToRead, _stateObj)) + { + return false; + } + + _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; + _sharedState._nextColumnHeaderToRead++; } + else if (_sharedState._nextColumnHeaderToRead == i) + { + bool isNull; + ulong dataLength; + if ( + !_parser.TryProcessColumnHeader( + columnMetaData, + _stateObj, + _sharedState._nextColumnHeaderToRead, + out isNull, + out dataLength + ) + ) + { + return false; + } - _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; - _sharedState._nextColumnHeaderToRead++; + _sharedState._nextColumnDataToRead = _sharedState._nextColumnHeaderToRead; + _sharedState._nextColumnHeaderToRead++; // We read this one + _sharedState._columnDataBytesRemaining = (long)dataLength; + + if (isNull) + { + if (columnMetaData.type != SqlDbType.Timestamp) + { + TdsParser.GetNullSqlValue( + _data[_sharedState._nextColumnDataToRead], + columnMetaData, + _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, + _parser.Connection + ); + } + } + else + { + if (!readHeaderOnly && !forStreaming) + { + // If we're in sequential mode try to read the data and then if it succeeds update shared + // state so there are no remaining bytes and advance the next column to read + if ( + !_parser.TryReadSqlValue( + _data[_sharedState._nextColumnDataToRead], + columnMetaData, + (int)dataLength, _stateObj, + _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, + columnMetaData.column + ) + ) + { // will read UDTs as VARBINARY. + return false; + } + _sharedState._columnDataBytesRemaining = 0; + _sharedState._nextColumnDataToRead++; + } + else + { + _sharedState._columnDataBytesRemaining = (long)dataLength; + } + } + } + else + { + Debug.Assert(false, "we have read past the column somehow, this is an error"); + } } else { @@ -4288,10 +4412,12 @@ private bool TryReadColumnInternal(int i, bool readHeaderOnly = false) // if LegacyRowVersionNullBehavior is enabled, Timestamp type must enter "else" block. if (isNull && (!LocalAppContextSwitches.LegacyRowVersionNullBehavior || columnMetaData.type != SqlDbType.Timestamp)) { - TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], - columnMetaData, - _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - _parser.Connection); + TdsParser.GetNullSqlValue( + _data[_sharedState._nextColumnDataToRead], + columnMetaData, + _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, + _parser.Connection + ); if (!readHeaderOnly) { @@ -5467,7 +5593,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat var metaData = _metaData; if ((data != null) && (metaData != null)) { - return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i])); + return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i], isAsync: false)); } else { @@ -5507,7 +5633,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { _stateObj._shouldHaveEnoughData = true; #endif - return Task.FromResult(GetFieldValueInternal(i)); + return Task.FromResult(GetFieldValueInternal(i, isAsync: true)); #if DEBUG } finally @@ -5563,9 +5689,33 @@ private static Task GetFieldValueAsyncExecute(Task task, object state) reader.PrepareForAsyncContinuation(); } + if (typeof(T) == typeof(Stream) || typeof(T) == typeof(TextReader) || typeof(T) == typeof(XmlReader)) + { + if (reader.IsCommandBehavior(CommandBehavior.SequentialAccess) && reader._sharedState._dataReady) + { + bool internalReadSuccess = false; + TdsParser.ReliabilitySection tdsReliabilitySection = new TdsParser.ReliabilitySection(); + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + tdsReliabilitySection.Start(); + internalReadSuccess = reader.TryReadColumnInternal(context._columnIndex, readHeaderOnly: true, forStreaming: false); + } + finally + { + tdsReliabilitySection.Stop(); + } + + if (internalReadSuccess) + { + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: true)); + } + } + } + if (reader.TryReadColumn(columnIndex, setTimeout: false)) { - return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex])); + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: false)); } else { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReaderSmi.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReaderSmi.cs index 1207285e5b..7c8d3fe7d9 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReaderSmi.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReaderSmi.cs @@ -135,7 +135,7 @@ public override T GetFieldValue(int ordinal) EnsureCanGetCol("GetFieldValue", ordinal); SmiQueryMetaData metaData = _currentMetaData[ordinal]; - if (_typeofINullable.IsAssignableFrom(typeof(T))) + if (typeof(INullable).IsAssignableFrom(typeof(T))) { // If its a SQL Type or Nullable UDT if (_currentConnection.IsKatmaiOrNewer) @@ -1044,7 +1044,7 @@ public override XmlReader GetXmlReader(int ordinal) stream = ValueUtilsSmi.GetStream(_readerEventSink, _currentColumnValuesV3, ordinal, _currentMetaData[ordinal], bypassTypeCheck: true); } - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(stream); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(stream, closeInput: false, async: false); } // diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs index 8bf15ead0e..2cbd875cf8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlTypes/SqlTypeWorkarounds.cs @@ -31,7 +31,7 @@ internal static class SqlTypeWorkarounds SqlCompareOptions.IgnoreNonSpace | SqlCompareOptions.IgnoreKanaType | SqlCompareOptions.BinarySort | SqlCompareOptions.BinarySort2; - internal static XmlReader SqlXmlCreateSqlXmlReader(Stream stream, bool closeInput = false, bool async = false) + internal static XmlReader SqlXmlCreateSqlXmlReader(Stream stream, bool closeInput, bool async) { Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async"); @@ -42,7 +42,7 @@ internal static XmlReader SqlXmlCreateSqlXmlReader(Stream stream, bool closeInpu return XmlReader.Create(stream, settingsToUse); } - internal static XmlReader SqlXmlCreateSqlXmlReader(TextReader textReader, bool closeInput = false, bool async = false) + internal static XmlReader SqlXmlCreateSqlXmlReader(TextReader textReader, bool closeInput, bool async) { Debug.Assert(closeInput || !async, "Currently we do not have pre-created settings for !closeInput+async"); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 94288d7de8..23ed738f37 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -66,6 +66,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderStreamsTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderStreamsTest.cs new file mode 100644 index 0000000000..8c8b2bc1d4 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataReaderTest/DataReaderStreamsTest.cs @@ -0,0 +1,679 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using System.Xml; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public static class DataReaderStreamsTest + { + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async Task GetFieldValueAsync_OfStream(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + byte[] originalData = CreateBinaryData(PacketSize, forcedPacketCount: 4); + string query = CreateBinaryDataQuery(originalData); + + string streamTypeName = null; + byte[] outputData = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (MemoryStream buffer = new MemoryStream(originalData.Length)) + using (Stream stream = await reader.GetFieldValueAsync(1)) + { + streamTypeName = stream.GetType().Name; + await stream.CopyToAsync(buffer); + outputData = buffer.ToArray(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputData); + Assert.Equal(originalData.Length, outputData.Length); + Assert.Equal(originalData, outputData); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async Task GetFieldValueAsync_OfXmlReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalXml = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateXmlDataQuery(originalXml); + + bool isAsync = false; + string outputXml = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (XmlReader xmlReader = await reader.GetFieldValueAsync(1)) + { + isAsync = xmlReader.Settings.Async; + outputXml = GetXmlDocumentContents(xmlReader); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || isAsync); + Assert.NotNull(outputXml); + Assert.Equal(originalXml.Length, outputXml.Length); + Assert.Equal(originalXml, outputXml); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async Task GetFieldValueAsync_OfTextReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalText = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateTextDataQuery(originalText); + + string streamTypeName = null; + string outputText = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (TextReader textReader = await reader.GetFieldValueAsync(1)) + { + streamTypeName = textReader.GetType().Name; + outputText = await textReader.ReadToEndAsync(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputText); + Assert.Equal(originalText.Length, outputText.Length); + Assert.Equal(originalText, outputText); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async Task GetFieldValueAsync_Char_OfTextReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalText = new ('c', PacketSize * 4); + string query = CreateCharDataQuery(originalText); + + string streamTypeName = null; + string outputText = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (TextReader textReader = await reader.GetFieldValueAsync(1)) + { + streamTypeName = textReader.GetType().Name; + outputText = await textReader.ReadToEndAsync(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputText); + Assert.Equal(originalText.Length, outputText.Length); + Assert.Equal(originalText, outputText); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetFieldValue_OfXmlReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalXml = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateXmlDataQuery(originalXml); + + bool isAsync = false; + string outputXml = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (XmlReader xmlReader = reader.GetFieldValue(1)) + { + isAsync = xmlReader.Settings.Async; + outputXml = GetXmlDocumentContents(xmlReader); + } + } + } + } + + Assert.False(isAsync); + Assert.NotNull(outputXml); + Assert.Equal(originalXml.Length, outputXml.Length); + Assert.Equal(originalXml, outputXml); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetFieldValue_OfStream(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + byte[] originalData = CreateBinaryData(PacketSize, forcedPacketCount: 4); + string query = CreateBinaryDataQuery(originalData); + + string streamTypeName = null; + byte[] outputData = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (Stream stream = reader.GetFieldValue(1)) + { + streamTypeName = stream.GetType().Name; + outputData = GetStreamContents(stream); + } + } + } + } + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputData); + Assert.Equal(originalData.Length, outputData.Length); + Assert.Equal(originalData, outputData); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetFieldValue_OfTextReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalText = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateTextDataQuery(originalText); + + string streamTypeName = null; + string outputText = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (TextReader textReader = reader.GetFieldValue(1)) + { + streamTypeName = textReader.GetType().Name; + outputText = textReader.ReadToEnd(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputText); + Assert.Equal(originalText.Length, outputText.Length); + Assert.Equal(originalText, outputText); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetStream(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + byte[] originalData = CreateBinaryData(PacketSize, forcedPacketCount: 4); + string query = CreateBinaryDataQuery(originalData); + + string streamTypeName = null; + byte[] outputData = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (MemoryStream buffer = new MemoryStream(originalData.Length)) + using (Stream stream = reader.GetStream(1)) + { + streamTypeName = stream.GetType().Name; + stream.CopyTo(buffer); + outputData = buffer.ToArray(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputData); + Assert.Equal(originalData.Length, outputData.Length); + Assert.Equal(originalData, outputData); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetXmlReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalXml = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateXmlDataQuery(originalXml); + + bool isAsync = false; + string outputXml = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (XmlReader xmlReader = reader.GetXmlReader(1)) + { + isAsync = xmlReader.Settings.Async; + outputXml = GetXmlDocumentContents(xmlReader); + } + } + } + } + + Assert.False(isAsync); + Assert.NotNull(outputXml); + Assert.Equal(originalXml.Length, outputXml.Length); + Assert.Equal(originalXml, outputXml); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehavioursAndIsAsync))] + public static async void GetTextReader(CommandBehavior behavior, bool isExecuteAsync) + { + const int PacketSize = 512; // force minimun packet size so that the test data spans multiple packets to test sequential access spanning + string connectionString = SetConnectionStringPacketSize(DataTestUtility.TCPConnectionString, PacketSize); + string originalText = CreateXmlData(PacketSize, forcedPacketCount: 4); + string query = CreateTextDataQuery(originalText); + + string streamTypeName = null; + string outputText = null; + using (SqlConnection connection = new SqlConnection(connectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + using (SqlDataReader reader = await ExecuteReader(command, behavior, isExecuteAsync)) + { + if (await Read(reader, isExecuteAsync)) + { + using (TextReader textReader = reader.GetTextReader(1)) + { + streamTypeName = textReader.GetType().Name; + outputText = textReader.ReadToEnd(); + } + } + } + } + + Assert.True(behavior != CommandBehavior.SequentialAccess || streamTypeName.Contains("Sequential")); + Assert.NotNull(outputText); + Assert.Equal(originalText.Length, outputText.Length); + Assert.Equal(originalText, outputText); + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehaviourAndAccessorTypes))] + public static void NullStreamProperties(CommandBehavior behavior, AccessorType accessorType) + { + string query = "SELECT convert(xml,NULL) AS XmlData, convert(nvarchar(max),NULL) as TextData, convert(varbinary(max),NULL) as StreamData"; + + using (SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + + // do clean queries to get field values again in case of sequential mode + + using (SqlDataReader reader = command.ExecuteReader(behavior)) + { + if (reader.Read()) + { + Assert.True(reader.IsDBNull(0)); + Assert.True(reader.IsDBNull(1)); + Assert.True(reader.IsDBNull(2)); + } + } + + using (SqlDataReader reader = command.ExecuteReader(behavior)) + { + if (reader.Read()) + { + Assert.True(reader.IsDBNullAsync(0).GetAwaiter().GetResult()); + Assert.True(reader.IsDBNullAsync(1).GetAwaiter().GetResult()); + Assert.True(reader.IsDBNullAsync(2).GetAwaiter().GetResult()); + } + } + + using (SqlDataReader reader = command.ExecuteReader(behavior)) + { + if (reader.Read()) + { + using (XmlReader xmlReader = GetValue(reader, 0, accessorType)) + { + Assert.NotNull(xmlReader); + Assert.Equal(accessorType == AccessorType.GetFieldValueAsync, xmlReader.Settings.Async); + Assert.Equal(xmlReader.Value, string.Empty); + Assert.False(xmlReader.Read()); + Assert.True(xmlReader.EOF); + } + + using (TextReader textReader = GetValue(reader, 1, accessorType)) + { + Assert.NotNull(textReader); + Assert.True(behavior != CommandBehavior.SequentialAccess || textReader.GetType().Name.Contains("Sequential")); + Assert.Equal(textReader.ReadToEnd(), string.Empty); + } + + using (Stream stream = GetValue(reader, 2, accessorType)) + { + Assert.NotNull(stream); + Assert.True(behavior != CommandBehavior.SequentialAccess || stream.GetType().Name.Contains("Sequential")); + } + } + } + + using (SqlDataReader reader = command.ExecuteReader(behavior)) + { + if (reader.Read()) + { + // get a clean reader over the same field and check that the value is empty + using (XmlReader xmlReader = GetValue(reader, 0, accessorType)) + { + Assert.Equal(GetXmlDocumentContents(xmlReader), string.Empty); + } + + using (TextReader textReader = GetValue(reader, 1, accessorType)) + { + Assert.Equal(textReader.ReadToEnd(), string.Empty); + } + + using (Stream stream = GetValue(reader, 2, accessorType)) + { + Assert.Equal(GetStreamContents(stream), Array.Empty()); + } + } + } + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [MemberData(nameof(GetCommandBehaviourAndAccessorTypes))] + public static void InvalidCastExceptionStream(CommandBehavior behavior, AccessorType accessorType) + { + string query = "SELECT convert(xml,NULL) AS XmlData, convert(nvarchar(max),NULL) as TextData"; + + using (SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString)) + using (SqlCommand command = new SqlCommand(query, connection)) + { + connection.Open(); + + using (SqlDataReader reader = command.ExecuteReader(behavior)) + { + Assert.True(reader.Read(), "It's excpected to read a row."); + + InvalidCastException ex = Assert.Throws(() => GetValue(reader, 0, accessorType)); + Assert.Contains("The GetTextReader function can only be used on columns of type Char, NChar, NText, NVarChar, Text or VarChar.", ex.Message); + + ex = Assert.Throws(() => GetValue(reader, 0, accessorType)); + Assert.Contains("The GetStream function can only be used on columns of type Binary, Image, Udt or VarBinary.", ex.Message); + + ex = Assert.Throws(() => GetValue(reader, 1, accessorType)); + Assert.Contains("The GetXmlReader function can only be used on columns of type Xml.", ex.Message); + } + } + } + + private static async Task ExecuteReader(SqlCommand command, CommandBehavior behavior, bool isExecuteAsync) + => isExecuteAsync ? await command.ExecuteReaderAsync(behavior) : command.ExecuteReader(behavior); + + private static async Task Read(SqlDataReader reader, bool isExecuteAsync) + => isExecuteAsync ? await reader.ReadAsync() : reader.Read(); + + public static IEnumerable GetCommandBehaviourAndAccessorTypes() + { + foreach (CommandBehavior behavior in new CommandBehavior[] { CommandBehavior.Default, CommandBehavior.SequentialAccess }) + { + foreach (AccessorType accessorType in new AccessorType[] { AccessorType.GetNamedValue, AccessorType.GetFieldValue, AccessorType.GetFieldValueAsync }) + { + yield return new object[] { behavior, accessorType }; + } + } + } + + public static IEnumerable GetCommandBehavioursAndIsAsync() + { + foreach (CommandBehavior behavior in new CommandBehavior[] { CommandBehavior.Default, CommandBehavior.SequentialAccess }) + { + yield return new object[] { behavior, true }; + yield return new object[] { behavior, false }; + } + } + + public enum AccessorType + { + GetNamedValue, // GetStream, GetXmlReader, GetTextReader + GetFieldValue, + GetFieldValueAsync + } + + private static T GetValue(SqlDataReader reader, int ordinal, AccessorType accesor) + { + switch (accesor) + { + case AccessorType.GetFieldValue: + return GetFieldValue(reader, ordinal); + case AccessorType.GetFieldValueAsync: + return GetFieldValueAsync(reader, ordinal); + case AccessorType.GetNamedValue: + return GetNamedValue(reader, ordinal); + default: + throw new NotSupportedException(); + } + } + + private static T GetFieldValueAsync(SqlDataReader reader, int ordinal) + { + return reader.GetFieldValueAsync(ordinal).GetAwaiter().GetResult(); + } + + private static T GetFieldValue(SqlDataReader reader, int ordinal) + { + return reader.GetFieldValue(ordinal); + } + + private static T GetNamedValue(SqlDataReader reader, int ordinal) + { + if (typeof(T) == typeof(XmlReader)) + { + return (T)(object)reader.GetXmlReader(ordinal); + } + else if (typeof(T) == typeof(TextReader)) + { + return (T)(object)reader.GetTextReader(ordinal); + } + else if (typeof(T) == typeof(Stream)) + { + return (T)(object)reader.GetStream(ordinal); + } + else + { + throw new NotSupportedException($"type {typeof(T).Name} is not a supported field type"); + } + } + + + private static string SetConnectionStringPacketSize(string connectionString, int packetSize) + { + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString); + builder.PersistSecurityInfo = true; + builder.PacketSize = packetSize; + return builder.ToString(); + } + + private static byte[] CreateBinaryData(int packetSize, int forcedPacketCount) + { + byte[] originalData = new byte[packetSize * forcedPacketCount]; // with header overhead this should cause forcedPacketCount+1 packets of data + Random random = new Random(100); // static seed for ease of debugging reproducibility + random.NextBytes(originalData); + return originalData; + } + + private static string CreateXmlData(int packetSize, int forcedPacketCount) + { + XmlWriterSettings settings = new XmlWriterSettings + { + ConformanceLevel = ConformanceLevel.Fragment, + Encoding = Encoding.Unicode, + Indent = true, + OmitXmlDeclaration = true + }; + StringBuilder buffer = new StringBuilder(2048); + using (StringWriter stringWriter = new StringWriter(buffer)) + using (XmlWriter xmlWriter = XmlWriter.Create(stringWriter, settings)) + { + int index = 1; + xmlWriter.WriteStartElement("root"); + while (buffer.Length / 2 < (packetSize * forcedPacketCount)) + { + xmlWriter.WriteStartElement("block"); + { + xmlWriter.WriteStartElement("value1"); + xmlWriter.WriteValue(index++); + xmlWriter.WriteEndElement(); + + xmlWriter.WriteStartElement("value2"); + xmlWriter.WriteValue(index++); + xmlWriter.WriteEndElement(); + + xmlWriter.WriteStartElement("value3"); + xmlWriter.WriteValue(index++); + xmlWriter.WriteEndElement(); + } + xmlWriter.WriteEndElement(); + } + xmlWriter.WriteEndElement(); + } + return buffer.ToString(); + } + + private static string CreateBinaryDataQuery(byte[] originalData) + { + StringBuilder queryBuilder = new StringBuilder(originalData.Length * 2 + 128); + queryBuilder.Append("SELECT 1 as DummyField, 0x"); + for (int index = 0; index < originalData.Length; index++) + { + queryBuilder.AppendFormat("{0:X2}", originalData[index]); + } + queryBuilder.Append(" AS Data"); + return queryBuilder.ToString(); + } + + private static string CreateXmlDataQuery(string originalXml) + { + StringBuilder queryBuilder = new StringBuilder(originalXml.Length + 128); + queryBuilder.Append("SELECT 1 as DummyField, convert(xml,'"); + queryBuilder.Append(originalXml); + queryBuilder.Append("') AS Data"); + return queryBuilder.ToString(); + } + + private static string CreateTextDataQuery(string originalText) + { + StringBuilder queryBuilder = new StringBuilder(originalText.Length + 128); + queryBuilder.Append("SELECT 1 as DummyField, convert(nvarchar(max),'"); + queryBuilder.Append(originalText); + queryBuilder.Append("') AS Data"); + return queryBuilder.ToString(); + } + + private static string CreateCharDataQuery(string originalText) + { + StringBuilder queryBuilder = new StringBuilder(originalText.Length + 128); + queryBuilder.Append($"SELECT 1 as DummyField, convert(char({originalText.Length}),'"); + queryBuilder.Append(originalText); + queryBuilder.Append("') AS Data"); + return queryBuilder.ToString(); + } + + private static string GetXmlDocumentContents(XmlReader xmlReader) + { + string outputXml; + XmlDocument document = new XmlDocument(); + document.Load(xmlReader); + + XmlWriterSettings settings = new XmlWriterSettings + { + ConformanceLevel = ConformanceLevel.Document, + Encoding = Encoding.Unicode, + Indent = true, + OmitXmlDeclaration = true + }; + + StringBuilder buffer = new StringBuilder(2048); + using (StringWriter stringWriter = new StringWriter(buffer)) + using (XmlWriter xmlWriter = XmlWriter.Create(stringWriter, settings)) + { + document.WriteContentTo(xmlWriter); + } + outputXml = buffer.ToString(); + return outputXml; + } + + private static byte[] GetStreamContents(Stream stream) + { + using (MemoryStream buffer = new MemoryStream()) + { + stream.CopyTo(buffer); + buffer.Flush(); + return buffer.ToArray(); + } + } + } +}