Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpenSSH strict key exchange extension #1366

Merged
merged 19 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 83 additions & 34 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ public class Session : ISession
/// </summary>
private bool _isDisconnecting;

/// <summary>
/// Indicates whether it is the init kex.
/// </summary>
private bool _isInitialKex;

/// <summary>
/// Indicates whether server supports strict key exchange.
/// <see href="https://github.com/openssh/openssh-portable/blob/master/PROTOCOL"/> 1.10.
/// </summary>
private bool _isStrictKex;

private IKeyExchange _keyExchange;

private HashAlgorithm _serverMac;
Expand Down Expand Up @@ -281,35 +292,11 @@ public bool IsConnected
/// </value>
public byte[] SessionId { get; private set; }

private Message _clientInitMessage;

/// <summary>
/// Gets the client init message.
/// </summary>
/// <value>The client init message.</value>
public Message ClientInitMessage
{
get
{
_clientInitMessage ??= new KeyExchangeInitMessage
{
KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
LanguagesClientToServer = new[] { string.Empty },
LanguagesServerToClient = new[] { string.Empty },
FirstKexPacketFollows = false,
Reserved = 0
};

return _clientInitMessage;
}
}
public Message ClientInitMessage { get; private set; }

/// <summary>
/// Gets the server version string.
Expand Down Expand Up @@ -617,6 +604,8 @@ public void Connect()
// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
_isInitialKex = true;
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
SendMessage(ClientInitMessage);

// Mark the message listener threads as started
Expand Down Expand Up @@ -741,6 +730,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
// Send our key exchange init.
// We need to do this before starting the message listener to avoid the case where we receive the server
// key exchange init and we continue the key exchange before having sent our own init.
_isInitialKex = true;
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
SendMessage(ClientInitMessage);

// Mark the message listener threads as started
Expand Down Expand Up @@ -1107,13 +1098,20 @@ internal void SendMessage(Message message)
SendPacket(data, 0, data.Length);
}

// increment the packet sequence number only after we're sure the packet has
// been sent; even though it's only used for the MAC, it needs to be incremented
// for each package sent.
//
// the server will use it to verify the data integrity, and as such the order in
// which messages are sent must follow the outbound packet sequence number
_outboundPacketSequence++;
if (_isStrictKex && message is NewKeysMessage)
{
_outboundPacketSequence = 0;
}
else
{
// increment the packet sequence number only after we're sure the packet has
// been sent; even though it's only used for the MAC, it needs to be incremented
// for each package sent.
//
// the server will use it to verify the data integrity, and as such the order in
// which messages are sent must follow the outbound packet sequence number
_outboundPacketSequence++;
}
}
}

Expand Down Expand Up @@ -1344,6 +1342,13 @@ private Message ReceiveMessage(Socket socket)

_inboundPacketSequence++;

// The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
// It ensures the integrity of key exchange process.
if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
{
throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
}
scott-xu marked this conversation as resolved.
Show resolved Hide resolved

return LoadMessage(data, messagePayloadOffset, messagePayloadLength);
}

Expand Down Expand Up @@ -1455,8 +1460,20 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)

_keyExchangeCompletedWaitHandle.Reset();

if (_isInitialKex && message.KeyExchangeAlgorithms.Contains("kex-strict-s-v00@openssh.com"))
scott-xu marked this conversation as resolved.
Show resolved Hide resolved
{
_isStrictKex = true;

DiagnosticAbstraction.Log(string.Format("[{0}] Enabling strict key exchange extension.", ToHex(SessionId)));

if (_inboundPacketSequence != 1)
{
throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
}
}

// Disable messages that are not key exchange related
_sshMessageFactory.DisableNonKeyExchangeMessages();
_sshMessageFactory.DisableNonKeyExchangeMessages(_isStrictKex);

_keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms,
message.KeyExchangeAlgorithms);
Expand Down Expand Up @@ -1533,6 +1550,17 @@ internal void OnNewKeysReceived(NewKeysMessage message)
// Enable activated messages that are not key exchange related
_sshMessageFactory.EnableActivatedMessages();

if (_isInitialKex)
{
_isInitialKex = false;
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: false);
}

if (_isStrictKex)
{
_inboundPacketSequence = 0;
}

NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));

// Signal that key exchange completed
Expand Down Expand Up @@ -2067,7 +2095,28 @@ private void Reset()
private static SshConnectionException CreateConnectionAbortedByServerException()
{
return new SshConnectionException("An established connection was aborted by the server.",
DisconnectReason.ConnectionLost);
DisconnectReason.ConnectionLost);
}

private KeyExchangeInitMessage BuildClientInitMessage(bool includeStrictKexPseudoAlgorithm)
{
return new KeyExchangeInitMessage
{
KeyExchangeAlgorithms = includeStrictKexPseudoAlgorithm ?
ConnectionInfo.KeyExchangeAlgorithms.Keys.Concat(["kex-strict-c-v00@openssh.com"]).ToArray() :
ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
LanguagesClientToServer = new[] { string.Empty },
LanguagesServerToClient = new[] { string.Empty },
FirstKexPacketFollows = false,
Reserved = 0,
};
}

private bool _disposed;
Expand Down
31 changes: 28 additions & 3 deletions src/Renci.SshNet/SshMessageFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,41 @@ public Message Create(byte messageNumber)
return enabledMessageMetadata.Create();
}

public void DisableNonKeyExchangeMessages()
/// <summary>
/// Disables non-KeyExchange messages.
/// </summary>
/// <param name="strict">
/// <see langword="true"/> to indicate the strict key exchange mode; otherwise <see langword="false"/>.
/// <para>In strict key exchange mode, only below messages are allowed:</para>
/// <list type="bullet">
/// <item>SSH_MSG_KEXINIT -> 20</item>
/// <item>SSH_MSG_NEWKEYS -> 21</item>
/// <item>SSH_MSG_DISCONNECT -> 1</item>
/// </list>
/// <para>Note:</para>
/// <para> The relevant KEX Reply MSG will be allowed from a sub class of KeyExchange class.</para>
/// <para> For example, it calls <c>Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");</c> if the curve25519-sha256 KEX algorithm is selected per negotiation.</para>
/// </param>
public void DisableNonKeyExchangeMessages(bool strict)
{
for (var i = 0; i < AllMessages.Length; i++)
{
var messageMetadata = AllMessages[i];

var messageNumber = messageMetadata.Number;
if (messageNumber is (> 2 and < 20) or > 30)
if (strict)
{
if (messageNumber is not 20 and not 21 and not 1)
{
_enabledMessagesByNumber[messageNumber] = null;
}
}
else
{
_enabledMessagesByNumber[messageNumber] = null;
if (messageNumber is (> 2 and < 20) or > 30)
{
_enabledMessagesByNumber[messageNumber] = null;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void IsConnectedShouldReturnFalse()
}

[TestMethod]
public void SendMessageShouldThrowShhConnectionException()
public void SendMessageShouldThrowSshConnectionException()
{
try
{
Expand Down Expand Up @@ -189,7 +189,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled()
}

[TestMethod]
public void ISession_SendMessageShouldThrowShhConnectionException()
public void ISession_SendMessageShouldThrowSshConnectionException()
{
var session = (ISession)_session;

Expand Down
26 changes: 26 additions & 0 deletions test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using System.Threading;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
Expand Down Expand Up @@ -30,6 +31,31 @@ public void ClientVersionIsRenciSshNet()
Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion);
}

[TestMethod]
public void IncludeStrictKexPseudoAlgorithmInInitKex()
{
Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);

var kexInitMessage = new KeyExchangeInitMessage();
kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
Assert.IsTrue(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
}

[TestMethod]
public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequenceKex()
Rob-Hague marked this conversation as resolved.
Show resolved Hide resolved
{
ServerBytesReceivedRegister.Clear();
Session.SendMessage(Session.ClientInitMessage);

Thread.Sleep(100);

Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);

var kexInitMessage = new KeyExchangeInitMessage();
kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
Assert.IsFalse(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
}

[TestMethod]
public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
{
Expand Down
12 changes: 3 additions & 9 deletions test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ public abstract class SessionTest_ConnectedBase
protected Session Session { get; private set; }
protected Socket ClientSocket { get; private set; }
protected Socket ServerSocket { get; private set; }
internal SshIdentification ServerIdentification { get; set; }
protected bool CallSessionConnectWhenArrange { get; set; }
protected SshIdentification ServerIdentification { get; private set; }

/// <summary>
/// Should the "server" wait for the client kexinit before sending its own.
Expand Down Expand Up @@ -163,8 +162,6 @@ protected virtual void SetupData()

ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);

CallSessionConnectWhenArrange = true;

void SendKeyExchangeInit()
{
var keyExchangeInitMessage = new KeyExchangeInitMessage
Expand Down Expand Up @@ -204,7 +201,7 @@ private void SetupMocks()
_ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
.Returns(_protocolVersionExchangeMock.Object);
_ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
.Returns(() => ServerIdentification);
.Returns(ServerIdentification);
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
_ = _keyExchangeMock.Setup(p => p.Name)
.Returns(_keyExchangeAlgorithm);
Expand Down Expand Up @@ -252,10 +249,7 @@ protected void Arrange()
SetupData();
SetupMocks();

if (CallSessionConnectWhenArrange)
{
Session.Connect();
}
Session.Connect();
}

protected virtual void ClientAuthentication_Callback()
Expand Down
Loading