Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUIC] Cleaned up TlsSecret and added test. #93119

Merged
merged 5 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,35 @@ internal static unsafe T GetMsQuicParameter<T>(MsQuicSafeHandle handle, uint par
where T : unmanaged
{
T value;
uint length = (uint)sizeof(T);

GetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value);
return value;
}
internal static unsafe void GetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value)
{
int status = MsQuicApi.Api.GetParam(
handle,
parameter,
&length,
(byte*)&value);
value);

if (StatusFailed(status))
{
ThrowHelper.ThrowMsQuicException(status, $"GetParam({handle}, {parameter}) failed");
}

return value;
}

internal static unsafe void SetMsQuicParameter<T>(MsQuicSafeHandle handle, uint parameter, T value)
where T : unmanaged
{
SetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value);
}
internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value)
{
int status = MsQuicApi.Api.SetParam(
handle,
parameter,
(uint)sizeof(T),
(byte*)&value);
length,
value);

if (StatusFailed(status))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.InteropServices;
using Microsoft.Quic;

Expand Down Expand Up @@ -92,22 +93,19 @@ internal sealed class MsQuicContextSafeHandle : MsQuicSafeHandle
/// </summary>
private readonly MsQuicSafeHandle? _parent;

#if DEBUG
/// <summary>
/// Native memory to hold TLS secrets. It needs to live same cycle as the underlying connection.
/// Additional, dependent object to be disposed only after the safe handle gets released.
/// </summary>
private unsafe QUIC_TLS_SECRETS* _tlsSecrets;
private IDisposable? _disposable;

public unsafe QUIC_TLS_SECRETS* GetSecretsBuffer()
public IDisposable Disposable
{
if (_tlsSecrets == null)
set
{
_tlsSecrets = (QUIC_TLS_SECRETS*)NativeMemory.Alloc((nuint)sizeof(QUIC_TLS_SECRETS));
Debug.Assert(_disposable is null);
_disposable = value;
}

return _tlsSecrets;
}
#endif

public unsafe MsQuicContextSafeHandle(QUIC_HANDLE* handle, GCHandle context, SafeHandleType safeHandleType, MsQuicSafeHandle? parent = null)
: base(handle, safeHandleType)
Expand Down Expand Up @@ -140,13 +138,7 @@ protected override unsafe bool ReleaseHandle()
NetEventSource.Info(this, $"{this} {_parent} ref count decremented");
}
}
#if DEBUG
if (_tlsSecrets != null)
{
NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
NativeMemory.Free(_tlsSecrets);
}
#endif
_disposable?.Dispose();
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

#if DEBUG
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
Expand All @@ -19,88 +20,83 @@ internal sealed class MsQuicTlsSecret : IDisposable

public static unsafe MsQuicTlsSecret? Create(MsQuicContextSafeHandle handle)
{
if (s_fileStream != null)
if (s_fileStream is null)
{
try
{
QUIC_TLS_SECRETS* ptr = handle.GetSecretsBuffer();
if (ptr != null)
{
int status = MsQuicApi.Api.SetParam(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), ptr);
return null;
}

if (StatusSucceeded(status))
{
return new MsQuicTlsSecret(ptr);
}
else
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(handle, "Failed to set native memory for TLS secret.");
}
}
}
QUIC_TLS_SECRETS* tlsSecret = null;
try
{
tlsSecret = (QUIC_TLS_SECRETS*)NativeMemory.AllocZeroed((nuint)sizeof(QUIC_TLS_SECRETS));
MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), (byte*)tlsSecret);
MsQuicTlsSecret instance = new MsQuicTlsSecret(tlsSecret);
handle.Disposable = instance;
return instance;
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(handle, $"Failed to set native memory for TLS secret: {ex}");
}
if (tlsSecret is not null)
{
NativeMemory.Free(tlsSecret);
}
catch { };
return null;
}

return null;
}

private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* memory)
private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* tlsSecret)
{
_tlsSecrets = memory;
_tlsSecrets = tlsSecret;
}

public void WriteSecret() => WriteSecret(s_fileStream);
public unsafe void WriteSecret(FileStream? stream)
public unsafe void WriteSecret()
{
if (stream != null && _tlsSecrets != null)
Debug.Assert(_tlsSecrets is not null);
Debug.Assert(s_fileStream is not null);

lock (s_fileStream)
{
lock (stream)
string clientRandom = string.Empty;
if (_tlsSecrets->IsSet.ClientRandom != 0)
{
string clientRandom = string.Empty;

if (_tlsSecrets->IsSet.ClientRandom != 0)
{
clientRandom = HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientRandom, 32));
}

if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

stream.Flush();
clientRandom = HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientRandom, 32));
}
if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
s_fileStream.Flush();
}

NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
}

public unsafe void Dispose()
{
if (_tlsSecrets != null)
if (_tlsSecrets is not null)
ManickaP marked this conversation as resolved.
Show resolved Hide resolved
{
NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
NativeMemory.Free(_tlsSecrets);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is safe. It feels like we can free the memory while MsQuic still may be trying to write to that. That was one reason why originally the native memory was tight to the MsQuic handle lifecycle.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispose is called when the connection handle gets released, not from QuicConnection.DisposeAsync.

_tlsSecrets = null;
}
}
}
Expand Down
23 changes: 13 additions & 10 deletions src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ namespace System.Net.Quic;
/// </remarks>
public sealed partial class QuicConnection : IAsyncDisposable
{
#if DEBUG
/// <summary>
/// The actual secret structure wrapper passed to MsQuic.
/// </summary>
private readonly MsQuicTlsSecret? _tlsSecret;
#endif

/// <summary>
/// Returns <c>true</c> if QUIC is supported on the current machine and can be used; otherwise, <c>false</c>.
/// </summary>
Expand Down Expand Up @@ -152,6 +145,15 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
/// Set when CONNECTED is received.
/// </summary>
private SslApplicationProtocol _negotiatedApplicationProtocol;

#if DEBUG
/// <summary>
/// Will contain TLS secret after CONNECTED event is received and store it into SSLKEYLOGFILE.
/// MsQuic holds the underlying pointer so this object can be disposed only after connection native handle gets closed.
/// </summary>
private readonly MsQuicTlsSecret? _tlsSecret;
#endif

/// <summary>
/// The remote endpoint used for this connection.
/// </summary>
Expand Down Expand Up @@ -467,6 +469,10 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data)
QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter<QuicAddr>(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS);
_localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&localAddress);

#if DEBUG
_tlsSecret?.WriteSecret();
#endif

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(this, $"{this} Connection connected {LocalEndPoint} -> {RemoteEndPoint} for {_negotiatedApplicationProtocol} protocol");
Expand Down Expand Up @@ -596,9 +602,6 @@ public async ValueTask DisposeAsync()
return;
}

#if DEBUG
_tlsSecret?.Dispose();
#endif
// Check if the connection has been shut down and if not, shut it down.
if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Security;
using System.Threading.Tasks;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.DotNet.XUnitExtensions;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
[Collection(nameof(DisableParallelization))]
[ConditionalClass(typeof(QuicTestBase), nameof(QuicTestBase.IsSupported))]
public class MsQuicRemoteExecutorTests : QuicTestBase
{
public MsQuicRemoteExecutorTests()
: base(null!) { }

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void SslKeyLogFile_IsCreatedAndFilled()
{
var psi = new ProcessStartInfo();
var tempFile = Path.GetTempFileName();
psi.Environment.Add("SSLKEYLOGFILE", tempFile);

RemoteExecutor.Invoke(async () =>
{
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}, new RemoteInvokeOptions { StartInfo = psi }).Dispose();

Assert.True(File.Exists(tempFile));
Assert.True(File.ReadAllText(tempFile).Length > 0);
}
}
}
Loading