Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't call user callbacks on MsQuic worker thread. #98361

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

// WARNING: This function will do the verification using OpenSSL. If the intention is to use OS function, caller should use CertificatePal interface.
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, IntPtr certificateBuffer, int bufferLength = 0)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool _ /*isServer*/, string? hostName, Span<byte> certificateBuffer)
{
SslPolicyErrors errors = chain.Build(remoteCertificate) ?
SslPolicyErrors.None :
Expand All @@ -31,15 +31,24 @@ internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X
}

SafeX509Handle certHandle;
if (certificateBuffer != IntPtr.Zero && bufferLength > 0)
unsafe
{
certHandle = Interop.Crypto.DecodeX509(certificateBuffer, bufferLength);
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
certHandle = Interop.Crypto.DecodeX509(Marshal.UnsafeAddrOfPinnedArrayElement(der, 0), der.Length);
if (certificateBuffer.Length > 0)
{
fixed (byte* pCert = certificateBuffer)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pCert, certificateBuffer.Length);
}
}
else
{
// We dont't have DER encoded buffer.
byte[] der = remoteCertificate.Export(X509ContentType.Cert);
fixed (byte* pDer = der)
{
certHandle = Interop.Crypto.DecodeX509((IntPtr)pDer, der.Length);
}
}
}

int hostNameMatch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ internal static class CertificateValidation
private static readonly IdnMapping s_idnMapping = new IdnMapping();

#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace System.Net
internal static partial class CertificateValidation
{
#pragma warning disable IDE0060
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, IntPtr certificateBuffer, int bufferLength)
internal static SslPolicyErrors BuildChainAndVerifyProperties(X509Chain chain, X509Certificate2 remoteCertificate, bool checkCertName, bool isServer, string? hostName, Span<byte> certificateBuffer)
=> BuildChainAndVerifyProperties(chain, remoteCertificate, checkCertName, isServer, hostName);
#pragma warning restore IDE0060

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,55 @@ public int StreamReceiveSetEnabled(MsQuicSafeHandle stream, byte enabled)
}
}
}

public int DatagramSend(MsQuicSafeHandle connection, QUIC_BUFFER* buffers, uint buffersCount, QUIC_SEND_FLAGS flags, void* context)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->DatagramSend(connection.QuicHandle, buffers, buffersCount, flags, context);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}

public int ConnectionResumptionTicketValidationComplete(MsQuicSafeHandle connection, byte result)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->ConnectionResumptionTicketValidationComplete(connection.QuicHandle, result);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}

public int ConnectionCertificateValidationComplete(MsQuicSafeHandle connection, byte result, QUIC_TLS_ALERT_CODES alert)
{
bool success = false;
try
{
connection.DangerousAddRef(ref success);
return ApiTable->ConnectionCertificateValidationComplete(connection.QuicHandle, result, alert);
}
finally
{
if (success)
{
connection.DangerousRelease();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,16 @@ private MsQuicApi(QUIC_API_TABLE* apiTable)
private static readonly Lazy<MsQuicApi> _api = new Lazy<MsQuicApi>(AllocateMsQuicApi);
internal static MsQuicApi Api => _api.Value;

internal static Version? Version { get; private set; }

internal static bool IsQuicSupported { get; }

internal static string MsQuicLibraryVersion { get; } = "unknown";
internal static string? NotSupportedReason { get; }

// workaround for https://github.com/microsoft/msquic/issues/4132
internal static bool SupportsAsyncCertValidation => Version >= new Version(2, 4, 0);

internal static bool UsesSChannelBackend { get; }

internal static bool Tls13ServerMayBeDisabled { get; }
Expand All @@ -69,6 +74,7 @@ static MsQuicApi()
{
bool loaded = false;
IntPtr msQuicHandle;
Version = default;

// MsQuic is using DualMode sockets and that will fail even for IPv4 if AF_INET6 is not available.
if (!Socket.OSSupportsIPv6)
Expand Down Expand Up @@ -135,7 +141,7 @@ static MsQuicApi()
}
return;
}
Version version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);
Version = new Version((int)libVersion[0], (int)libVersion[1], (int)libVersion[2], (int)libVersion[3]);

paramSize = 64 * sizeof(sbyte);
sbyte* libGitHash = stackalloc sbyte[64];
Expand All @@ -150,11 +156,11 @@ static MsQuicApi()
}
string? gitHash = Marshal.PtrToStringUTF8((IntPtr)libGitHash);

MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {version} ({gitHash})";
MsQuicLibraryVersion = $"{Interop.Libraries.MsQuic} {Version} ({gitHash})";

if (version < s_minMsQuicVersion)
if (Version < s_minMsQuicVersion)
{
NotSupportedReason = $"Incompatible MsQuic library version '{version}', expecting higher than '{s_minMsQuicVersion}'.";
NotSupportedReason = $"Incompatible MsQuic library version '{Version}', expecting higher than '{s_minMsQuicVersion}'.";
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(null, NotSupportedReason);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.Quic;
using static Microsoft.Quic.MsQuic;

Expand Down Expand Up @@ -63,18 +66,122 @@ public SslConnectionOptions(QuicConnection connection, bool isClient,
_certificateChainPolicy = certificateChainPolicy;
}

public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER* chainPtr, out X509Certificate2? certificate)
internal async Task<bool> StartAsyncCertificateValidation(IntPtr certificatePtr, IntPtr chainPtr)
{
//
// The provided data pointers are valid only while still inside this function, so they need to be
// copied to separate buffers which are then handed off to threadpool.
//

X509Certificate2? certificate = null;

byte[]? certDataRented = null;
Memory<byte> certData = default;
byte[]? chainDataRented = null;
Memory<byte> chainData = default;

if (certificatePtr != IntPtr.Zero)
{
if (MsQuicApi.UsesSChannelBackend)
{
// provided data is a pointer to a CERT_CONTEXT
certificate = new X509Certificate2(certificatePtr);
// TODO: what about chainPtr?
}
else
{
unsafe
{
// On non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the contents are buffers
// with DER encoded cert and chain.
QUIC_BUFFER* certificateBuffer = (QUIC_BUFFER*)certificatePtr;
QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainPtr;

if (certificateBuffer->Length > 0)
{
certDataRented = ArrayPool<byte>.Shared.Rent((int)certificateBuffer->Length);
certData = certDataRented.AsMemory(0, (int)certificateBuffer->Length);
certificateBuffer->Span.CopyTo(certData.Span);
}

if (chainBuffer->Length > 0)
{
chainDataRented = ArrayPool<byte>.Shared.Rent((int)chainBuffer->Length);
chainData = chainDataRented.AsMemory(0, (int)chainBuffer->Length);
chainBuffer->Span.CopyTo(chainData.Span);
}
}
}
}

// We wan't to do the certificate validation asynchronously, but due to a bug in MsQuic, we need to call the callback synchronously on some versions
if (MsQuicApi.SupportsAsyncCertValidation)
{
// force yield to the thread pool to free up MsQuic worker thread.
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
}

// certificatePtr and chainPtr are invalid beyond this point

QUIC_TLS_ALERT_CODES result;
try
{
if (certData.Length > 0)
{
Debug.Assert(certificate == null);
certificate = new X509Certificate2(certData.Span);
}

result = _connection._sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span);
_connection._remoteCertificate = certificate;
}
catch (Exception ex)
{
certificate?.Dispose();
_connection._connectedTcs.TrySetException(ex);
result = QUIC_TLS_ALERT_CODES.USER_CANCELED;
}
finally
{
if (certDataRented != null)
{
ArrayPool<byte>.Shared.Return(certDataRented);
}

if (chainDataRented != null)
{
ArrayPool<byte>.Shared.Return(chainDataRented);
}
}

if (MsQuicApi.SupportsAsyncCertValidation)
{
int status = MsQuicApi.Api.ConnectionCertificateValidationComplete(
_connection._handle,
result == QUIC_TLS_ALERT_CODES.SUCCESS ? (byte)1 : (byte)0,
result);

if (MsQuic.StatusFailed(status))
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(_connection, $"{_connection} ConnectionCertificateValidationComplete failed with {ThrowHelper.GetErrorMessageForStatus(status)}");
}
}
}

return result == QUIC_TLS_ALERT_CODES.SUCCESS;
}

private QUIC_TLS_ALERT_CODES ValidateCertificate(X509Certificate2? certificate, Span<byte> certData, Span<byte> chainData)
{
SslPolicyErrors sslPolicyErrors = SslPolicyErrors.None;
IntPtr certificateBuffer = 0;
int certificateLength = 0;
bool wrapException = false;

X509Chain? chain = null;
X509Certificate2? result = null;
try
{
if (certificatePtr is not null)
if (certificate is not null)
{
chain = new X509Chain();
if (_certificateChainPolicy != null)
Expand All @@ -96,51 +203,34 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
chain.ChainPolicy.ApplicationPolicy.Add(_isClient ? s_serverAuthOid : s_clientAuthOid);
}

if (MsQuicApi.UsesSChannelBackend)
if (chainData.Length > 0)
{
result = new X509Certificate2((IntPtr)certificatePtr);
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainData);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
}
else
{
if (certificatePtr->Length > 0)
{
certificateBuffer = (IntPtr)certificatePtr->Buffer;
certificateLength = (int)certificatePtr->Length;
result = new X509Certificate2(certificatePtr->Span);
}

if (chainPtr->Length > 0)
{
X509Certificate2Collection additionalCertificates = new X509Certificate2Collection();
additionalCertificates.Import(chainPtr->Span);
chain.ChainPolicy.ExtraStore.AddRange(additionalCertificates);
}
}
}

if (result is not null)
{
bool checkCertName = !chain!.ChainPolicy!.VerificationFlags.HasFlag(X509VerificationFlags.IgnoreInvalidName);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, result, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certificateBuffer, certificateLength);
sslPolicyErrors |= CertificateValidation.BuildChainAndVerifyProperties(chain!, certificate, checkCertName, !_isClient, TargetHostNameHelper.NormalizeHostName(_targetHost), certData);
}
else if (_certificateRequired)
{
sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable;
}

int status = QUIC_STATUS_SUCCESS;
QUIC_TLS_ALERT_CODES result = QUIC_TLS_ALERT_CODES.SUCCESS;
if (_validationCallback is not null)
{
wrapException = true;
if (!_validationCallback(_connection, result, chain, sslPolicyErrors))
if (!_validationCallback(_connection, certificate, chain, sslPolicyErrors))
{
wrapException = false;
if (_isClient)
{
throw new AuthenticationException(SR.net_quic_cert_custom_validation);
}

status = QUIC_STATUS_USER_CANCELED;
result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we losing now the specific errors we were returning? Like discerning user_cancelled? Does it matter (affects the other side and what info they get), or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MsQuic does not use the specific error code value, it only checks for PENDING for async validation, SUCCESS for accept, and any error simply means reject.

With the async validation, we will actually be able to select the right TLS alert code which will go out over the wire, If the user validation callback throws, this will default to USER_CANCELLED as before, but if the callback returns false then it becomes BAD_CERTIFICATE. We can perhaps be more specific when no custom validation is present (there are alerts for stuff like UNTRUSTED_CERTIFICATE and similar) but I don't think we do that even in SslStream on some platforms.

}
}
else if (sslPolicyErrors != SslPolicyErrors.None)
Expand All @@ -150,15 +240,13 @@ public unsafe int ValidateCertificate(QUIC_BUFFER* certificatePtr, QUIC_BUFFER*
throw new AuthenticationException(SR.Format(SR.net_quic_cert_chain_validation, sslPolicyErrors));
}

status = QUIC_STATUS_HANDSHAKE_FAILURE;
result = QUIC_TLS_ALERT_CODES.BAD_CERTIFICATE;
}

certificate = result;
return status;
return result;
}
catch (Exception ex)
{
result?.Dispose();
if (wrapException)
{
throw new QuicException(QuicError.CallbackError, null, SR.net_quic_callback_error, ex);
Expand Down
Loading
Loading