Skip to content

Commit

Permalink
refactor SslStream internals (#68678)
Browse files Browse the repository at this point in the history
* refactor SslStream internals

* fix validation and certs

* update fakes

* feedback from review
  • Loading branch information
wfurt authored May 9, 2022
1 parent 5ecaae9 commit 3e5517b
Show file tree
Hide file tree
Showing 17 changed files with 347 additions and 624 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
<Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientHelloInfo.cs" />
<Compile Include="System\Net\Security\SslServerAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SecureChannel.cs" />
<Compile Include="System\Net\Security\SslSessionsCache.cs" />
<Compile Include="System\Net\Security\SslStream.cs" />
<Compile Include="System\Net\Security\SslStream.Implementation.cs" />
<Compile Include="System\Net\Security\SslStream.IO.cs" />
<Compile Include="System\Net\Security\SslStream.Protocol.cs" />
<Compile Include="System\Net\Security\SslStreamCertificateContext.cs" />
<Compile Include="System\Net\Security\SslConnectionInfo.cs" />
<Compile Include="System\Net\Security\StreamSizes.cs" />
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,39 @@ namespace System.Net.Security
{
internal sealed class SslAuthenticationOptions
{
internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions, RemoteCertificateValidationCallback? remoteCallback, LocalCertSelectionCallback? localCallback)
internal SslAuthenticationOptions()
{
TargetHost = string.Empty;
}

internal void UpdateOptions(SslClientAuthenticationOptions sslClientAuthenticationOptions)
{
Debug.Assert(sslClientAuthenticationOptions.TargetHost != null);

if (CertValidationDelegate == null)
{
CertValidationDelegate = sslClientAuthenticationOptions.RemoteCertificateValidationCallback;
}
else if (sslClientAuthenticationOptions.RemoteCertificateValidationCallback != null &&
CertValidationDelegate != sslClientAuthenticationOptions.RemoteCertificateValidationCallback)
{
// Callback was set in constructor to differet value.
throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback)));
}

if (CertSelectionDelegate == null)
{
CertSelectionDelegate = sslClientAuthenticationOptions.LocalCertificateSelectionCallback;
}
else if (sslClientAuthenticationOptions.LocalCertificateSelectionCallback != null &&
CertSelectionDelegate != sslClientAuthenticationOptions.LocalCertificateSelectionCallback)
{
throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(LocalCertificateSelectionCallback)));
}

// Common options.
AllowRenegotiation = sslClientAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslClientAuthenticationOptions.ApplicationProtocols;
CertValidationDelegate = remoteCallback;
CheckCertName = true;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslClientAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslClientAuthenticationOptions.EncryptionPolicy;
Expand All @@ -27,32 +52,57 @@ internal SslAuthenticationOptions(SslClientAuthenticationOptions sslClientAuthen
TargetHost = sslClientAuthenticationOptions.TargetHost.TrimEnd('.');

// Client specific options.
CertSelectionDelegate = localCallback;
CertificateRevocationCheckMode = sslClientAuthenticationOptions.CertificateRevocationCheckMode;
ClientCertificates = sslClientAuthenticationOptions.ClientCertificates;
CipherSuitesPolicy = sslClientAuthenticationOptions.CipherSuitesPolicy;
}

internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
internal void UpdateOptions(ServerOptionsSelectionCallback optionCallback, object? state)
{
// Common options.
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
CheckCertName = false;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
TargetHost = string.Empty;
IsServer = true;
RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired;
if (NetEventSource.Log.IsEnabled())
UserState = state;
ServerOptionDelegate = optionCallback;
}

internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
if (sslServerAuthenticationOptions.ServerCertificate == null &&
sslServerAuthenticationOptions.ServerCertificateContext == null &&
sslServerAuthenticationOptions.ServerCertificateSelectionCallback == null &&
CertSelectionDelegate == null)
{
NetEventSource.Info(this, $"Server RemoteCertRequired: {RemoteCertRequired}.");
throw new NotSupportedException(SR.net_ssl_io_no_server_cert);
}

if ((sslServerAuthenticationOptions.ServerCertificate != null ||
sslServerAuthenticationOptions.ServerCertificateContext != null ||
CertSelectionDelegate != null) &&
sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
{
throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(ServerCertificateSelectionCallback)));
}

if (CertValidationDelegate == null)
{
CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback;
}
else if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null &&
CertValidationDelegate != sslServerAuthenticationOptions.RemoteCertificateValidationCallback)
{
// Callback was set in constructor to differet value.
throw new InvalidOperationException(SR.Format(SR.net_conflicting_options, nameof(RemoteCertificateValidationCallback)));
}
TargetHost = string.Empty;

// Server specific options.
IsServer = true;
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired;
CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy;
CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode;

if (sslServerAuthenticationOptions.ServerCertificateContext != null)
{
CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext;
Expand All @@ -70,7 +120,7 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen
{
// This is legacy fix-up. If the Certificate did not have key, we will search stores and we
// will try to find one with matching hash.
certificateWithKey = SecureChannel.FindCertificateWithPrivateKey(this, true, sslServerAuthenticationOptions.ServerCertificate);
certificateWithKey = SslStream.FindCertificateWithPrivateKey(this, true, sslServerAuthenticationOptions.ServerCertificate);
if (certificateWithKey == null)
{
throw new AuthenticationException(SR.net_ssl_io_no_server_cert);
Expand All @@ -80,45 +130,9 @@ internal SslAuthenticationOptions(SslServerAuthenticationOptions sslServerAuthen
}
}

if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null)
if (sslServerAuthenticationOptions.ServerCertificateSelectionCallback != null)
{
CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback;
}
}

internal SslAuthenticationOptions(ServerOptionsSelectionCallback optionCallback, object? state, RemoteCertificateValidationCallback? remoteCallback)
{
CheckCertName = false;
TargetHost = string.Empty;
IsServer = true;
UserState = state;
ServerOptionDelegate = optionCallback;
CertValidationDelegate = remoteCallback;
}

internal void UpdateOptions(SslServerAuthenticationOptions sslServerAuthenticationOptions)
{
AllowRenegotiation = sslServerAuthenticationOptions.AllowRenegotiation;
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols;
EnabledSslProtocols = FilterOutIncompatibleSslProtocols(sslServerAuthenticationOptions.EnabledSslProtocols);
EncryptionPolicy = sslServerAuthenticationOptions.EncryptionPolicy;
RemoteCertRequired = sslServerAuthenticationOptions.ClientCertificateRequired;
CipherSuitesPolicy = sslServerAuthenticationOptions.CipherSuitesPolicy;
CertificateRevocationCheckMode = sslServerAuthenticationOptions.CertificateRevocationCheckMode;
if (sslServerAuthenticationOptions.ServerCertificateContext != null)
{
CertificateContext = sslServerAuthenticationOptions.ServerCertificateContext;
}
else if (sslServerAuthenticationOptions.ServerCertificate is X509Certificate2 certificateWithKey &&
certificateWithKey.HasPrivateKey)
{
// given cert is X509Certificate2 with key. We can use it directly.
CertificateContext = SslStreamCertificateContext.Create(certificateWithKey);
}

if (sslServerAuthenticationOptions.RemoteCertificateValidationCallback != null)
{
CertValidationDelegate = sslServerAuthenticationOptions.RemoteCertificateValidationCallback;
ServerCertSelectionDelegate = sslServerAuthenticationOptions.ServerCertificateSelectionCallback;
}
}

Expand Down Expand Up @@ -150,10 +164,10 @@ private static SslProtocols FilterOutIncompatibleSslProtocols(SslProtocols proto
internal bool RemoteCertRequired { get; set; }
internal bool CheckCertName { get; set; }
internal RemoteCertificateValidationCallback? CertValidationDelegate { get; set; }
internal LocalCertSelectionCallback? CertSelectionDelegate { get; set; }
internal ServerCertSelectionCallback? ServerCertSelectionDelegate { get; set; }
internal LocalCertificateSelectionCallback? CertSelectionDelegate { get; set; }
internal ServerCertificateSelectionCallback? ServerCertSelectionDelegate { get; set; }
internal CipherSuitesPolicy? CipherSuitesPolicy { get; set; }
internal object? UserState { get; }
internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; }
internal object? UserState { get; set; }
internal ServerOptionsSelectionCallback? ServerOptionDelegate { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
public SslConnectionInfo(SafeSslHandle sslContext)
public void UpdateSslConnectionInfo(SafeSslHandle sslContext)
{
string protocolString = Interop.AndroidCrypto.SSLStreamGetProtocol(sslContext);
SslProtocols protocol = protocolString switch
Expand All @@ -26,6 +26,7 @@ public SslConnectionInfo(SafeSslHandle sslContext)
_ => SslProtocols.None,
};
Protocol = (int)protocol;
ApplicationProtocol = Interop.AndroidCrypto.SSLStreamGetApplicationProtocol(sslContext);

// Enum value names should match the cipher suite name, so we just parse the
string cipherSuite = Interop.AndroidCrypto.SSLStreamGetCipherSuite(sslContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
public SslConnectionInfo(SafeSslHandle sslContext)
public void UpdateSslConnectionInfo(SafeSslHandle sslContext)
{
Protocol = (int)MapProtocolVersion(Interop.Ssl.SslGetVersion(sslContext));
ApplicationProtocol = Interop.Ssl.SslGetAlpnSelected(sslContext);

MapCipherSuite(SslGetCurrentCipherSuite(sslContext));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
public SslConnectionInfo(SafeSslHandle sslContext)
public void UpdateSslConnectionInfo(SafeSslHandle sslContext)
{
SslProtocols protocol;
TlsCipherSuite cipherSuite;
Expand All @@ -26,6 +26,7 @@ public SslConnectionInfo(SafeSslHandle sslContext)

Protocol = (int)protocol;
TlsCipherSuite = cipherSuite;
ApplicationProtocol = Interop.AppleCrypto.SslGetAlpnSelected(sslContext);

MapCipherSuite(cipherSuite);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
private void MapCipherSuite(TlsCipherSuite cipherSuite)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
public SslConnectionInfo(SecPkgContext_ConnectionInfo interopConnectionInfo, TlsCipherSuite cipherSuite)
private static byte[]? GetNegotiatedApplicationProtocol(SafeDeleteContext context)
{
Interop.SecPkgContext_ApplicationProtocol alpnContext = default;
bool success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, context, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_APPLICATION_PROTOCOL, ref alpnContext);

// Check if the context returned is alpn data, with successful negotiation.
if (success &&
alpnContext.ProtoNegoExt == Interop.ApplicationProtocolNegotiationExt.ALPN &&
alpnContext.ProtoNegoStatus == Interop.ApplicationProtocolNegotiationStatus.Success)
{
return alpnContext.Protocol;
}

return null;
}

public void UpdateSslConnectionInfo(SafeDeleteContext securityContext)
{
SecPkgContext_ConnectionInfo interopConnectionInfo = default;
bool success = SSPIWrapper.QueryBlittableContextAttributes(
GlobalSSPI.SSPISecureChannel,
securityContext,
Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CONNECTION_INFO,
ref interopConnectionInfo);
Debug.Assert(success);

TlsCipherSuite cipherSuite = default;
SecPkgContext_CipherInfo cipherInfo = default;

success = SSPIWrapper.QueryBlittableContextAttributes(GlobalSSPI.SSPISecureChannel, securityContext, Interop.SspiCli.ContextAttribute.SECPKG_ATTR_CIPHER_INFO, ref cipherInfo);
if (success)
{
cipherSuite = (TlsCipherSuite)cipherInfo.dwCipherSuite;
}

Protocol = interopConnectionInfo.Protocol;
DataCipherAlg = interopConnectionInfo.DataCipherAlg;
DataKeySize = interopConnectionInfo.DataKeySize;
Expand All @@ -16,6 +51,8 @@ public SslConnectionInfo(SecPkgContext_ConnectionInfo interopConnectionInfo, Tls
KeyExchKeySize = interopConnectionInfo.KeyExchKeySize;

TlsCipherSuite = cipherSuite;

ApplicationProtocol = GetNegotiatedApplicationProtocol(securityContext);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

namespace System.Net.Security
{
internal sealed partial class SslConnectionInfo
internal partial struct SslConnectionInfo
{
public int Protocol { get; }
public int Protocol { get; private set; }
public TlsCipherSuite TlsCipherSuite { get; private set; }
public int DataCipherAlg { get; private set; }
public int DataKeySize { get; private set; }
public int DataHashAlg { get; private set; }
public int DataHashKeySize { get; private set; }
public int KeyExchangeAlg { get; private set; }
public int KeyExchKeySize { get; private set; }

public byte[]? ApplicationProtocol { get; internal set; }
}
}
Loading

0 comments on commit 3e5517b

Please sign in to comment.