diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs index 0f925d65db3ec..fee4e1aade7df 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicHelpers.cs @@ -58,30 +58,35 @@ internal static unsafe T GetMsQuicParameter(MsQuicSafeHandle handle, uint par where T : unmanaged { T value; - uint length = (uint)sizeof(T); - + GetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value); + return value; + } + internal static unsafe void GetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value) + { int status = MsQuicApi.Api.GetParam( handle, parameter, &length, - (byte*)&value); + value); if (StatusFailed(status)) { ThrowHelper.ThrowMsQuicException(status, $"GetParam({handle}, {parameter}) failed"); } - - return value; } internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, T value) where T : unmanaged + { + SetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value); + } + internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value) { int status = MsQuicApi.Api.SetParam( handle, parameter, - (uint)sizeof(T), - (byte*)&value); + length, + value); if (StatusFailed(status)) { diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 8e87654f2e321..38a099ed9e49f 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Quic; @@ -92,22 +93,19 @@ internal sealed class MsQuicContextSafeHandle : MsQuicSafeHandle /// private readonly MsQuicSafeHandle? _parent; -#if DEBUG /// - /// Native memory to hold TLS secrets. It needs to live same cycle as the underlying connection. + /// Additional, dependent object to be disposed only after the safe handle gets released. /// - private unsafe QUIC_TLS_SECRETS* _tlsSecrets; + private IDisposable? _disposable; - public unsafe QUIC_TLS_SECRETS* GetSecretsBuffer() + public IDisposable Disposable { - if (_tlsSecrets == null) + set { - _tlsSecrets = (QUIC_TLS_SECRETS*)NativeMemory.Alloc((nuint)sizeof(QUIC_TLS_SECRETS)); + Debug.Assert(_disposable is null); + _disposable = value; } - - return _tlsSecrets; } -#endif public unsafe MsQuicContextSafeHandle(QUIC_HANDLE* handle, GCHandle context, SafeHandleType safeHandleType, MsQuicSafeHandle? parent = null) : base(handle, safeHandleType) @@ -140,13 +138,7 @@ protected override unsafe bool ReleaseHandle() NetEventSource.Info(this, $"{this} {_parent} ref count decremented"); } } -#if DEBUG - if (_tlsSecrets != null) - { - NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS)); - NativeMemory.Free(_tlsSecrets); - } -#endif + _disposable?.Dispose(); return true; } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs index ad2b3a87ccf00..f151f12312909 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicTlsSecret.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. #if DEBUG +using System.Diagnostics; using System.IO; using System.Runtime.InteropServices; using System.Text; @@ -19,88 +20,93 @@ internal sealed class MsQuicTlsSecret : IDisposable public static unsafe MsQuicTlsSecret? Create(MsQuicContextSafeHandle handle) { - if (s_fileStream != null) + if (s_fileStream is null) { - try - { - QUIC_TLS_SECRETS* ptr = handle.GetSecretsBuffer(); - if (ptr != null) - { - int status = MsQuicApi.Api.SetParam(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), ptr); + return null; + } - if (StatusSucceeded(status)) - { - return new MsQuicTlsSecret(ptr); - } - else - { - if (NetEventSource.Log.IsEnabled()) - { - NetEventSource.Error(handle, "Failed to set native memory for TLS secret."); - } - } - } + QUIC_TLS_SECRETS* tlsSecrets = null; + try + { + tlsSecrets = (QUIC_TLS_SECRETS*)NativeMemory.AllocZeroed((nuint)sizeof(QUIC_TLS_SECRETS)); + MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), (byte*)tlsSecrets); + MsQuicTlsSecret instance = new MsQuicTlsSecret(tlsSecrets); + handle.Disposable = instance; + return instance; + } + catch (Exception ex) + { + if (NetEventSource.Log.IsEnabled()) + { + NetEventSource.Error(handle, $"Failed to set native memory for TLS secret: {ex}"); + } + if (tlsSecrets is not null) + { + NativeMemory.Free(tlsSecrets); } - catch { }; + return null; } - - return null; } - private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* memory) + private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* tlsSecrets) { - _tlsSecrets = memory; + _tlsSecrets = tlsSecrets; } - public void WriteSecret() => WriteSecret(s_fileStream); - public unsafe void WriteSecret(FileStream? stream) + public unsafe void WriteSecret() { - if (stream != null && _tlsSecrets != null) + Debug.Assert(_tlsSecrets is not null); + Debug.Assert(s_fileStream is not null); + + lock (s_fileStream) { - lock (stream) + string clientRandom = string.Empty; + if (_tlsSecrets->IsSet.ClientRandom != 0) { - string clientRandom = string.Empty; - - if (_tlsSecrets->IsSet.ClientRandom != 0) - { - clientRandom = HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientRandom, 32)); - } - - if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0) - { - stream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n")); - } - - if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0) - { - stream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n")); - } - - if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0) - { - stream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n")); - } - - if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0) - { - stream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n")); - } - - if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0) - { - stream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n")); - } - - stream.Flush(); + clientRandom = HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientRandom, 32)); } + if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0) + { + s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n")); + } + if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0) + { + s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n")); + } + if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0) + { + s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n")); + } + if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0) + { + s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n")); + } + if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0) + { + s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n")); + } + s_fileStream.Flush(); } + + NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS)); } public unsafe void Dispose() { - if (_tlsSecrets != null) + if (_tlsSecrets is null) + { + return; + } + lock (this) { - NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS)); + if (_tlsSecrets is null) + { + return; + } + + QUIC_TLS_SECRETS* tlsSecrets = _tlsSecrets; + _tlsSecrets = null; + NativeMemory.Free(_tlsSecrets); } } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index 3b49667e9b32a..20e0de1771faf 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -38,13 +38,6 @@ namespace System.Net.Quic; /// public sealed partial class QuicConnection : IAsyncDisposable { -#if DEBUG - /// - /// The actual secret structure wrapper passed to MsQuic. - /// - private readonly MsQuicTlsSecret? _tlsSecret; -#endif - /// /// Returns true if QUIC is supported on the current machine and can be used; otherwise, false. /// @@ -152,6 +145,15 @@ static async ValueTask StartConnectAsync(QuicClientConnectionOpt /// Set when CONNECTED is received. /// private SslApplicationProtocol _negotiatedApplicationProtocol; + +#if DEBUG + /// + /// Will contain TLS secret after CONNECTED event is received and store it into SSLKEYLOGFILE. + /// MsQuic holds the underlying pointer so this object can be disposed only after connection native handle gets closed. + /// + private readonly MsQuicTlsSecret? _tlsSecret; +#endif + /// /// The remote endpoint used for this connection. /// @@ -467,6 +469,10 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data) QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS); _localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&localAddress); +#if DEBUG + _tlsSecret?.WriteSecret(); +#endif + if (NetEventSource.Log.IsEnabled()) { NetEventSource.Info(this, $"{this} Connection connected {LocalEndPoint} -> {RemoteEndPoint} for {_negotiatedApplicationProtocol} protocol"); @@ -596,9 +602,6 @@ public async ValueTask DisposeAsync() return; } -#if DEBUG - _tlsSecret?.Dispose(); -#endif // Check if the connection has been shut down and if not, shut it down. if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this)) { diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicRemoteExecutorTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicRemoteExecutorTests.cs new file mode 100644 index 0000000000000..3aa7d11b508c5 --- /dev/null +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicRemoteExecutorTests.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Net.Security; +using System.Threading.Tasks; +using Microsoft.DotNet.RemoteExecutor; +using Microsoft.DotNet.XUnitExtensions; +using Xunit; +using Xunit.Abstractions; + +namespace System.Net.Quic.Tests +{ + [Collection(nameof(DisableParallelization))] + [ConditionalClass(typeof(QuicTestBase), nameof(QuicTestBase.IsSupported))] + public class MsQuicRemoteExecutorTests : QuicTestBase + { + public MsQuicRemoteExecutorTests() + : base(null!) { } + + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void SslKeyLogFile_IsCreatedAndFilled() + { + if (PlatformDetection.IsReleaseRuntime) + { + throw new SkipTestException("Retrieving SSL secrets is not supported in Release mode."); + } + + var psi = new ProcessStartInfo(); + var tempFile = Path.GetTempFileName(); + psi.Environment.Add("SSLKEYLOGFILE", tempFile); + + RemoteExecutor.Invoke(async () => + { + (QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection(); + await clientConnection.DisposeAsync(); + await serverConnection.DisposeAsync(); + }, new RemoteInvokeOptions { StartInfo = psi }).Dispose(); + + Assert.True(File.Exists(tempFile)); + Assert.True(File.ReadAllText(tempFile).Length > 0); + } + } +}