diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs index f900e843b..c3d66e07c 100644 --- a/src/Renci.SshNet/SshCommand.cs +++ b/src/Renci.SshNet/SshCommand.cs @@ -21,7 +21,7 @@ public class SshCommand : IDisposable private readonly ISession _session; private readonly Encoding _encoding; - private IChannelSession? _channel; + private IChannelSession _channel; private TaskCompletionSource? _tcs; private CancellationTokenSource? _cts; private CancellationTokenRegistration _tokenRegistration; @@ -142,14 +142,14 @@ public int? ExitStatus /// public Stream CreateInputStream() { - if (_channel == null) + if (!_channel.IsOpen) { - throw new InvalidOperationException($"The input stream can be used only after calling BeginExecute and before calling EndExecute."); + throw new InvalidOperationException("The input stream can be used only during execution."); } if (_inputStream != null) { - throw new InvalidOperationException($"The input stream already exists."); + throw new InvalidOperationException("The input stream already exists."); } _inputStream = new ChannelInputStream(_channel); @@ -226,6 +226,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding) ExtendedOutputStream = new PipeStream(); _session.Disconnected += Session_Disconnected; _session.ErrorOccured += Session_ErrorOccured; + _channel = _session.CreateChannelSession(); } /// @@ -257,6 +258,8 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default) throw new InvalidOperationException("Asynchronous operation is already in progress."); } + UnsubscribeFromChannelEvents(dispose: true); + OutputStream.Dispose(); ExtendedOutputStream.Dispose(); @@ -265,6 +268,7 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default) // so we just need to reinitialise them for subsequent executions. OutputStream = new PipeStream(); ExtendedOutputStream = new PipeStream(); + _channel = _session.CreateChannelSession(); } _exitStatus = default; @@ -282,7 +286,6 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default) _tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); _userToken = cancellationToken; - _channel = _session.CreateChannelSession(); _channel.DataReceived += Channel_DataReceived; _channel.ExtendedDataReceived += Channel_ExtendedDataReceived; _channel.RequestReceived += Channel_RequestReceived; @@ -542,7 +545,10 @@ private void SetAsyncComplete(bool setResult = true) } } - UnsubscribeFromEventsAndDisposeChannel(); + // We don't dispose the channel here to avoid a race condition + // where SSH_MSG_CHANNEL_CLOSE arrives before _channel starts + // waiting for a response in _channel.SendExecRequest(). + UnsubscribeFromChannelEvents(dispose: false); OutputStream.Dispose(); ExtendedOutputStream.Dispose(); @@ -568,7 +574,7 @@ private void Channel_RequestReceived(object? sender, ChannelRequestEventArgs e) Debug.Assert(!exitSignalInfo.WantReply, "exit-signal is want_reply := false by definition."); } - else if (e.Info.WantReply && _channel?.RemoteChannelNumber is uint remoteChannelNumber) + else if (e.Info.WantReply && sender is IChannel { RemoteChannelNumber: uint remoteChannelNumber }) { var replyMessage = new ChannelFailureMessage(remoteChannelNumber); _session.SendMessage(replyMessage); @@ -591,20 +597,13 @@ private void Channel_DataReceived(object? sender, ChannelDataEventArgs e) } /// - /// Unsubscribes the current from channel events, and disposes - /// the . + /// Unsubscribes the current from channel events, and optionally, + /// disposes . /// - private void UnsubscribeFromEventsAndDisposeChannel() + private void UnsubscribeFromChannelEvents(bool dispose) { var channel = _channel; - if (channel is null) - { - return; - } - - _channel = null; - // unsubscribe from events as we do not want to be signaled should these get fired // during the dispose of the channel channel.DataReceived -= Channel_DataReceived; @@ -612,8 +611,10 @@ private void UnsubscribeFromEventsAndDisposeChannel() channel.RequestReceived -= Channel_RequestReceived; channel.Closed -= Channel_Closed; - // actually dispose the channel - channel.Dispose(); + if (dispose) + { + channel.Dispose(); + } } /// @@ -645,7 +646,7 @@ protected virtual void Dispose(bool disposing) // unsubscribe from channel events to ensure other objects that we're going to dispose // are not accessed while disposing - UnsubscribeFromEventsAndDisposeChannel(); + UnsubscribeFromChannelEvents(dispose: true); _inputStream?.Dispose(); _inputStream = null; diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute.cs deleted file mode 100644 index 51d4de496..000000000 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute.cs +++ /dev/null @@ -1,72 +0,0 @@ -using System; -using System.Globalization; -using System.Text; - -using Microsoft.VisualStudio.TestTools.UnitTesting; - -using Moq; - -using Renci.SshNet.Channels; -using Renci.SshNet.Common; -using Renci.SshNet.Tests.Common; - -namespace Renci.SshNet.Tests.Classes -{ - [TestClass] - public class SshCommand_EndExecute : TestBase - { - private Mock _sessionMock; - private Mock _channelSessionMock; - private string _commandText; - private Encoding _encoding; - private SshCommand _sshCommand; - - protected override void OnInit() - { - base.OnInit(); - - _sessionMock = new Mock(MockBehavior.Strict); - _commandText = new Random().Next().ToString(CultureInfo.InvariantCulture); - _encoding = Encoding.UTF8; - _channelSessionMock = new Mock(MockBehavior.Strict); - - _sshCommand = new SshCommand(_sessionMock.Object, _commandText, _encoding); - } - - [TestMethod] - public void EndExecute_ChannelClosed_ShouldDisposeChannelSession() - { - var seq = new MockSequence(); - - _sessionMock.InSequence(seq).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object); - _channelSessionMock.InSequence(seq).Setup(p => p.Open()); - _channelSessionMock.InSequence(seq).Setup(p => p.SendExecRequest(_commandText)) - .Returns(true) - .Raises(c => c.Closed += null, new ChannelEventArgs(5)); - _channelSessionMock.InSequence(seq).Setup(p => p.Dispose()); - - var asyncResult = _sshCommand.BeginExecute(); - _sshCommand.EndExecute(asyncResult); - - _channelSessionMock.Verify(p => p.Dispose(), Times.Once); - } - - [TestMethod] - public void EndExecute_ChannelOpen_ShouldSendEofAndCloseAndDisposeChannelSession() - { - var seq = new MockSequence(); - - _sessionMock.InSequence(seq).Setup(p => p.CreateChannelSession()).Returns(_channelSessionMock.Object); - _channelSessionMock.InSequence(seq).Setup(p => p.Open()); - _channelSessionMock.InSequence(seq).Setup(p => p.SendExecRequest(_commandText)) - .Returns(true) - .Raises(c => c.Closed += null, new ChannelEventArgs(5)); - _channelSessionMock.InSequence(seq).Setup(p => p.Dispose()); - - var asyncResult = _sshCommand.BeginExecute(); - _sshCommand.EndExecute(asyncResult); - - _channelSessionMock.Verify(p => p.Dispose(), Times.Once); - } - } -} diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultIsNull.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultIsNull.cs index 9029c3e4c..bb997b28a 100644 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultIsNull.cs +++ b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultIsNull.cs @@ -30,7 +30,7 @@ protected override void OnInit() private void Arrange() { - _sessionMock = new Mock(MockBehavior.Strict); + _sessionMock = new Mock(); _commandText = new Random().Next().ToString(CultureInfo.InvariantCulture); _encoding = Encoding.UTF8; _asyncResult = null; diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs index 6935d74a0..1afc92234 100644 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs +++ b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs @@ -81,12 +81,6 @@ private void Act() _actual = _sshCommand.EndExecute(_asyncResult); } - [TestMethod] - public void ChannelSessionShouldBeDisposedOnce() - { - _channelSessionMock.Verify(p => p.Dispose(), Times.Once); - } - [TestMethod] public void EndExecuteShouldReturnAllDataReceivedInSpecifiedEncoding() {