diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java index a5558030e..81e95fe11 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java @@ -220,16 +220,42 @@ public int read(byte[] into, int off, int len) throws IOException { public class ReadAheadRemoteFileInputStream extends InputStream { + private class UnconfirmedRead { + private final long offset; + private final Promise promise; + private final int length; + + private UnconfirmedRead(long offset, int length, Promise promise) { + this.offset = offset; + this.length = length; + this.promise = promise; + } + + UnconfirmedRead(long offset, int length) throws IOException { + this(offset, length, RemoteFile.this.asyncRead(offset, length)); + } + + public long getOffset() { + return offset; + } + + public Promise getPromise() { + return promise; + } + + public int getLength() { + return length; + } + } private final byte[] b = new byte[1]; private final int maxUnconfirmedReads; private final long readAheadLimit; - private final Queue> unconfirmedReads = new LinkedList>(); - private final Queue unconfirmedReadOffsets = new LinkedList(); + private final Queue unconfirmedReads = new LinkedList<>(); - private long requestOffset; - private long responseOffset; + private long currentOffset; + private int maxReadLength = Integer.MAX_VALUE; private boolean eof; public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) { @@ -247,28 +273,42 @@ public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset, assert 0 <= fileOffset; this.maxUnconfirmedReads = maxUnconfirmedReads; - this.requestOffset = this.responseOffset = fileOffset; + this.currentOffset = fileOffset; this.readAheadLimit = readAheadLimit > 0 ? fileOffset + readAheadLimit : Long.MAX_VALUE; } private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]); private boolean retrieveUnconfirmedRead(boolean blocking) throws IOException { - if (unconfirmedReads.size() <= 0) { + final UnconfirmedRead unconfirmedRead = unconfirmedReads.peek(); + if (unconfirmedRead == null || !blocking && !unconfirmedRead.getPromise().isDelivered()) { return false; } + unconfirmedReads.remove(unconfirmedRead); - if (!blocking && !unconfirmedReads.peek().isDelivered()) { - return false; - } - - unconfirmedReadOffsets.remove(); - final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); + final Response res = unconfirmedRead.promise.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); switch (res.getType()) { case DATA: int recvLen = res.readUInt32AsInt(); - responseOffset += recvLen; - pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen); + if (unconfirmedRead.offset == currentOffset) { + currentOffset += recvLen; + pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen); + + if (recvLen < unconfirmedRead.length) { + // The server returned a packet smaller than the client had requested. + // It can be caused by at least one of the following: + // * The file has been read fully. Then, few futile read requests can be sent during + // the next read(), but the file will be downloaded correctly anyway. + // * The server shapes the request length. Then, the read window will be adjusted, + // and all further read-ahead requests won't be shaped. + // * The file on the server is not a regular file, it is something like fifo. + // Then, the window will shrink, and the client will start reading the file slower than it + // hypothetically can. It must be a rare case, and it is not worth implementing a sort of + // congestion control algorithm here. + maxReadLength = recvLen; + unconfirmedReads.clear(); + } + } break; case STATUS: @@ -296,49 +336,24 @@ public int read(byte[] into, int off, int len) throws IOException { // we also need to go here for len <= 0, because pending may be at // EOF in which case it would return -1 instead of 0 + long requestOffset = currentOffset; while (unconfirmedReads.size() <= maxUnconfirmedReads) { // Send read requests as long as there is no EOF and we have not reached the maximum parallelism - int reqLen = Math.max(1024, len); // don't be shy! + int reqLen = Math.min(Math.max(1024, len), maxReadLength); if (readAheadLimit > requestOffset) { long remaining = readAheadLimit - requestOffset; if (reqLen > remaining) { reqLen = (int) remaining; } } - unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen)); - unconfirmedReadOffsets.add(requestOffset); + unconfirmedReads.add(new UnconfirmedRead(requestOffset, reqLen)); requestOffset += reqLen; if (requestOffset >= readAheadLimit) { break; } } - long nextOffset = unconfirmedReadOffsets.peek(); - if (responseOffset != nextOffset) { - - // the server could not give us all the data we needed, so - // we try to fill the gap synchronously - - assert responseOffset < nextOffset; - assert 0 < (nextOffset - responseOffset); - assert (nextOffset - responseOffset) <= Integer.MAX_VALUE; - - byte[] buf = new byte[(int) (nextOffset - responseOffset)]; - int recvLen = RemoteFile.this.read(responseOffset, buf, 0, buf.length); - - if (recvLen < 0) { - eof = true; - return -1; - } - - if (0 == recvLen) { - // avoid infinite loops - throw new SFTPException("Unexpected response size (0), bailing out"); - } - - responseOffset += recvLen; - pending = new ByteArrayInputStream(buf, 0, recvLen); - } else if (!retrieveUnconfirmedRead(true /*blocking*/)) { + if (!retrieveUnconfirmedRead(true /*blocking*/)) { // this may happen if we change prefetch strategy // currently, we should never get here... diff --git a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java index 949a917c0..c69d12406 100644 --- a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java +++ b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java @@ -17,22 +17,24 @@ import com.hierynomus.sshj.test.SshFixture; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.ByteArrayUtils; import net.schmizz.sshj.sftp.OpenMode; import net.schmizz.sshj.sftp.RemoteFile; import net.schmizz.sshj.sftp.SFTPEngine; import net.schmizz.sshj.sftp.SFTPException; +import org.apache.sshd.common.util.io.IoUtils; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; +import java.security.SecureRandom; import java.util.EnumSet; import java.util.Random; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class RemoteFileTest { @@ -174,4 +176,53 @@ public void limitedReadAheadInputStream() throws IOException { assertThat("The written and received data should match", data, equalTo(test2)); } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_FullSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 1024 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_HalfSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 512 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_QuarterSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 256 * 1024); + } + + @Test + public void shouldReadCorrectlyWhenWrappedInBufferedStream_SmallSizeBuffer() throws IOException { + doTestShouldReadCorrectlyWhenWrappedInBufferedStream(1024 * 1024, 1024); + } + + private void doTestShouldReadCorrectlyWhenWrappedInBufferedStream(int fileSize, int bufferSize) throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + final byte[] expected = new byte[fileSize]; + new SecureRandom(new byte[] { 31 }).nextBytes(expected); + + File file = temp.newFile("shouldReadCorrectlyWhenWrappedInBufferedStream.bin"); + try (OutputStream fStream = new FileOutputStream(file)) { + IoUtils.copy(new ByteArrayInputStream(expected), fStream); + } + + RemoteFile rf = sftp.open(file.getPath()); + final byte[] actual; + try (InputStream inputStream = new BufferedInputStream( + rf.new ReadAheadRemoteFileInputStream(10), + bufferSize) + ) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + IoUtils.copy(inputStream, baos, expected.length); + actual = baos.toByteArray(); + } + + assertEquals("The file should be fully read", expected.length, actual.length); + assertThat("The file should be read correctly", + ByteArrayUtils.equals(expected, 0, actual, 0, expected.length)); + } }