From 6ddffb2de37953b80e79058cb40290fc7a26df88 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Thu, 23 May 2024 23:18:37 +0200 Subject: [PATCH 1/2] Send "signal" and set the wait handle on completion always --- .editorconfig | 2 +- src/Renci.SshNet/Channels/ChannelSession.cs | 29 -- src/Renci.SshNet/Channels/IChannelSession.cs | 21 - src/Renci.SshNet/CommandAsyncResult.cs | 10 - src/Renci.SshNet/SshClient.cs | 2 +- src/Renci.SshNet/SshCommand.cs | 443 ++++++++---------- .../OldIntegrationTests/SshCommandTest.cs | 280 ++++------- .../Classes/ShellStreamTest_ReadExpect.cs | 4 - .../Classes/SshCommandTest_Dispose.cs | 12 - ...EndExecute_AsyncResultFromOtherInstance.cs | 3 +- .../SshCommandTest_EndExecute_ChannelOpen.cs | 15 +- 11 files changed, 303 insertions(+), 518 deletions(-) diff --git a/.editorconfig b/.editorconfig index 3577f7db4..05a2600d0 100644 --- a/.editorconfig +++ b/.editorconfig @@ -688,7 +688,7 @@ dotnet_diagnostic.CA1852.severity = none # CA1859: Change return type for improved performance # # By default, this diagnostic is only reported for private members. -dotnet_code_quality.CA1859.api_surface = all +dotnet_code_quality.CA1859.api_surface = private,internal # CA2208: Instantiate argument exceptions correctly # https://learn.microsoft.com/en-us/dotnet/fundamentals/code-analysis/quality-rules/ca2208 diff --git a/src/Renci.SshNet/Channels/ChannelSession.cs b/src/Renci.SshNet/Channels/ChannelSession.cs index f222fc647..52c31522c 100644 --- a/src/Renci.SshNet/Channels/ChannelSession.cs +++ b/src/Renci.SshNet/Channels/ChannelSession.cs @@ -274,35 +274,6 @@ public bool SendSignalRequest(string signalName) return true; } - /// - /// Sends the exit status request. - /// - /// The exit status. - /// - /// if request was successful; otherwise . - /// - public bool SendExitStatusRequest(uint exitStatus) - { - SendMessage(new ChannelRequestMessage(RemoteChannelNumber, new ExitStatusRequestInfo(exitStatus))); - return true; - } - - /// - /// Sends the exit signal request. - /// - /// Name of the signal. - /// if set to [core dumped]. - /// The error message. - /// The language. - /// - /// if request was successful; otherwise . - /// - public bool SendExitSignalRequest(string signalName, bool coreDumped, string errorMessage, string language) - { - SendMessage(new ChannelRequestMessage(RemoteChannelNumber, new ExitSignalRequestInfo(signalName, coreDumped, errorMessage, language))); - return true; - } - /// /// Sends eow@openssh.com request. /// diff --git a/src/Renci.SshNet/Channels/IChannelSession.cs b/src/Renci.SshNet/Channels/IChannelSession.cs index 2b18efa63..21fe4c158 100644 --- a/src/Renci.SshNet/Channels/IChannelSession.cs +++ b/src/Renci.SshNet/Channels/IChannelSession.cs @@ -120,27 +120,6 @@ bool SendPseudoTerminalRequest(string environmentVariable, /// bool SendSignalRequest(string signalName); - /// - /// Sends the exit status request. - /// - /// The exit status. - /// - /// if request was successful; otherwise . - /// - bool SendExitStatusRequest(uint exitStatus); - - /// - /// Sends the exit signal request. - /// - /// Name of the signal. - /// if set to [core dumped]. - /// The error message. - /// The language. - /// - /// if request was successful; otherwise . - /// - bool SendExitSignalRequest(string signalName, bool coreDumped, string errorMessage, string language); - /// /// Sends eow@openssh.com request. /// diff --git a/src/Renci.SshNet/CommandAsyncResult.cs b/src/Renci.SshNet/CommandAsyncResult.cs index c956083b0..ad96d39de 100644 --- a/src/Renci.SshNet/CommandAsyncResult.cs +++ b/src/Renci.SshNet/CommandAsyncResult.cs @@ -56,15 +56,5 @@ internal CommandAsyncResult() /// true if the operation is complete; otherwise, false. /// public bool IsCompleted { get; internal set; } - - /// - /// Gets or sets a value indicating whether was already called for this - /// . - /// - /// - /// if was already called for this ; - /// otherwise, . - /// - internal bool EndCalled { get; set; } } } diff --git a/src/Renci.SshNet/SshClient.cs b/src/Renci.SshNet/SshClient.cs index 6517a4268..033bd93e2 100644 --- a/src/Renci.SshNet/SshClient.cs +++ b/src/Renci.SshNet/SshClient.cs @@ -235,7 +235,7 @@ public SshCommand CreateCommand(string commandText, Encoding encoding) EnsureSessionIsOpen(); ConnectionInfo.Encoding = encoding; - return new SshCommand(Session, commandText, encoding); + return new SshCommand(Session!, commandText, encoding); } /// diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs index 4bd3abd97..1ee8d752d 100644 --- a/src/Renci.SshNet/SshCommand.cs +++ b/src/Renci.SshNet/SshCommand.cs @@ -1,5 +1,6 @@ -using System; -using System.Globalization; +#nullable enable +using System; +using System.Diagnostics; using System.IO; using System.Runtime.ExceptionServices; using System.Text; @@ -18,22 +19,25 @@ namespace Renci.SshNet /// public class SshCommand : IDisposable { + private static readonly object CompletedResult = new(); + + private readonly ISession _session; private readonly Encoding _encoding; - private readonly object _endExecuteLock = new object(); - - private ISession _session; - private IChannelSession _channel; - private CommandAsyncResult _asyncResult; - private AsyncCallback _callback; - private EventWaitHandle _sessionErrorOccuredWaitHandle; - private EventWaitHandle _commandCancelledWaitHandle; - private Exception _exception; - private string _result; - private string _error; + + /// + /// The result of the command: an exception, + /// or . + /// + private object? _result; + + private IChannelSession? _channel; + private CommandAsyncResult? _asyncResult; + private AsyncCallback? _callback; + private string? _stdOut; + private string? _stdErr; private bool _hasError; private bool _isDisposed; - private bool _isCancelled; - private ChannelInputStream _inputStream; + private ChannelInputStream? _inputStream; private TimeSpan _commandTimeout; /// @@ -66,19 +70,26 @@ public TimeSpan CommandTimeout /// public int ExitStatus { get; private set; } + /// + /// Gets the name of the signal due to which the command + /// terminated violently, if applicable, otherwise . + /// + /// + /// The value (if it exists) is supplied by the server and is usually one of the + /// following, as described in https://datatracker.ietf.org/doc/html/rfc4254#section-6.10: + /// ABRT, ALRM, FPE, HUP, ILL, INT, KILL, PIPE, QUIT, SEGV, TER, USR1, USR2. + /// + public string? ExitSignal { get; private set; } + /// /// Gets the output stream. /// -#pragma warning disable CA1859 // Use concrete types when possible for improved performance public Stream OutputStream { get; private set; } -#pragma warning restore CA1859 // Use concrete types when possible for improved performance /// /// Gets the extended output stream. /// -#pragma warning disable CA1859 // Use concrete types when possible for improved performance public Stream ExtendedOutputStream { get; private set; } -#pragma warning restore CA1859 // Use concrete types when possible for improved performance /// /// Creates and returns the input stream for the command. @@ -109,21 +120,19 @@ public string Result { get { - if (_result is not null) + if (_stdOut is not null) { - return _result; + return _stdOut; } - if (OutputStream is null) + if (_asyncResult is null) { return string.Empty; } - using (var sr = new StreamReader(OutputStream, - _encoding, - detectEncodingFromByteOrderMarks: true)) + using (var sr = new StreamReader(OutputStream, _encoding)) { - return _result = sr.ReadToEnd(); + return _stdOut = sr.ReadToEnd(); } } } @@ -135,21 +144,19 @@ public string Error { get { - if (_error is not null) + if (_stdErr is not null) { - return _error; + return _stdErr; } - if (ExtendedOutputStream is null || !_hasError) + if (_asyncResult is null || !_hasError) { return string.Empty; } - using (var sr = new StreamReader(ExtendedOutputStream, - _encoding, - detectEncodingFromByteOrderMarks: true)) + using (var sr = new StreamReader(ExtendedOutputStream, _encoding)) { - return _error = sr.ReadToEnd(); + return _stdErr = sr.ReadToEnd(); } } } @@ -182,8 +189,8 @@ internal SshCommand(ISession session, string commandText, Encoding encoding) CommandText = commandText; _encoding = encoding; CommandTimeout = Timeout.InfiniteTimeSpan; - _sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false); - _commandCancelledWaitHandle = new AutoResetEvent(initialState: false); + OutputStream = new PipeStream(); + ExtendedOutputStream = new PipeStream(); _session.Disconnected += Session_Disconnected; _session.ErrorOccured += Session_ErrorOccured; } @@ -216,7 +223,7 @@ public IAsyncResult BeginExecute() /// CommandText property is empty. /// Client is not connected. /// Operation has timed out. - public IAsyncResult BeginExecute(AsyncCallback callback) + public IAsyncResult BeginExecute(AsyncCallback? callback) { return BeginExecute(callback, state: null); } @@ -234,47 +241,54 @@ public IAsyncResult BeginExecute(AsyncCallback callback) /// CommandText property is empty. /// Client is not connected. /// Operation has timed out. -#pragma warning disable CA1859 // Use concrete types when possible for improved performance - public IAsyncResult BeginExecute(AsyncCallback callback, object state) -#pragma warning restore CA1859 // Use concrete types when possible for improved performance + public IAsyncResult BeginExecute(AsyncCallback? callback, object? state) { - // Prevent from executing BeginExecute before calling EndExecute - if (_asyncResult != null && !_asyncResult.EndCalled) +#if NET7_0_OR_GREATER + ObjectDisposedException.ThrowIf(_isDisposed, this); +#else + if (_isDisposed) { - throw new InvalidOperationException("Asynchronous operation is already in progress."); + throw new ObjectDisposedException(GetType().FullName); + } +#endif + + if (_asyncResult is not null) + { + if (!_asyncResult.AsyncWaitHandle.WaitOne(0)) + { + throw new InvalidOperationException("Asynchronous operation is already in progress."); + } + + OutputStream.Dispose(); + ExtendedOutputStream.Dispose(); + + // Initialize output streams. We already initialised them for the first + // execution in the constructor (to allow passing them around before execution) + // so we just need to reinitialise them for subsequent executions. + OutputStream = new PipeStream(); + ExtendedOutputStream = new PipeStream(); } // Create new AsyncResult object _asyncResult = new CommandAsyncResult { AsyncWaitHandle = new ManualResetEvent(initialState: false), - IsCompleted = false, AsyncState = state, }; - if (_channel is not null) - { - throw new SshException("Invalid operation."); - } - - if (string.IsNullOrEmpty(CommandText)) - { - throw new ArgumentException("CommandText property is empty."); - } - - OutputStream?.Dispose(); - ExtendedOutputStream?.Dispose(); - - // Initialize output streams - OutputStream = new PipeStream(); - ExtendedOutputStream = new PipeStream(); - + ExitStatus = default; + ExitSignal = null; _result = null; - _error = null; + _stdOut = null; + _stdErr = null; _hasError = false; _callback = callback; - _channel = CreateChannel(); + _channel = _session.CreateChannelSession(); + _channel.DataReceived += Channel_DataReceived; + _channel.ExtendedDataReceived += Channel_ExtendedDataReceived; + _channel.RequestReceived += Channel_RequestReceived; + _channel.Closed += Channel_Closed; _channel.Open(); _ = _channel.SendExecRequest(CommandText); @@ -293,8 +307,13 @@ public IAsyncResult BeginExecute(AsyncCallback callback, object state) /// /// Client is not connected. /// Operation has timed out. - public IAsyncResult BeginExecute(string commandText, AsyncCallback callback, object state) + public IAsyncResult BeginExecute(string commandText, AsyncCallback? callback, object? state) { + if (commandText is null) + { + throw new ArgumentNullException(nameof(commandText)); + } + CommandText = commandText; return BeginExecute(callback, state); @@ -314,55 +333,88 @@ public string EndExecute(IAsyncResult asyncResult) throw new ArgumentNullException(nameof(asyncResult)); } - if (asyncResult is not CommandAsyncResult commandAsyncResult || _asyncResult != commandAsyncResult) + if (_asyncResult != asyncResult) { - throw new ArgumentException(string.Format("The {0} object was not returned from the corresponding asynchronous method on this class.", nameof(IAsyncResult))); + throw new ArgumentException("Argument does not correspond to the currently executing command.", nameof(asyncResult)); } - lock (_endExecuteLock) - { - if (commandAsyncResult.EndCalled) - { - throw new ArgumentException("EndExecute can only be called once for each asynchronous operation."); - } - - _inputStream?.Close(); + _inputStream?.Dispose(); - try - { - // wait for operation to complete (or time out) - WaitOnHandle(_asyncResult.AsyncWaitHandle); - } - finally - { - UnsubscribeFromEventsAndDisposeChannel(_channel); - _channel = null; + if (!_asyncResult.AsyncWaitHandle.WaitOne(CommandTimeout)) + { + // Complete the operation with a TimeoutException (which will be thrown below). + SetAsyncComplete(new SshOperationTimeoutException($"Command '{CommandText}' timed out. ({nameof(CommandTimeout)}: {CommandTimeout}).")); + } - OutputStream?.Dispose(); - ExtendedOutputStream?.Dispose(); + Debug.Assert(_asyncResult.IsCompleted); - commandAsyncResult.EndCalled = true; - } + if (_result is Exception exception) + { + ExceptionDispatchInfo.Capture(exception).Throw(); + } - if (!_isCancelled) - { - return Result; - } + Debug.Assert(_result == CompletedResult); + Debug.Assert(!OutputStream.CanWrite, $"{nameof(OutputStream)} should have been disposed (else we will block)."); - SetAsyncComplete(); - throw new OperationCanceledException(); - } + return Result; } /// - /// Cancels command execution in asynchronous scenarios. + /// Cancels a running command by sending a signal to the remote process. /// /// if true send SIGKILL instead of SIGTERM. - public void CancelAsync(bool forceKill = false) + /// Time to wait for the server to reply. + /// + /// + /// This method stops the command running on the server by sending a SIGTERM + /// (or SIGKILL, depending on ) signal to the remote + /// process. When the server implements signals, it will send a response which + /// populates with the signal with which the command terminated. + /// + /// + /// When the server does not implement signals, it may send no response. As a fallback, + /// this method waits up to for a response + /// and then completes the object anyway if there was none. + /// + /// + /// If the command has already finished (with or without cancellation), this method does + /// nothing. + /// + /// + /// Command has not been started. + public void CancelAsync(bool forceKill = false, int millisecondsTimeout = 500) { - var signal = forceKill ? "KILL" : "TERM"; - _ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en"); - _ = _commandCancelledWaitHandle?.Set(); + if (_asyncResult is not { } asyncResult) + { + throw new InvalidOperationException("Command has not been started."); + } + + var exception = new OperationCanceledException($"Command '{CommandText}' was cancelled."); + + if (Interlocked.CompareExchange(ref _result, exception, comparand: null) is not null) + { + // Command has already completed. + return; + } + + // Try to send the cancellation signal. + if (_channel?.SendSignalRequest(forceKill ? "KILL" : "TERM") is null) + { + // Command has completed (in the meantime since the last check). + // We won the race above and the command has finished by some other means, + // but will throw the OperationCanceledException. + return; + } + + // Having sent the "signal" message, we expect to receive "exit-signal" + // and then a close message. But since a server may not implement signals, + // we can't guarantee that, so we wait a short time for that to happen and + // if it doesn't, just set the WaitHandle ourselves to unblock EndExecute. + + if (!asyncResult.AsyncWaitHandle.WaitOne(millisecondsTimeout)) + { + SetAsyncComplete(asyncResult); + } } /// @@ -394,88 +446,72 @@ public string Execute(string commandText) return Execute(); } - private IChannelSession CreateChannel() + private void Session_Disconnected(object? sender, EventArgs e) { - var channel = _session.CreateChannelSession(); - channel.DataReceived += Channel_DataReceived; - channel.ExtendedDataReceived += Channel_ExtendedDataReceived; - channel.RequestReceived += Channel_RequestReceived; - channel.Closed += Channel_Closed; - return channel; + SetAsyncComplete(new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost)); } - private void Session_Disconnected(object sender, EventArgs e) + private void Session_ErrorOccured(object? sender, ExceptionEventArgs e) { - // If objected is disposed or being disposed don't handle this event - if (_isDisposed) - { - return; - } - - _exception = new SshConnectionException("An established connection was aborted by the software in your host machine.", DisconnectReason.ConnectionLost); - - _ = _sessionErrorOccuredWaitHandle.Set(); + SetAsyncComplete(e.Exception); } - private void Session_ErrorOccured(object sender, ExceptionEventArgs e) + private void SetAsyncComplete(object result) { - // If objected is disposed or being disposed don't handle this event - if (_isDisposed) + _ = Interlocked.CompareExchange(ref _result, result, comparand: null); + + if (_asyncResult is CommandAsyncResult asyncResult) { - return; + SetAsyncComplete(asyncResult); } - - _exception = e.Exception; - - _ = _sessionErrorOccuredWaitHandle.Set(); } - private void SetAsyncComplete() + private void SetAsyncComplete(CommandAsyncResult asyncResult) { - OutputStream?.Dispose(); - ExtendedOutputStream?.Dispose(); + UnsubscribeFromEventsAndDisposeChannel(); + + OutputStream.Dispose(); + ExtendedOutputStream.Dispose(); + + asyncResult.IsCompleted = true; - _asyncResult.IsCompleted = true; + _ = ((EventWaitHandle)asyncResult.AsyncWaitHandle).Set(); - if (_callback is not null && !_isCancelled) + if (Interlocked.Exchange(ref _callback, value: null) is AsyncCallback callback) { - // Execute callback on different thread - ThreadAbstraction.ExecuteThread(() => _callback(_asyncResult)); + ThreadAbstraction.ExecuteThread(() => callback(asyncResult)); } - - _ = ((EventWaitHandle)_asyncResult.AsyncWaitHandle).Set(); } - private void Channel_Closed(object sender, ChannelEventArgs e) + private void Channel_Closed(object? sender, ChannelEventArgs e) { - SetAsyncComplete(); + SetAsyncComplete(CompletedResult); } - private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e) + private void Channel_RequestReceived(object? sender, ChannelRequestEventArgs e) { if (e.Info is ExitStatusRequestInfo exitStatusInfo) { ExitStatus = (int)exitStatusInfo.ExitStatus; - if (exitStatusInfo.WantReply) - { - var replyMessage = new ChannelSuccessMessage(_channel.RemoteChannelNumber); - _session.SendMessage(replyMessage); - } + Debug.Assert(!exitStatusInfo.WantReply, "exit-status is want_reply := false by definition."); } - else + else if (e.Info is ExitSignalRequestInfo exitSignalInfo) { - if (e.Info.WantReply) - { - var replyMessage = new ChannelFailureMessage(_channel.RemoteChannelNumber); - _session.SendMessage(replyMessage); - } + ExitSignal = exitSignalInfo.SignalName; + + Debug.Assert(!exitSignalInfo.WantReply, "exit-signal is want_reply := false by definition."); + } + else if (e.Info.WantReply && _channel?.RemoteChannelNumber is uint remoteChannelNumber) + { + var replyMessage = new ChannelFailureMessage(remoteChannelNumber); + _session.SendMessage(replyMessage); } } - private void Channel_ExtendedDataReceived(object sender, ChannelExtendedDataEventArgs e) + private void Channel_ExtendedDataReceived(object? sender, ChannelExtendedDataEventArgs e) { - ExtendedOutputStream?.Write(e.Data, 0, e.Data.Length); + ExtendedOutputStream.Write(e.Data, 0, e.Data.Length); if (e.DataTypeCode == 1) { @@ -483,64 +519,34 @@ private void Channel_ExtendedDataReceived(object sender, ChannelExtendedDataEven } } - private void Channel_DataReceived(object sender, ChannelDataEventArgs e) + private void Channel_DataReceived(object? sender, ChannelDataEventArgs e) { - OutputStream?.Write(e.Data, 0, e.Data.Length); + OutputStream.Write(e.Data, 0, e.Data.Length); - if (_asyncResult != null) + if (_asyncResult is CommandAsyncResult asyncResult) { - lock (_asyncResult) + lock (asyncResult) { - _asyncResult.BytesReceived += e.Data.Length; + asyncResult.BytesReceived += e.Data.Length; } } } - /// Command '{0}' has timed out. - /// The actual command will be included in the exception message. - private void WaitOnHandle(WaitHandle waitHandle) - { - var waitHandles = new[] - { - _sessionErrorOccuredWaitHandle, - waitHandle, - _commandCancelledWaitHandle - }; - - var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout); - switch (signaledElement) - { - case 0: - ExceptionDispatchInfo.Capture(_exception).Throw(); - break; - case 1: - // Specified waithandle was signaled - break; - case 2: - _isCancelled = true; - break; - case WaitHandle.WaitTimeout: - throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText)); - default: - throw new SshException($"Unexpected element '{signaledElement.ToString(CultureInfo.InvariantCulture)}' signaled."); - } - } - /// /// Unsubscribes the current from channel events, and disposes - /// the . + /// the . /// - /// The channel. - /// - /// Does nothing when is . - /// - private void UnsubscribeFromEventsAndDisposeChannel(IChannel channel) + private void UnsubscribeFromEventsAndDisposeChannel() { + var channel = _channel; + if (channel is null) { return; } + _channel = null; + // unsubscribe from events as we do not want to be signaled should these get fired // during the dispose of the channel channel.DataReceived -= Channel_DataReceived; @@ -576,66 +582,27 @@ protected virtual void Dispose(bool disposing) { // unsubscribe from session events to ensure other objects that we're going to dispose // are not accessed while disposing - var session = _session; - if (session != null) - { - session.Disconnected -= Session_Disconnected; - session.ErrorOccured -= Session_ErrorOccured; - _session = null; - } + _session.Disconnected -= Session_Disconnected; + _session.ErrorOccured -= Session_ErrorOccured; // unsubscribe from channel events to ensure other objects that we're going to dispose // are not accessed while disposing - var channel = _channel; - if (channel != null) - { - UnsubscribeFromEventsAndDisposeChannel(channel); - _channel = null; - } + UnsubscribeFromEventsAndDisposeChannel(); - var inputStream = _inputStream; - if (inputStream != null) - { - inputStream.Dispose(); - _inputStream = null; - } + _inputStream?.Dispose(); + _inputStream = null; - var outputStream = OutputStream; - if (outputStream != null) - { - outputStream.Dispose(); - OutputStream = null; - } - - var extendedOutputStream = ExtendedOutputStream; - if (extendedOutputStream != null) - { - extendedOutputStream.Dispose(); - ExtendedOutputStream = null; - } + OutputStream.Dispose(); + ExtendedOutputStream.Dispose(); - var sessionErrorOccuredWaitHandle = _sessionErrorOccuredWaitHandle; - if (sessionErrorOccuredWaitHandle != null) + if (_asyncResult is not null && _result is null) { - sessionErrorOccuredWaitHandle.Dispose(); - _sessionErrorOccuredWaitHandle = null; + // In case an operation is still running, try to complete it with an ObjectDisposedException. + SetAsyncComplete(new ObjectDisposedException(GetType().FullName)); } - _commandCancelledWaitHandle?.Dispose(); - _commandCancelledWaitHandle = null; - _isDisposed = true; } } - - /// - /// Finalizes an instance of the class. - /// Releases unmanaged resources and performs other cleanup operations before the - /// is reclaimed by garbage collection. - /// - ~SshCommand() - { - Dispose(disposing: false); - } } } diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs index f9ef378cd..eb229b9c3 100644 --- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs +++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs @@ -19,14 +19,14 @@ public void Test_Run_SingleCommand() client.Connect(); var testValue = Guid.NewGuid().ToString(); - var command = client.RunCommand(string.Format("echo {0}", testValue)); + using var command = client.RunCommand(string.Format("echo {0}", testValue)); var result = command.Result; result = result.Substring(0, result.Length - 1); // Remove \n character returned by command client.Disconnect(); #endregion - Assert.IsTrue(result.Equals(testValue)); + Assert.AreEqual(testValue, result); } } @@ -39,15 +39,14 @@ public void Test_Execute_SingleCommand() client.Connect(); var testValue = Guid.NewGuid().ToString(); - var command = string.Format("echo {0}", testValue); - var cmd = client.CreateCommand(command); + var command = string.Format("echo -n {0}", testValue); + using var cmd = client.CreateCommand(command); var result = cmd.Execute(); - result = result.Substring(0, result.Length - 1); // Remove \n character returned by command client.Disconnect(); #endregion - Assert.IsTrue(result.Equals(testValue)); + Assert.AreEqual(testValue, result); } } @@ -56,75 +55,62 @@ public void Test_Execute_SingleCommand() public void Test_CancelAsync_Unfinished_Command() { using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password); - #region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal client.Connect(); var testValue = Guid.NewGuid().ToString(); - var command = $"sleep 15s; echo {testValue}"; - using var cmd = client.CreateCommand(command); + using var cmd = client.CreateCommand($"sleep 15s; echo {testValue}"); + var asyncResult = cmd.BeginExecute(); + cmd.CancelAsync(); + Assert.ThrowsException(() => cmd.EndExecute(asyncResult)); Assert.IsTrue(asyncResult.IsCompleted); - client.Disconnect(); - Assert.AreEqual(string.Empty, cmd.Result.Trim()); - #endregion + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); + Assert.AreEqual(string.Empty, cmd.Result); + Assert.AreEqual("TERM", cmd.ExitSignal); } [TestMethod] - public async Task Test_CancelAsync_Finished_Command() + [Timeout(5000)] + public async Task Test_CancelAsync_Kill_Unfinished_Command() { using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password); - #region Example SshCommand CancelAsync Finished Command client.Connect(); var testValue = Guid.NewGuid().ToString(); - var command = $"echo {testValue}"; - using var cmd = client.CreateCommand(command); + using var cmd = client.CreateCommand($"sleep 15s; echo {testValue}"); + var asyncResult = cmd.BeginExecute(); - while (!asyncResult.IsCompleted) - { - await Task.Delay(200); - } - cmd.CancelAsync(); - cmd.EndExecute(asyncResult); - client.Disconnect(); + Task executeTask = Task.Factory.FromAsync(asyncResult, cmd.EndExecute); + cmd.CancelAsync(forceKill: true); + + await Assert.ThrowsExceptionAsync(() => executeTask); Assert.IsTrue(asyncResult.IsCompleted); - Assert.AreEqual(testValue, cmd.Result.Trim()); - #endregion + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); + Assert.AreEqual(string.Empty, cmd.Result); + Assert.AreEqual("KILL", cmd.ExitSignal); } [TestMethod] - public void Test_Execute_OutputStream() + public void Test_CancelAsync_Finished_Command() { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - #region Example SshCommand CreateCommand Execute OutputStream - client.Connect(); - - var cmd = client.CreateCommand("ls -l"); // very long list - var asynch = cmd.BeginExecute(); - - var reader = new StreamReader(cmd.OutputStream); - - while (!asynch.IsCompleted) - { - var result = reader.ReadToEnd(); - if (string.IsNullOrEmpty(result)) - { - continue; - } + using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password); + client.Connect(); + var testValue = Guid.NewGuid().ToString(); + using var cmd = client.CreateCommand($"echo -n {testValue}"); - Console.Write(result); - } + var asyncResult = cmd.BeginExecute(); - _ = cmd.EndExecute(asynch); + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(TimeSpan.FromSeconds(5))); - client.Disconnect(); - #endregion + cmd.CancelAsync(); // Should not throw + Assert.AreEqual(testValue, cmd.EndExecute(asyncResult)); // Should not throw + cmd.CancelAsync(); // Should not throw - Assert.Inconclusive(); - } + Assert.IsTrue(asyncResult.IsCompleted); + Assert.AreEqual(testValue, cmd.Result); + Assert.IsNull(cmd.ExitSignal); } [TestMethod] @@ -135,51 +121,33 @@ public void Test_Execute_ExtendedOutputStream() #region Example SshCommand CreateCommand Execute ExtendedOutputStream client.Connect(); - var cmd = client.CreateCommand("echo 12345; echo 654321 >&2"); - var result = cmd.Execute(); - - Console.Write(result); + using var cmd = client.CreateCommand("echo 12345; echo 654321 >&2"); + using var reader = new StreamReader(cmd.ExtendedOutputStream); - var reader = new StreamReader(cmd.ExtendedOutputStream); - Console.WriteLine("DEBUG:"); - Console.Write(reader.ReadToEnd()); + Assert.AreEqual("12345\n", cmd.Execute()); + Assert.AreEqual("654321\n", reader.ReadToEnd()); client.Disconnect(); #endregion - - Assert.Inconclusive(); } } [TestMethod] - [ExpectedException(typeof(SshOperationTimeoutException))] public void Test_Execute_Timeout() { using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { #region Example SshCommand CreateCommand Execute CommandTimeout client.Connect(); - var cmd = client.CreateCommand("sleep 10s"); - cmd.CommandTimeout = TimeSpan.FromSeconds(5); - cmd.Execute(); + using var cmd = client.CreateCommand("sleep 10s"); + cmd.CommandTimeout = TimeSpan.FromSeconds(2); + Assert.ThrowsException(cmd.Execute); client.Disconnect(); #endregion } } - [TestMethod] - public void Test_Execute_Infinite_Timeout() - { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - client.Connect(); - var cmd = client.CreateCommand("sleep 10s"); - cmd.Execute(); - client.Disconnect(); - } - } - [TestMethod] public void Test_Execute_InvalidCommand() { @@ -187,7 +155,7 @@ public void Test_Execute_InvalidCommand() { client.Connect(); - var cmd = client.CreateCommand(";"); + using var cmd = client.CreateCommand(";"); cmd.Execute(); if (string.IsNullOrEmpty(cmd.Error)) { @@ -205,7 +173,7 @@ public void Test_Execute_InvalidCommand_Then_Execute_ValidCommand() using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { client.Connect(); - var cmd = client.CreateCommand(";"); + using var cmd = client.CreateCommand(";"); cmd.Execute(); if (string.IsNullOrEmpty(cmd.Error)) { @@ -221,24 +189,6 @@ public void Test_Execute_InvalidCommand_Then_Execute_ValidCommand() } } - [TestMethod] - public void Test_Execute_Command_with_ExtendedOutput() - { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - client.Connect(); - var cmd = client.CreateCommand("echo 12345; echo 654321 >&2"); - cmd.Execute(); - - //var extendedData = Encoding.ASCII.GetString(cmd.ExtendedOutputStream.ToArray()); - var extendedData = new StreamReader(cmd.ExtendedOutputStream, Encoding.ASCII).ReadToEnd(); - client.Disconnect(); - - Assert.AreEqual("12345\n", cmd.Result); - Assert.AreEqual("654321\n", extendedData); - } - } - [TestMethod] public void Test_Execute_Command_Reconnect_Execute_Command() { @@ -261,17 +211,12 @@ public void Test_Execute_Command_ExitStatus() { using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { - #region Example SshCommand RunCommand ExitStatus client.Connect(); - var cmd = client.RunCommand("exit 128"); + using var cmd = client.RunCommand("exit 128"); - Console.WriteLine(cmd.ExitStatus); - - client.Disconnect(); - #endregion - - Assert.IsTrue(cmd.ExitStatus == 128); + Assert.AreEqual(128, cmd.ExitStatus); + Assert.IsNull(cmd.ExitSignal); } } @@ -282,61 +227,45 @@ public void Test_Execute_Command_Asynchronously() { client.Connect(); - var cmd = client.CreateCommand("sleep 5s; echo 'test'"); - var asyncResult = cmd.BeginExecute(null, null); - while (!asyncResult.IsCompleted) - { - Thread.Sleep(100); - } - - cmd.EndExecute(asyncResult); - - Assert.IsTrue(cmd.Result == "test\n"); - - client.Disconnect(); - } - } + using var callbackCalled = new ManualResetEventSlim(); - [TestMethod] - public void Test_Execute_Command_Asynchronously_With_Error() - { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - client.Connect(); + using var cmd = client.CreateCommand("sleep 2s; echo 'test'"); + var asyncResult = cmd.BeginExecute(new AsyncCallback((s) => + { + callbackCalled.Set(); + }), state: null); - var cmd = client.CreateCommand("sleep 5s; ;"); - var asyncResult = cmd.BeginExecute(null, null); while (!asyncResult.IsCompleted) { Thread.Sleep(100); } + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); + cmd.EndExecute(asyncResult); - Assert.IsFalse(string.IsNullOrEmpty(cmd.Error)); + Assert.AreEqual("test\n", cmd.Result); + Assert.IsTrue(callbackCalled.Wait(TimeSpan.FromSeconds(1))); client.Disconnect(); } } [TestMethod] - public void Test_Execute_Command_Asynchronously_With_Callback() + public void Test_Execute_Command_Asynchronously_With_Error() { using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { client.Connect(); - using var callbackCalled = new ManualResetEventSlim(); + using var cmd = client.CreateCommand("sleep 2s; ;"); + var asyncResult = cmd.BeginExecute(null, null); - using var cmd = client.CreateCommand("sleep 5s; echo 'test'"); - var asyncResult = cmd.BeginExecute(new AsyncCallback((s) => - { - callbackCalled.Set(); - }), null); + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(TimeSpan.FromSeconds(5))); cmd.EndExecute(asyncResult); - Assert.IsTrue(callbackCalled.Wait(TimeSpan.FromSeconds(1))); + Assert.IsFalse(string.IsNullOrEmpty(cmd.Error)); client.Disconnect(); } @@ -353,7 +282,7 @@ public void Test_Execute_Command_Asynchronously_With_Callback_On_Different_Threa int callbackThreadId = 0; using var callbackCalled = new ManualResetEventSlim(); - using var cmd = client.CreateCommand("sleep 5s; echo 'test'"); + using var cmd = client.CreateCommand("sleep 2s; echo 'test'"); var asyncResult = cmd.BeginExecute(new AsyncCallback((s) => { callbackThreadId = Thread.CurrentThread.ManagedThreadId; @@ -379,7 +308,7 @@ public void Test_Execute_Command_Same_Object_Different_Commands() using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { client.Connect(); - var cmd = client.CreateCommand("echo 12345"); + using var cmd = client.CreateCommand("echo 12345"); cmd.Execute(); Assert.AreEqual("12345\n", cmd.Result); cmd.Execute("echo 23456"); @@ -394,35 +323,22 @@ public void Test_Get_Result_Without_Execution() using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { client.Connect(); - var cmd = client.CreateCommand("ls -l"); - - Assert.IsTrue(string.IsNullOrEmpty(cmd.Result)); - client.Disconnect(); - } - } - - [TestMethod] - public void Test_Get_Error_Without_Execution() - { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - client.Connect(); - var cmd = client.CreateCommand("ls -l"); + using var cmd = client.CreateCommand("ls -l"); - Assert.IsTrue(string.IsNullOrEmpty(cmd.Error)); + Assert.AreEqual(string.Empty, cmd.Result); + Assert.AreEqual(string.Empty, cmd.Error); client.Disconnect(); } } [WorkItem(703), TestMethod] - [ExpectedException(typeof(ArgumentNullException))] public void Test_EndExecute_Before_BeginExecute() { using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) { client.Connect(); - var cmd = client.CreateCommand("ls -l"); - cmd.EndExecute(null); + using var cmd = client.CreateCommand("ls -l"); + Assert.ThrowsException(() => cmd.EndExecute(null)); client.Disconnect(); } } @@ -442,15 +358,12 @@ public void BeginExecuteTest() client.Connect(); - var cmd = client.CreateCommand("sleep 15s;echo 123"); // Perform long running task + using var cmd = client.CreateCommand("sleep 2s;echo 123"); // Perform long running task var asynch = cmd.BeginExecute(); - while (!asynch.IsCompleted) - { - // Waiting for command to complete... - Thread.Sleep(2000); - } + Assert.IsTrue(asynch.AsyncWaitHandle.WaitOne(TimeSpan.FromSeconds(5))); + result = cmd.EndExecute(asynch); client.Disconnect(); @@ -461,30 +374,6 @@ public void BeginExecuteTest() } } - [TestMethod] - public void Test_Execute_Invalid_Command() - { - using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) - { - #region Example SshCommand CreateCommand Error - - client.Connect(); - - var cmd = client.CreateCommand(";"); - cmd.Execute(); - if (!string.IsNullOrEmpty(cmd.Error)) - { - Console.WriteLine(cmd.Error); - } - - client.Disconnect(); - - #endregion - - Assert.Inconclusive(); - } - } - [TestMethod] public void Test_MultipleThread_100_MultipleConnections() @@ -542,13 +431,30 @@ public void Test_MultipleThread_100_MultipleSessions() } } + [TestMethod] + public void Test_ExecuteAsync_Dispose_CommandFinishes() + { + using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) + { + client.Connect(); + + var cmd = client.CreateCommand("sleep 5s"); + var asyncResult = cmd.BeginExecute(null, null); + + cmd.Dispose(); + + Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); + + Assert.ThrowsException(() => cmd.EndExecute(asyncResult)); + } + } + private static bool ExecuteTestCommand(SshClient s) { var testValue = Guid.NewGuid().ToString(); - var command = string.Format("echo {0}", testValue); - var cmd = s.CreateCommand(command); + var command = string.Format("echo -n {0}", testValue); + using var cmd = s.CreateCommand(command); var result = cmd.Execute(); - result = result.Substring(0, result.Length - 1); // Remove \n character returned by command return result.Equals(testValue); } } diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs index e4ec77f37..6dfdf3970 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -390,10 +390,6 @@ public void Open() public bool SendExecRequest(string command) => throw new NotImplementedException(); - public bool SendExitSignalRequest(string signalName, bool coreDumped, string errorMessage, string language) => throw new NotImplementedException(); - - public bool SendExitStatusRequest(uint exitStatus) => throw new NotImplementedException(); - public bool SendKeepAliveRequest() => throw new NotImplementedException(); public bool SendLocalFlowRequest(bool clientCanDo) => throw new NotImplementedException(); diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_Dispose.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_Dispose.cs index 21542f83b..a3d11f231 100644 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_Dispose.cs +++ b/test/Renci.SshNet.Tests/Classes/SshCommandTest_Dispose.cs @@ -64,24 +64,12 @@ public void ChannelSessionShouldBeDisposedOnce() _channelSessionMock.Verify(p => p.Dispose(), Times.Once); } - [TestMethod] - public void OutputStreamShouldReturnNull() - { - Assert.IsNull(_sshCommand.OutputStream); - } - [TestMethod] public void OutputStreamShouldHaveBeenDisposed() { Assert.AreEqual(-1, _outputStream.ReadByte()); } - [TestMethod] - public void ExtendedOutputStreamShouldReturnNull() - { - Assert.IsNull(_sshCommand.ExtendedOutputStream); - } - [TestMethod] public void ExtendedOutputStreamShouldHaveBeenDisposed() { diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultFromOtherInstance.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultFromOtherInstance.cs index 5584786c5..79de6bb21 100644 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultFromOtherInstance.cs +++ b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_AsyncResultFromOtherInstance.cs @@ -85,8 +85,7 @@ public void EndExecuteShouldHaveThrownArgumentException() { Assert.IsNotNull(_actualException); Assert.IsNull(_actualException.InnerException); - Assert.AreEqual(string.Format("The {0} object was not returned from the corresponding asynchronous method on this class.", nameof(IAsyncResult)), _actualException.Message); - Assert.IsNull(_actualException.ParamName); + Assert.AreEqual("asyncResult", _actualException.ParamName); } } } diff --git a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs index 293439680..6935d74a0 100644 --- a/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs +++ b/test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute_ChannelOpen.cs @@ -94,20 +94,9 @@ public void EndExecuteShouldReturnAllDataReceivedInSpecifiedEncoding() } [TestMethod] - public void EndExecuteShouldThrowArgumentExceptionWhenInvokedAgainWithSameAsyncResult() + public void EndExecuteShouldNotThrowWhenInvokedAgainWithSameAsyncResult() { - try - { - _sshCommand.EndExecute(_asyncResult); - Assert.Fail(); - } - catch (ArgumentException ex) - { - Assert.AreEqual(typeof(ArgumentException), ex.GetType()); - Assert.IsNull(ex.InnerException); - Assert.AreEqual("EndExecute can only be called once for each asynchronous operation.", ex.Message); - Assert.IsNull(ex.ParamName); - } + Assert.AreEqual(_sshCommand.Result, _sshCommand.EndExecute(_asyncResult)); } [TestMethod] From d74faaf4fb418aace73de7c20c9be1b4fdf68beb Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Tue, 4 Jun 2024 00:00:59 +0200 Subject: [PATCH 2/2] Make ExitStatus nullable --- src/Renci.SshNet/SshCommand.cs | 26 ++++++++++++++++--- .../OldIntegrationTests/SshCommandTest.cs | 3 +++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs index 1ee8d752d..cfede42dd 100644 --- a/src/Renci.SshNet/SshCommand.cs +++ b/src/Renci.SshNet/SshCommand.cs @@ -40,6 +40,9 @@ public class SshCommand : IDisposable private ChannelInputStream? _inputStream; private TimeSpan _commandTimeout; + private int _exitStatus; + private volatile bool _haveExitStatus; // volatile to prevent re-ordering of reads/writes of _exitStatus. + /// /// Gets the command text. /// @@ -66,9 +69,22 @@ public TimeSpan CommandTimeout } /// - /// Gets the command exit status. + /// Gets the number representing the exit status of the command, if applicable, + /// otherwise . /// - public int ExitStatus { get; private set; } + /// + /// The value is not when an exit status code has been returned + /// from the server. If the command terminated due to a signal, + /// may be not instead. + /// + /// + public int? ExitStatus + { + get + { + return _haveExitStatus ? _exitStatus : null; + } + } /// /// Gets the name of the signal due to which the command @@ -276,7 +292,8 @@ public IAsyncResult BeginExecute(AsyncCallback? callback, object? state) AsyncState = state, }; - ExitStatus = default; + _exitStatus = default; + _haveExitStatus = false; ExitSignal = null; _result = null; _stdOut = null; @@ -492,7 +509,8 @@ private void Channel_RequestReceived(object? sender, ChannelRequestEventArgs e) { if (e.Info is ExitStatusRequestInfo exitStatusInfo) { - ExitStatus = (int)exitStatusInfo.ExitStatus; + _exitStatus = (int)exitStatusInfo.ExitStatus; + _haveExitStatus = true; Debug.Assert(!exitStatusInfo.WantReply, "exit-status is want_reply := false by definition."); } diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs index eb229b9c3..68013cfa1 100644 --- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs +++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs @@ -68,6 +68,7 @@ public void Test_CancelAsync_Unfinished_Command() Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); Assert.AreEqual(string.Empty, cmd.Result); Assert.AreEqual("TERM", cmd.ExitSignal); + Assert.IsNull(cmd.ExitStatus); } [TestMethod] @@ -90,6 +91,7 @@ public async Task Test_CancelAsync_Kill_Unfinished_Command() Assert.IsTrue(asyncResult.AsyncWaitHandle.WaitOne(0)); Assert.AreEqual(string.Empty, cmd.Result); Assert.AreEqual("KILL", cmd.ExitSignal); + Assert.IsNull(cmd.ExitStatus); } [TestMethod] @@ -110,6 +112,7 @@ public void Test_CancelAsync_Finished_Command() Assert.IsTrue(asyncResult.IsCompleted); Assert.AreEqual(testValue, cmd.Result); + Assert.AreEqual(0, cmd.ExitStatus); Assert.IsNull(cmd.ExitSignal); }