Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Change SqlClient to use strongly typed packet and session handles #33155

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ internal struct SNI_Error

#region DLL Imports
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIAddProviderWrapper")]
internal static extern uint SNIAddProvider(SNIHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo);
internal static extern uint SNIAddProvider(SNISessionHandle pConn, ProviderEnum ProvNum, [In] ref uint pInfo);
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICheckConnectionWrapper")]
internal static extern uint SNICheckConnection([In] SNIHandle pConn);
internal static extern uint SNICheckConnection([In] SNISessionHandle pConn);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNICloseWrapper")]
internal static extern uint SNIClose(IntPtr pConn);
Expand All @@ -197,7 +197,7 @@ internal struct SNI_Error
internal static extern void SNIPacketRelease(IntPtr pPacket);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIPacketResetWrapper")]
internal static extern void SNIPacketReset([In] SNIHandle pConn, IOType IOType, SNIPacket pPacket, ConsumerNumber ConsNum);
internal static extern void SNIPacketReset([In] SNISessionHandle pConn, IOType IOType, SNIPacketHandle pPacket, ConsumerNumber ConsNum);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
internal static extern uint SNIQueryInfo(QTypes QType, ref uint pbQInfo);
Expand All @@ -206,25 +206,25 @@ internal struct SNI_Error
internal static extern uint SNIQueryInfo(QTypes QType, ref IntPtr pbQInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIReadAsyncWrapper")]
internal static extern uint SNIReadAsync(SNIHandle pConn, ref IntPtr ppNewPacket);
internal static extern uint SNIReadAsync(SNISessionHandle pConn, ref IntPtr ppNewPacket);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
internal static extern uint SNIReadSyncOverAsync(SNIHandle pConn, ref IntPtr ppNewPacket, int timeout);
internal static extern uint SNIReadSyncOverAsync(SNISessionHandle pConn, ref IntPtr ppNewPacket, int timeout);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIRemoveProviderWrapper")]
internal static extern uint SNIRemoveProvider(SNIHandle pConn, ProviderEnum ProvNum);
internal static extern uint SNIRemoveProvider(SNISessionHandle pConn, ProviderEnum ProvNum);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
internal static extern uint SNISecInitPackage(ref uint pcbMaxToken);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISetInfoWrapper")]
internal static extern uint SNISetInfo(SNIHandle pConn, QTypes QType, [In] ref uint pbQInfo);
internal static extern uint SNISetInfo(SNISessionHandle pConn, QTypes QType, [In] ref uint pbQInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
internal static extern uint SNITerminate();

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNIWaitForSSLHandshakeToCompleteWrapper")]
internal static extern uint SNIWaitForSSLHandshakeToComplete([In] SNIHandle pConn, int dwMilliseconds);
internal static extern uint SNIWaitForSSLHandshakeToComplete([In] SNISessionHandle pConn, int dwMilliseconds);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
internal static extern uint UnmanagedIsTokenRestricted([In] IntPtr token, [MarshalAs(UnmanagedType.Bool)] out bool isRestricted);
Expand All @@ -233,7 +233,7 @@ internal struct SNI_Error
private static extern uint GetSniMaxComposedSpnLength();

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIGetInfoWrapper([In] SNIHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo);
private static extern uint SNIGetInfoWrapper([In] SNISessionHandle pConn, SNINativeMethodWrapper.QTypes QType, out Guid pbQInfo);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIInitialize([In] IntPtr pmo);
Expand All @@ -245,7 +245,7 @@ internal struct SNI_Error
private static extern uint SNIOpenWrapper(
[In] ref Sni_Consumer_Info pConsumerInfo,
[MarshalAs(UnmanagedType.LPStr)] string szConnect,
[In] SNIHandle pConn,
[In] SNISessionHandle pConn,
out IntPtr ppConn,
[MarshalAs(UnmanagedType.Bool)] bool fSync);

Expand All @@ -256,11 +256,11 @@ private static extern uint SNIOpenWrapper(
private static extern uint SNIPacketGetDataWrapper([In] IntPtr packet, [In, Out] byte[] readBuffer, uint readBufferLength, out uint dataSize);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, uint cbBuf);
private static extern unsafe void SNIPacketSetData(SNIPacketHandle pPacket, [In] byte* pbBuf, uint cbBuf);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In] SNISessionHandle pConn,
[In, Out] byte[] pIn,
uint cbIn,
[In, Out] byte[] pOut,
Expand All @@ -272,13 +272,13 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIWriteAsyncWrapper(SNIHandle pConn, [In] SNIPacket pPacket);
private static extern uint SNIWriteAsyncWrapper(SNISessionHandle pConn, [In] SNIPacketHandle pPacket);

[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern uint SNIWriteSyncOverAsync(SNIHandle pConn, [In] SNIPacket pPacket);
private static extern uint SNIWriteSyncOverAsync(SNISessionHandle pConn, [In] SNIPacketHandle pPacket);
#endregion

internal static uint SniGetConnectionId(SNIHandle pConn, ref Guid connId)
internal static uint SniGetConnectionId(SNISessionHandle pConn, ref Guid connId)
{
return SNIGetInfoWrapper(pConn, QTypes.SNI_QUERY_CONN_CONNID, out connId);
}
Expand All @@ -288,7 +288,7 @@ internal static uint SNIInitialize()
return SNIInitialize(IntPtr.Zero);
}

internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNIHandle parent, ref IntPtr pConn, bool fSync)
internal static unsafe uint SNIOpenMarsSession(ConsumerInfo consumerInfo, SNISessionHandle parent, ref IntPtr pConn, bool fSync)
{
// initialize consumer info for MARS
Sni_Consumer_Info native_consumerInfo = new Sni_Consumer_Info();
Expand Down Expand Up @@ -347,15 +347,15 @@ internal static unsafe uint SNIPacketGetData(IntPtr packet, byte[] readBuffer, r
return SNIPacketGetDataWrapper(packet, readBuffer, (uint)readBuffer.Length, out dataSize);
}

internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int length)
internal static unsafe void SNIPacketSetData(SNIPacketHandle packet, byte[] data, int length)
{
fixed (byte* pin_data = &data[0])
{
SNIPacketSetData(packet, pin_data, (uint)length);
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNISessionHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
{
fixed (byte* pin_serverUserName = &serverUserName[0])
{
Expand All @@ -374,7 +374,7 @@ internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject,
}
}

internal static uint SNIWritePacket(SNIHandle pConn, SNIPacket packet, bool sync)
internal static uint SNIWritePacket(SNISessionHandle pConn, SNIPacketHandle packet, bool sync)
{
if (sync)
{
Expand Down
1 change: 1 addition & 0 deletions src/System.Data.SqlClient/src/System.Data.SqlClient.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<AssemblyVersion Condition="'$(TargetGroup)' == 'netstandard1.2'">4.0.0.0</AssemblyVersion>
<AssemblyVersion Condition="'$(TargetGroup)' == 'netstandard1.3'">4.1.0.0</AssemblyVersion>
<DefineConstants Condition="'$(TargetsNetCoreApp)' == 'true'">$(DefineConstants);netcoreapp</DefineConstants>
<DefineConstants Condition=" '$(TargetsWindows)' == 'true' And '$(IsPartialFacadeAssembly)' != 'true' and '$(IsUAPAssembly)' != 'true'">$(DefineConstants);snidll</DefineConstants>
<Configurations>net461-Windows_NT-Debug;net461-Windows_NT-Release;netcoreapp-Debug;netcoreapp-Release;netcoreapp-Unix-Debug;netcoreapp-Unix-Release;netcoreapp-Windows_NT-Debug;netcoreapp-Windows_NT-Release;netcoreapp2.1-Debug;netcoreapp2.1-Release;netcoreapp2.1-Unix-Debug;netcoreapp2.1-Unix-Release;netcoreapp2.1-Windows_NT-Debug;netcoreapp2.1-Windows_NT-Release;netfx-Windows_NT-Debug;netfx-Windows_NT-Release;netstandard-Debug;netstandard-Release;netstandard-Unix-Debug;netstandard-Unix-Release;netstandard-Windows_NT-Debug;netstandard-Windows_NT-Release;netstandard1.2-Debug;netstandard1.2-Release;netstandard1.3-Debug;netstandard1.3-Release;uap-Windows_NT-Debug;uap-Windows_NT-Release;uap10.0.16299-Windows_NT-Debug;uap10.0.16299-Windows_NT-Release</Configurations>
</PropertyGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netstandard' OR '$(TargetsNetCoreApp)' == 'true' OR '$(IsUAPAssembly)' == 'true' ">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public void HandleReceiveError(SNIPacket packet)
_packetEvent.Set();
}

((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 1);
((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1);
}

/// <summary>
Expand All @@ -331,7 +331,7 @@ public void HandleSendComplete(SNIPacket packet, uint sniErrorCode)
{
Debug.Assert(_callbackObject != null);

((TdsParserStateObject)_callbackObject).WriteAsyncCallback(packet, sniErrorCode);
((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode);
}
}

Expand Down Expand Up @@ -377,7 +377,7 @@ public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header)
_asyncReceives--;
Debug.Assert(_callbackObject != null);

((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 0);
((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal class SNIPacket : IDisposable, IEquatable<SNIPacket>
private string _description;
private SNIAsyncCallback _completionCallback;

private ArrayPool<byte> _arrayPool = ArrayPool<byte>.Shared;
//private ArrayPool<byte> _arrayPool = ArrayPool<byte>.Shared;
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
private bool _isBufferFromArrayPool = false;

public SNIPacket() { }
Expand Down Expand Up @@ -98,14 +98,14 @@ public void Allocate(int capacity)
{
if (_isBufferFromArrayPool)
{
_arrayPool.Return(_data);
ArrayPool<byte>.Shared.Return(_data);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the default shared pool is being used it seems wasteful to carry around the extra field referring to it. The default pool can also be devirtualized by the jit possibly making the calls slightly faster.
I did wonder if it might be worth having an sql specific pool for packet buffers but considered the advice warning of creating too many pools and decided to err on the side of caution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default pool can also be devirtualized by the jit possibly making the calls slightly faster.

Yep, when referred to directly via ArrayPool<byte>.Shared dotnet/coreclr#20637

}
_data = null;
}

if (_data == null)
{
_data = _arrayPool.Rent(capacity);
_data = ArrayPool<byte>.Shared.Rent(capacity);
_isBufferFromArrayPool = true;
}

Expand Down Expand Up @@ -221,7 +221,7 @@ public void Release()
{
if(_isBufferFromArrayPool)
{
_arrayPool.Return(_data);
ArrayPool<byte>.Shared.Return(_data);
}
_data = null;
_capacity = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,22 @@ internal void PostReadAsyncForMars()
// Have to post read to initialize MARS - will get callback on this when connection goes
// down or is closed.

IntPtr temp = IntPtr.Zero;
PacketHandle temp = default;
uint error = TdsEnums.SNI_SUCCESS;

_pMarsPhysicalConObj.IncrementPendingCallbacks();
object handle = _pMarsPhysicalConObj.SessionHandle;
temp = (IntPtr)_pMarsPhysicalConObj.ReadAsync(out error, ref handle);
SessionHandle handle = _pMarsPhysicalConObj.SessionHandle;
temp = _pMarsPhysicalConObj.ReadAsync(handle, out error);

if (temp != IntPtr.Zero)
Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");

if (temp.NativePointer != IntPtr.Zero)
{
// Be sure to release packet, otherwise it will be leaked by native.
_pMarsPhysicalConObj.ReleasePacket(temp);
}

Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease");
Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease");
if (TdsEnums.SNI_SUCCESS_IO_PENDING != error)
{
Debug.Assert(TdsEnums.SNI_SUCCESS != error, "Unexpected successful read async on physical connection before enabling MARS!");
Expand Down Expand Up @@ -118,4 +120,4 @@ private SNIErrorDetails GetSniErrorDetails()
}

} // tdsparser
}//namespace
}//namespace
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private static void ReadDispatcher(IntPtr key, IntPtr packet, uint error)

if (null != stateObj)
{
stateObj.ReadAsyncCallback(IntPtr.Zero, packet, error);
stateObj.ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error);
}
}
}
Expand All @@ -125,19 +125,19 @@ private static void WriteDispatcher(IntPtr key, IntPtr packet, uint error)

if (null != stateObj)
{
stateObj.WriteAsyncCallback(IntPtr.Zero, packet, error);
stateObj.WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error);
}
}
}
}

internal sealed class SNIHandle : SafeHandle
internal sealed class SNISessionHandle : SafeHandle
{
private readonly uint _status = TdsEnums.SNI_UNINITIALIZED;
private readonly bool _fSync = false;

// creates a physical connection
internal SNIHandle(
internal SNISessionHandle(
SNINativeMethodWrapper.ConsumerInfo myInfo,
string serverName,
byte[] spnBuffer,
Expand Down Expand Up @@ -165,7 +165,7 @@ internal SNIHandle(
}

// constructs SNI Handle for MARS session
internal SNIHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNIHandle parent) : base(IntPtr.Zero, true)
internal SNISessionHandle(SNINativeMethodWrapper.ConsumerInfo myInfo, SNISessionHandle parent) : base(IntPtr.Zero, true)
{
try { }
finally
Expand Down Expand Up @@ -206,9 +206,9 @@ internal uint Status
}
}

internal sealed class SNIPacket : SafeHandle
internal sealed class SNIPacketHandle : SafeHandle
{
internal SNIPacket(SafeHandle sniHandle) : base(IntPtr.Zero, true)
internal SNIPacketHandle(SafeHandle sniHandle) : base(IntPtr.Zero, true)
{
SNINativeMethodWrapper.SNIPacketAllocate(sniHandle, SNINativeMethodWrapper.IOType.WRITE, ref base.handle);
if (IntPtr.Zero == base.handle)
Expand Down Expand Up @@ -241,17 +241,17 @@ override protected bool ReleaseHandle()
internal sealed class WritePacketCache : IDisposable
{
private bool _disposed;
private Stack<SNIPacket> _packets;
private Stack<SNIPacketHandle> _packets;

public WritePacketCache()
{
_disposed = false;
_packets = new Stack<SNIPacket>();
_packets = new Stack<SNIPacketHandle>();
}

public SNIPacket Take(SNIHandle sniHandle)
public SNIPacketHandle Take(SNISessionHandle sniHandle)
{
SNIPacket packet;
SNIPacketHandle packet;
if (_packets.Count > 0)
{
// Success - reset the packet
Expand All @@ -261,12 +261,12 @@ public SNIPacket Take(SNIHandle sniHandle)
else
{
// Failed to take a packet - create a new one
packet = new SNIPacket(sniHandle);
packet = new SNIPacketHandle(sniHandle);
}
return packet;
}

public void Add(SNIPacket packet)
public void Add(SNIPacketHandle packet)
{
if (!_disposed)
{
Expand Down Expand Up @@ -296,4 +296,4 @@ public void Dispose()
}
}
}
}
}
Loading