Skip to content

Commit

Permalink
[Hotfix 4.0.2] | Parallelize SSRP requests when MSF is specified (#1578
Browse files Browse the repository at this point in the history
…) (#1720)
  • Loading branch information
DavoudEshtehari authored Aug 23, 2022
1 parent be9731c commit f0eaa13
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ internal class SNICommon
internal const int ConnTimeoutError = 11;
internal const int ConnNotUsableError = 19;
internal const int InvalidConnStringError = 25;
internal const int ErrorLocatingServerInstance = 26;
internal const int HandshakeFailureError = 31;
internal const int InternalExceptionError = 35;
internal const int ConnOpenFailedError = 40;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
/// <param name="isIntegratedSecurity"></param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
Expand Down Expand Up @@ -263,7 +263,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
Expand All @@ -285,12 +285,12 @@ private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire
try
{
port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName);
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference);
}
catch (SocketException se)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InvalidConnStringError, se);
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.ErrorLocatingServerInstance, se);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel
bool reportError = true;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with cached IPs based on IPAddressPreference.
// The exceptions will be throw to upper level and be handled as before.
// We will always first try to connect with serverName as before and let DNS resolve the serverName.
// If DNS resolution fails, we will try with IPs in the DNS cache if they exist. We try with cached IPs based on IPAddressPreference.
// Exceptions will be thrown to the caller and be handled as before.
try
{
if (parallel)
Expand Down Expand Up @@ -280,7 +280,12 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
Task<Socket> connectTask;

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
serverAddrTask.Wait(ts);
bool complete = serverAddrTask.Wait(ts);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] serverAddresses = serverAddrTask.Result;

if (serverAddresses.Length > MaxParallelIpAddresses)
Expand Down Expand Up @@ -324,7 +329,6 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i

availableSocket = connectTask.Result;
return availableSocket;

}

// Connect to server with hostName and port.
Expand All @@ -334,7 +338,14 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
bool complete = serverAddrTask.Wait(timeout);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] ipAddresses = serverAddrTask.Result;

string IPv4String = null;
string IPv6String = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
Expand All @@ -21,8 +24,11 @@ internal class SSRP
/// </summary>
/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to find port number</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>port number for given instance name</returns>
internal static int GetPortByInstanceName(string browserHostName, string instanceName)
internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");
Expand All @@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
byte[] responsePacket = null;
try
{
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest);
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference);
}
catch (SocketException se)
{
Expand Down Expand Up @@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
/// </summary>
/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to lookup DAC port</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>DAC port for given instance name</returns>
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName)
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");

byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference);

const byte SvrResp = 0x05;
const byte ProtocolVersion = 0x01;
Expand Down Expand Up @@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
return requestPacket;
}

private class SsrpResult
{
public byte[] ResponsePacket;
public Exception Error;
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="browserHostname">UDP server hostname</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>response packet from UDP server</returns>
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket)
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
using (TrySNIEventScope.Create(nameof(SSRP)))
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostname), "browserhostname should not be null, empty, or whitespace");
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");

const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;
bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);

IPAddress address = null;
bool isIpAddress = IPAddress.TryParse(browserHostname, out address);
TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
// The infinite Timeout is a function of ConnectionString Timeout=0
if (long.MaxValue != timerExpire)
{
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

byte[] responsePacket = null;
using (UdpClient client = new UdpClient(!isIpAddress ? AddressFamily.InterNetwork : address.AddressFamily))
IPAddress[] ipAddresses = null;
if (!isIpAddress)
{
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
bool taskComplete;
try
{
taskComplete = serverAddrTask.Wait(ts);
}
catch (AggregateException ae)
{
throw ae.InnerException;
}

// If DNS took too long, need to return instead of blocking
if (!taskComplete)
return null;

ipAddresses = serverAddrTask.Result;
}

Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

switch (ipPreference)
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, port);
case SqlConnectionIPAddressPreference.IPv4First:
{
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

// No responses so throw first error
if (response4 != null && response4.Error != null)
throw response4.Error;
else if (response6 != null && response6.Error != null)
throw response6.Error;

break;
}
case SqlConnectionIPAddressPreference.IPv6First:
{
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

// No responses so throw first error
if (response6 != null && response6.Error != null)
throw response6.Error;
else if (response4 != null && response4.Error != null)
throw response4.Error;

break;
}
default:
{
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
if (response != null && response.ResponsePacket != null)
return response.ResponsePacket;
else if (response != null && response.Error != null)
throw response.Error;

break;
}
}

return null;
}
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="ipAddresses">IP Addresses</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <returns>response packet from UDP server</returns>
private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel)
{
if (ipAddresses.Length == 0)
return null;

if (allIPsInParallel) // Used for MultiSubnetFailover
{
List<Task<SsrpResult>> tasks = new(ipAddresses.Length);
CancellationTokenSource cts = new CancellationTokenSource();
for (int i = 0; i < ipAddresses.Length; i++)
{
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
}

List<Task<SsrpResult>> completedTasks = new();
while (tasks.Count > 0)
{
int first = Task.WaitAny(tasks.ToArray());
if (tasks[first].Result.ResponsePacket != null)
{
cts.Cancel();
return tasks[first].Result;
}
else
{
completedTasks.Add(tasks[first]);
tasks.Remove(tasks[first]);
}
}

Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0");

// All tasks failed. Return the error from the first failure.
return completedTasks[0].Result;
}
else
{
// If not parallel, use the first IP address provided
IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port);
return SendUDPRequest(endPoint, requestPacket);
}
}

private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket)
{
const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;

SsrpResult result = new();

try
{
using (UdpClient client = new UdpClient(endPoint.AddressFamily))
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint);
Task<UdpReceiveResult> receiveTask = null;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info.");
if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client.");
responsePacket = receiveTask.Result.Buffer;
result.ResponsePacket = receiveTask.Result.Buffer;
}
}

return responsePacket;
}
catch (Exception e)
{
result.Error = e;
}

return result;
}
}
}
Loading

0 comments on commit f0eaa13

Please sign in to comment.