Skip to content

Commit

Permalink
Initial Commit of Transport Creation Handler (#2569)
Browse files Browse the repository at this point in the history
* Initial code for transport creation handler

* Completing initial code for opening TCP stream.

* Addressing a handful of comments from @edwardneal

* Properly expose protocol via DataSource

* I'm honestly not sure I'm doing this isAsync thing right...

* Addressing more PR comments

* Fixing list of files from merge

* Addressing more PR comments

* Removing linq code

* More changes as per PR comments

* Fixing merge conflict on connection handler context
  • Loading branch information
benrr101 authored Jun 20, 2024
1 parent 6ec34b6 commit ec7485b
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 5 deletions.
1 change: 0 additions & 1 deletion src/Microsoft.Data.SqlClient.sln
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{4F3CD363-B1E6-4D6D-9466-97D78A56BE45}"
ProjectSection(SolutionItems) = preProject
Directory.Build.props = Directory.Build.props
NuGet.config = NuGet.config
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlServer.Server", "Microsoft.SqlServer.Server\Microsoft.SqlServer.Server.csproj", "{A314812A-7820-4565-A2A8-ABBE391C11E4}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,8 @@
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequest.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequestType.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\IHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\IpAddressVersionComparer.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\TransportCreationHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\IO\TdsWriteStream.cs" />
<Compile Include="Microsoft\Data\SqlClientX\SqlConnectionX.cs" />
<EmbeddedResource Include="$(CommonSourceRoot)Resources\Strings.resx">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ namespace Microsoft.Data.SqlClientX.Handlers
internal class ConnectionHandlerContext : HandlerRequest
{
/// <summary>
/// Class that contains data required to handle a connection request.
/// Stream used by readers.
/// </summary>
public SqlConnectionString ConnectionString { get; set; }
public Stream ConnectionStream { get; set; }

/// <summary>
/// Stream used by readers.
/// Class that contains data required to handle a connection request.
/// </summary>
public Stream ConnectionStream { get; set; }
public SqlConnectionString ConnectionString { get; set; }

/// <summary>
/// Class required by DataSourceParser and Transport layer.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
/// <summary>
/// Comparer that sorts IP addresses based on the version of the internet protocol it is using.
/// This class cannot be instantiated, so to use it, use the singleton instances (doubleton?)
/// <see cref="InstanceV4"/> and <see cref="InstanceV6"/>.
/// </summary>
internal sealed class IpAddressVersionSorter : IComparer<IPAddress>
{
private readonly AddressFamily _preferredAddressFamily;

private IpAddressVersionSorter(AddressFamily preferredAddressFamily)
{
_preferredAddressFamily = preferredAddressFamily;
}

/// <summary>
/// Gets a singleton instance that ranks IPv4 addresses higher than IPv6 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV4 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetwork);

/// <summary>
/// Gets a singleton instance that ranks IPv6 addresses higher than IPv4 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV6 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetworkV6);

/// <inheritdoc />
public int Compare(IPAddress x, IPAddress y)
{
if (x is null) { throw new ArgumentNullException(nameof(x)); }
if (y is null) { throw new ArgumentNullException(nameof(y)); }

if (x.AddressFamily == y.AddressFamily)
{
// Versions are the same, it's a tie.
return 0;
}

return x.AddressFamily == _preferredAddressFamily ? 1 : -1;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
internal sealed class TransportCreationHandler : IHandler<ConnectionHandlerContext>
{
private const int KeepAliveIntervalSeconds = 1;
private const int KeepAliveTimeSeconds = 30;

#if NET8_0_OR_GREATER
private static readonly TimeSpan DefaultPollTimeout = TimeSpan.FromSeconds(30);
#else
private const int DefaultPollTimeout = 30 * 100000; // 30 seconds as microseconds
#endif

/// <inheritdoc />
public IHandler<ConnectionHandlerContext> NextHandler { get; set; }

/// <inheritdoc />
public async ValueTask Handle(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
Debug.Assert(context.DataSource is not null, "context.DataSource is null");

try
{
// @TODO: Build CoR for handling the different protocols in order
if (context.DataSource.ResolvedProtocol is DataSource.Protocol.TCP)
{
context.ConnectionStream = await HandleTcpRequest(context, isAsync, ct).ConfigureAwait(false);
}
else
{
throw new NotImplementedException();
}
}
catch (Exception e)
{
context.Error = e;
return;
}

if (NextHandler is not null)
{
await NextHandler.Handle(context, isAsync, ct).ConfigureAwait(false);
}
}

private ValueTask<Stream> HandleNamedPipeRequest()
{
throw new NotImplementedException();
}

private ValueTask<Stream> HandleSharedMemoryRequest()
{
throw new NotImplementedException();
}

private async ValueTask<Stream> HandleTcpRequest(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();

// DNS lookup
IPAddress[] ipAddresses = isAsync
? await Dns.GetHostAddressesAsync(context.DataSource.ServerName, ct).ConfigureAwait(false)
: Dns.GetHostAddresses(context.DataSource.ServerName);
if (ipAddresses is null || ipAddresses.Length == 0)
{
throw new SocketException((int)SocketError.HostNotFound);
}

// If there is an IP version preference, apply it
switch (context.ConnectionString.IPAddressPreference)
{
case SqlConnectionIPAddressPreference.IPv4First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV4);
break;

case SqlConnectionIPAddressPreference.IPv6First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV6);
break;

case SqlConnectionIPAddressPreference.UsePlatformDefault:
default:
// Not sorting necessary
break;
}

// Attempt to connect to one of the matching IP addresses
// @TODO: Handle opening in parallel
Socket socket = null;
var socketOpenExceptions = new List<Exception>();

int portToUse = context.DataSource.ResolvedPort < 0
? context.DataSource.Port
: context.DataSource.ResolvedPort;
var ipEndpoint = new IPEndPoint(IPAddress.None, portToUse); // Allocate once
foreach (IPAddress ipAddress in ipAddresses)
{
ipEndpoint.Address = ipAddress;
try
{
socket = await OpenSocket(ipEndpoint, isAsync, ct).ConfigureAwait(false);
break;
}
catch(Exception e)
{
socketOpenExceptions.Add(e);
}
}

// If no socket succeeded, throw
if (socket is null)
{
// If there are any socket exceptions in the collected exceptions, throw the first
// one. If there are not, collect all exceptions and throw them as an aggregate.
foreach (Exception exception in socketOpenExceptions)
{
if (exception is SocketException)
{
throw exception;
}
}

throw new AggregateException(socketOpenExceptions);
}

// Create the stream for the socket
return new NetworkStream(socket);
}

private async ValueTask<Socket> OpenSocket(IPEndPoint ipEndPoint, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();

var socket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { Blocking = false };

// Enable keep-alive
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, KeepAliveIntervalSeconds);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, KeepAliveTimeSeconds);

try
{
if (isAsync)
{
#if NET6_0_OR_GREATER
await socket.ConnectAsync(ipEndPoint, ct).ConfigureAwait(false);
#else
// @TODO: Only real way to cancel this is to register a cancellation token event and dispose of the socket.
await new TaskFactory(ct).FromAsync(socket.BeginConnect, socket.EndConnect, ipEndPoint, null)
.ConfigureAwait(false);
#endif
}
else
{
OpenSocketSync(socket, ipEndPoint, ct);
}
}
catch (Exception)
{
socket.Dispose();
throw;
}

// Connection is established
socket.Blocking = true;
socket.NoDelay = true;

return socket;
}

private void OpenSocketSync(Socket socket, IPEndPoint ipEndPoint, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();

try
{
socket.Connect(ipEndPoint);
}
catch (SocketException e)
{
// Because the socket is configured to be non-blocking, any operation that would
// block will throw an exception indicating it would block. Since opening a TCP
// connection will always block, we expect to get an exception for it, and will
// ignore it. This allows us to immediately return from connect and poll it,
// allowing us to observe timeouts and cancellation.
if (e.SocketErrorCode is not SocketError.WouldBlock)
{
throw;
}
}

// Poll the socket until it is open
// @TODO: This method can't be cancelled, so we should consider pooling smaller timeouts and looping while
// there is still time left on the timer, checking cancellation token each time.
if (!socket.Poll(DefaultPollTimeout, SelectMode.SelectWrite))
{
throw new TimeoutException("Socket failed to open within timeout period.");
}
}
}
}

0 comments on commit ec7485b

Please sign in to comment.