diff --git a/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java b/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java index e238bd7..3e3ba2d 100644 --- a/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java +++ b/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java @@ -43,7 +43,7 @@ long initDStream(long stream) { } @Override - long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) { + long decompressStream(long stream, ByteBuffer dst, int dstBufPos, int dstSize, ByteBuffer src, int srcBufPos, int srcSize) { if (!src.hasArray()) { throw new IllegalArgumentException("provided source ByteBuffer lacks array"); } @@ -53,7 +53,10 @@ long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, B byte[] targetArr = dst.array(); byte[] sourceArr = src.array(); - return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize); + // We are interested in array data corresponding to the pos represented by the ByteBuffer view. + // A ByteBuffer may share an underlying array with other ByteBuffers. In such scenario, we need to adjust the + // index of the array by adding an offset using arrayOffset(). + return decompressStreamNative(stream, targetArr, dstBufPos + dst.arrayOffset(), dstSize, sourceArr, srcBufPos + src.arrayOffset(), srcSize); } public static int recommendedTargetBufferSize() { diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index 9d57559..ff617f7 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -679,8 +679,15 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { val channel = FileChannel.open(file.toPath, StandardOpenOption.READ) // write some garbage bytes at the beginning of buffer containing compressed data to prove that // this buffer's position doesn't have to start from 0. - val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset()); - val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length) + val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset()) + // add some extra bytes to the underlying array of the ByteBuffer. The ByteBuffer view does not include these + // extra bytes. These are added to the underlying array to test for scenarios where the ByteBuffer view is a slice + // of the underlying array. + val extraBytes = "extra bytes".getBytes(Charset.defaultCharset()) + // Create a read buffer with extraBytes, we will later carve a slice out of it to store the compressed data. + val bigReadBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length + extraBytes.length) + bigReadBuffer.put(extraBytes) + val readBuffer = bigReadBuffer.slice() readBuffer.put(garbageBytes) channel.read(readBuffer) // set pos to 0 and limit to containing bytes @@ -694,7 +701,9 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { var pos = 0 // write some garbage bytes at the beginning of buffer containing uncompressed data to prove that // this buffer's position doesn't have to start from 0. - val block = ByteBuffer.allocate(1 + garbageBytes.length) + val bigBlock = ByteBuffer.allocate(1 + garbageBytes.length + extraBytes.length) + bigBlock.put(extraBytes) + var block = bigBlock.slice() while (pos < length && zis.hasRemaining) { block.clear block.put(garbageBytes)