Skip to content

Commit

Permalink
Fix bug in ZstdBufferDecompressingStream when array backing ByteBuffe…
Browse files Browse the repository at this point in the history
…r is shared
  • Loading branch information
divijvaidya authored and luben committed May 30, 2023
1 parent 709ffb5 commit 355b851
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ long initDStream(long stream) {
}

@Override
long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) {
long decompressStream(long stream, ByteBuffer dst, int dstBufPos, int dstSize, ByteBuffer src, int srcBufPos, int srcSize) {
if (!src.hasArray()) {
throw new IllegalArgumentException("provided source ByteBuffer lacks array");
}
Expand All @@ -53,7 +53,10 @@ long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, B
byte[] targetArr = dst.array();
byte[] sourceArr = src.array();

return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize);
// We are interested in array data corresponding to the pos represented by the ByteBuffer view.
// A ByteBuffer may share an underlying array with other ByteBuffers. In such scenario, we need to adjust the
// index of the array by adding an offset using arrayOffset().
return decompressStreamNative(stream, targetArr, dstBufPos + dst.arrayOffset(), dstSize, sourceArr, srcBufPos + src.arrayOffset(), srcSize);
}

public static int recommendedTargetBufferSize() {
Expand Down
15 changes: 12 additions & 3 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,15 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
val channel = FileChannel.open(file.toPath, StandardOpenOption.READ)
// write some garbage bytes at the beginning of buffer containing compressed data to prove that
// this buffer's position doesn't have to start from 0.
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset());
val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length)
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset())
// add some extra bytes to the underlying array of the ByteBuffer. The ByteBuffer view does not include these
// extra bytes. These are added to the underlying array to test for scenarios where the ByteBuffer view is a slice
// of the underlying array.
val extraBytes = "extra bytes".getBytes(Charset.defaultCharset())
// Create a read buffer with extraBytes, we will later carve a slice out of it to store the compressed data.
val bigReadBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length + extraBytes.length)
bigReadBuffer.put(extraBytes)
val readBuffer = bigReadBuffer.slice()
readBuffer.put(garbageBytes)
channel.read(readBuffer)
// set pos to 0 and limit to containing bytes
Expand All @@ -694,7 +701,9 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
var pos = 0
// write some garbage bytes at the beginning of buffer containing uncompressed data to prove that
// this buffer's position doesn't have to start from 0.
val block = ByteBuffer.allocate(1 + garbageBytes.length)
val bigBlock = ByteBuffer.allocate(1 + garbageBytes.length + extraBytes.length)
bigBlock.put(extraBytes)
var block = bigBlock.slice()
while (pos < length && zis.hasRemaining) {
block.clear
block.put(garbageBytes)
Expand Down

0 comments on commit 355b851

Please sign in to comment.