Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed dequeuing of incoming queue #1319

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions src/Renci.SshNet/ShellStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,14 @@ public void Expect(TimeSpan timeout, params ExpectAction[] expectActions)

if (match.Success)
{
var returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
var returnText = SyncQueuesAndReturn(returnLength);

expectAction.Action(returnText);
expectedFound = true;
Expand Down Expand Up @@ -385,19 +380,14 @@ public string Expect(Regex regex, TimeSpan timeout)

if (match.Success)
{
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
returnText = SyncQueuesAndReturn(returnLength);

break;
}
Expand Down Expand Up @@ -501,19 +491,14 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object

if (match.Success)
{
returnText = matchText.Substring(0, match.Index + match.Length);
var returnLength = _encoding.GetByteCount(returnText);
#if NETSTANDARD2_1_OR_GREATER || NET6_0_OR_GREATER
var returnLength = _encoding.GetByteCount(matchText.AsSpan(0, match.Index + match.Length));
#else
var returnLength = _encoding.GetByteCount(matchText.Substring(0, match.Index + match.Length));
#endif

// Remove processed items from the queue
for (var i = 0; i < returnLength && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
returnText = SyncQueuesAndReturn(returnLength);

expectAction.Action(returnText);
callback?.Invoke(asyncResult);
Expand Down Expand Up @@ -614,15 +599,7 @@ public string ReadLine(TimeSpan timeout)
var bytesProcessed = _encoding.GetByteCount(text + CrLf);

// remove processed bytes from the queue
for (var i = 0; i < bytesProcessed; i++)
{
if (_expect.Count == _incoming.Count)
{
_ = _expect.Dequeue();
}

_ = _incoming.Dequeue();
}
SyncQueuesAndDequeue(bytesProcessed);

break;
}
Expand Down Expand Up @@ -687,7 +664,7 @@ public override int Read(byte[] buffer, int offset, int count)
{
for (; i < count && _incoming.Count > 0; i++)
{
if (_expect.Count == _incoming.Count)
if (_incoming.Count == _expect.Count)
{
_ = _expect.Dequeue();
}
Expand Down Expand Up @@ -869,5 +846,37 @@ private void OnDataReceived(byte[] data)
{
DataReceived?.Invoke(this, new ShellDataEventArgs(data));
}

private string SyncQueuesAndReturn(int bytesToDequeue)
{
string incomingText;

lock (_incoming)
{
var incomingLength = _incoming.Count - _expect.Count + bytesToDequeue;
incomingText = _encoding.GetString(_incoming.ToArray(), 0, incomingLength);

SyncQueuesAndDequeue(bytesToDequeue);
jscarle marked this conversation as resolved.
Show resolved Hide resolved
}

return incomingText;
}

private void SyncQueuesAndDequeue(int bytesToDequeue)
{
lock (_incoming)
{
while (_incoming.Count > _expect.Count)
{
_ = _incoming.Dequeue();
}

for (var count = 0; count < bytesToDequeue && _incoming.Count > 0; count++)
{
_ = _incoming.Dequeue();
_ = _expect.Dequeue();
}
}
}
}
}
31 changes: 29 additions & 2 deletions test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace Renci.SshNet.Tests.Classes
[TestClass]
public class ShellStreamTest_ReadExpect
{
private const int BufferSize = 1024;
private const int ExpectSize = BufferSize * 2;
private ShellStream _shellStream;
private ChannelSessionStub _channelSessionStub;

Expand All @@ -42,8 +44,8 @@ public void Initialize()
width: 800,
height: 600,
terminalModeValues: null,
bufferSize: 1024,
expectSize: 2048);
bufferSize: BufferSize,
expectSize: ExpectSize);
}

[TestMethod]
Expand Down Expand Up @@ -244,6 +246,31 @@ public void Expect_String_LargeExpect()
Assert.AreEqual($"{new string('c', 100)}", _shellStream.Read());
}

[TestMethod]
public void Expect_String_DequeueChecks()
{
const string expected = "ccccc";

// Prime buffer
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', BufferSize)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string(' ', ExpectSize)));

// Test data
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('a', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('b', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(expected));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('d', 100)));
_channelSessionStub.Receive(Encoding.UTF8.GetBytes(new string('e', 100)));

// Expected result
var expectedResult = $"{new string(' ', BufferSize)}{new string(' ', ExpectSize)}{new string('a', 100)}{new string('b', 100)}{expected}";
var expectedRead = $"{new string('d', 100)}{new string('e', 100)}";

Assert.AreEqual(expectedResult, _shellStream.Expect(expected));

Assert.AreEqual(expectedRead, _shellStream.Read());
}

[TestMethod]
public void Expect_Timeout()
{
Expand Down