Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Commit of Transport Creation Handler #2569

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/Microsoft.Data.SqlClient.sln
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{4F3CD363-B1E6-4D6D-9466-97D78A56BE45}"
ProjectSection(SolutionItems) = preProject
Directory.Build.props = Directory.Build.props
NuGet.config = NuGet.config
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlServer.Server", "Microsoft.SqlServer.Server\Microsoft.SqlServer.Server.csproj", "{A314812A-7820-4565-A2A8-ABBE391C11E4}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,9 +923,12 @@
<Compile Include="Microsoft\Data\SqlClientX\IO\ITdsWriteStream.cs" />
<Compile Include="Microsoft\Data\SqlClientX\IO\TdsStreamPacketType.cs" />
<Compile Include="Microsoft\Data\SqlClientX\IO\TdsStream.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\Connection\ConnectionHandlerContext.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequest.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\HandlerRequestType.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\IHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\IpAddressVersionComparer.cs" />
<Compile Include="Microsoft\Data\SqlClientX\Handlers\TransportCreation\TransportCreationHandler.cs" />
<Compile Include="Microsoft\Data\SqlClientX\IO\TdsWriteStream.cs" />
<Compile Include="Microsoft\Data\SqlClientX\SqlConnectionX.cs" />
<EmbeddedResource Include="$(CommonSourceRoot)Resources\Strings.resx">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlClient.SNI;

namespace Microsoft.Data.SqlClientX.Handlers.Connection
{
/// <summary>
/// Class that contains data required to handle a connection request.
/// </summary>
// TODO: This will be updated as more information is available.
internal class ConnectionHandlerContext : HandlerRequest
{
/// <summary>
/// Stream that is created during connection.
/// </summary>
public Stream ConnectionStream { get; set; }

/// <summary>
/// Gets or sets the data source as parsed from the connection string. It will be used by
/// the transport creation handler to create the connection stream.
/// </summary>
public DataSource DataSource { get; set; }

/// <summary>
/// Gets or sets an exception that halted execution of the connection chain of handlers.
/// </summary>
public Exception Error { get; set; }

public SqlConnectionIPAddressPreference IpAddressPreference { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
/// <summary>
/// Comparer that sorts IP addresses based on the version of the internet protocol it is using.
/// This class cannot be instantiated, so to use it, use the singleton instances (doubleton?)
/// <see cref="InstanceV4"/> and <see cref="InstanceV6"/>.
/// </summary>
internal sealed class IpAddressVersionSorter : IComparer<IPAddress>
{
private readonly AddressFamily _preferredAddressFamily;

private IpAddressVersionSorter(AddressFamily preferredAddressFamily)
{
_preferredAddressFamily = preferredAddressFamily;
}

/// <summary>
/// Gets a singleton instance that ranks IPv4 addresses higher than IPv6 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV4 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetwork);

/// <summary>
/// Gets a singleton instance that ranks IPv6 addresses higher than IPv4 addresses.
/// </summary>
public static IpAddressVersionSorter InstanceV6 { get; } =
new IpAddressVersionSorter(AddressFamily.InterNetworkV6);

/// <inheritdoc />
public int Compare(IPAddress x, IPAddress y)
{
if (x is null) { throw new ArgumentNullException(nameof(x)); }
if (y is null) { throw new ArgumentNullException(nameof(y)); }

if (x.AddressFamily == y.AddressFamily)
{
// Versions are the same, it's a tie.
return 0;
}

return x.AddressFamily == _preferredAddressFamily ? 1 : -1;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.Data.SqlClient.SNI;
using Microsoft.Data.SqlClientX.Handlers.Connection;

namespace Microsoft.Data.SqlClientX.Handlers.TransportCreation
{
internal class TransportCreationHandler : IHandler<ConnectionHandlerContext>
{
private const int KeepAliveIntervalSeconds = 1;
private const int KeepAliveTimeSeconds = 30;

#if NET8_0_OR_GREATER
private static readonly TimeSpan DefaultPollTimeout = TimeSpan.FromSeconds(30);
#else
private const int DefaultPollTimeout = 30 * 100000; // 30 seconds as microseconds
#endif

/// <inheritdoc />
public IHandler<ConnectionHandlerContext> NextHandler { get; set; }

/// <inheritdoc />
public async ValueTask Handle(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
if (context.DataSource is null)
{
context.Error = new ArgumentNullException(nameof(context));
return;
}

try
{
// @TODO: Build CoR for handling the different protocols in order
if (context.DataSource.ResolvedProtocol is DataSource.Protocol.TCP)
{
context.ConnectionStream = await HandleTcpRequest(context, isAsync, ct).ConfigureAwait(false);
}
}
catch (Exception e)
{
context.Error = e;
}

if (NextHandler is not null)
{
await NextHandler.Handle(context, isAsync, ct).ConfigureAwait(false);
}
}

private ValueTask<Stream> HandleNamedPipeRequest()
{
throw new NotImplementedException();
}

private ValueTask<Stream> HandleSharedMemoryRequest()
{
throw new NotImplementedException();
}

private async ValueTask<Stream> HandleTcpRequest(ConnectionHandlerContext context, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();

// DNS lookup
var ipAddresses = isAsync
? await Dns.GetHostAddressesAsync(context.DataSource.ServerName, ct).ConfigureAwait(false)
: Dns.GetHostAddresses(context.DataSource.ServerName);
if (ipAddresses.Length == 0)
{
throw new Exception("Hostname did not resolve");
}

// If there is an IP version preference, apply it
switch (context.IpAddressPreference)
{
case SqlConnectionIPAddressPreference.IPv4First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV4);
break;

case SqlConnectionIPAddressPreference.IPv6First:
Array.Sort(ipAddresses, IpAddressVersionSorter.InstanceV6);
break;

case SqlConnectionIPAddressPreference.UsePlatformDefault:
default:
// Not sorting necessary
break;
}

// Attempt to connect to one of the matching IP addresses
// @TODO: Handle opening in parallel
Socket socket = null;
var socketOpenExceptions = new List<Exception>();

var ipEndpoint = new IPEndPoint(IPAddress.None, context.DataSource.Port); // Allocate once
foreach (var ipAddress in ipAddresses)
{
ipEndpoint.Address = ipAddress;
try
{
socket = await OpenSocket(ipEndpoint, isAsync, ct).ConfigureAwait(false);
break;
}
catch(Exception e)
{
socketOpenExceptions.Add(e);
}
}

// If no socket succeeded, throw
if (socket is null)
{
throw socketOpenExceptions.OfType<SocketException>().FirstOrDefault()
?? (Exception)new AggregateException(socketOpenExceptions);
}

// Create the stream for the socket
return new NetworkStream(socket);
}

private async ValueTask<Socket> OpenSocket(IPEndPoint ipEndPoint, bool isAsync, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sync cases, the callers will provide a default Cancellation token. the ct.ThrowIfCancellationRequested() may not be needed in sync paths.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to utilize the cancellation token as the way to signal that either the timeout has expired or the user has requested cancellation. This will be adopted in a later PR. As such, I think it would still be valuable to keep the cancellation token in the sync code path. @edwardneal had a good suggestion for how to cancel while the socket is opening by registering event on the cancellation token.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure how this would look end to end. If this theory holds, then great, else we will modify.


var socket = new Socket(ipEndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { Blocking = false };

// Enable keep-alive
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, KeepAliveIntervalSeconds);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, KeepAliveTimeSeconds);

if (isAsync)
{
#if NET6_0_OR_GREATER
await socket.ConnectAsync(ipEndPoint, ct).ConfigureAwait(false);
#else
await new TaskFactory(ct).FromAsync(socket.BeginConnect, socket.EndConnect, ipEndpoint, null)
.ConfigureAwait(false);
#endif
}
else
{
OpenSocketSync(socket, ipEndPoint, ct);
}

// Connection is established
socket.Blocking = true;
socket.NoDelay = true;

return socket;
}

private void OpenSocketSync(Socket socket, IPEndPoint ipEndPoint, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();

try
{
socket.Connect(ipEndPoint);
}
catch (SocketException e)
{
// Because the socket is configured to be non-blocking, any operation that would
// block will throw an exception indicating it would block. Since opening a TCP
// connection will always block, we expect to get an exception for it, and will
// ignore it. This allows us to immediately return from connect and poll it,
// allowing us to observe timeouts and cancellation.
if (e.SocketErrorCode is not SocketError.WouldBlock)
{
throw;
}
}

// Poll the socket until it is open
if (!socket.Poll(DefaultPollTimeout, SelectMode.SelectWrite))
{
throw new TimeoutException("Socket failed to open within timeout period.");
}
}
}
}