From 100c434dfcec17a865ca2c2b844afe1046ce1b10 Mon Sep 17 00:00:00 2001 From: Divij Vaidya Date: Wed, 26 Apr 2023 19:36:17 +0200 Subject: [PATCH] Relax the requirement for source and target ByteBuffer in ZstdBufferDecompressingStream --- ...dBufferDecompressingStreamNoFinalizer.java | 35 +++++++++++++++++++ ...dBufferDecompressingStreamNoFinalizer.java | 14 +++++--- ...tBufferDecompressingStreamNoFinalizer.java | 4 +-- .../zstd/ZstdInputStreamNoFinalizer.java | 2 +- src/test/scala/Zstd.scala | 20 +++++++++-- 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/src/main/java/com/github/luben/zstd/BaseZstdBufferDecompressingStreamNoFinalizer.java b/src/main/java/com/github/luben/zstd/BaseZstdBufferDecompressingStreamNoFinalizer.java index c62ebee..abed022 100644 --- a/src/main/java/com/github/luben/zstd/BaseZstdBufferDecompressingStreamNoFinalizer.java +++ b/src/main/java/com/github/luben/zstd/BaseZstdBufferDecompressingStreamNoFinalizer.java @@ -10,7 +10,13 @@ public abstract class BaseZstdBufferDecompressingStreamNoFinalizer implements Cl protected boolean closed = false; private boolean finishedFrame = false; private boolean streamEnd = false; + /** + * This field is set by the native call to represent the number of bytes consumed from {@link #source} buffer. + */ private int consumed; + /** + * This field is set by the native call to represent the number of bytes produced into the target buffer. + */ private int produced; BaseZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) { @@ -27,6 +33,9 @@ protected ByteBuffer refill(ByteBuffer toRefill) { return toRefill; } + /** + * @return false if all data is processed and no more data is available from the {@link #source} + */ public boolean hasRemaining() { return !streamEnd && (source.hasRemaining() || !finishedFrame); } @@ -52,6 +61,15 @@ public BaseZstdBufferDecompressingStreamNoFinalizer setDict(ZstdDictDecompress d return this; } + /** + * Set the value of zstd parameter ZSTD_d_windowLogMax. + * + * @param windowLogMax window size in bytes + * @return this instance of {@link BaseZstdBufferDecompressingStreamNoFinalizer} + * @throws ZstdIOException if there is an error while setting the configuration natively. + * + * @see Zstd's ZSTD_d_windowLogMax parameter + */ public BaseZstdBufferDecompressingStreamNoFinalizer setLongMax(int windowLogMax) throws IOException { long size = Zstd.setDecompressionLongMax(stream, windowLogMax); if (Zstd.isError(size)) { @@ -106,6 +124,23 @@ public void close() { } } } + + /** + * Reads the content of the de-compressed stream into the target buffer. + *

This method will block until the chunk of compressed data stored in {@link #source} has been decompressed and + * written into the target buffer. After each execution, this method will refill the {@link #source} buffer, using + * {@link #refill(ByteBuffer)}. + *

To read the full stream of decompressed data, this method should be called in a loop while {@link #hasRemaining()} + * is true. + *

The target buffer will be written starting from {@link ByteBuffer#position()}. The {@link ByteBuffer#position()} + * of source and the target buffers will be modified to represent the data read and written respectively. + * + * @param target buffer to store the read bytes from uncompressed stream. + * @return the number of bytes read into the target buffer. + * @throws ZstdIOException if an error occurs while reading. + * @throws IllegalArgumentException if provided source or target buffers are incorrectly configured. + * @throws IOException if the stream is closed before reading. + */ public abstract int read(ByteBuffer target) throws IOException; abstract long createDStream(); diff --git a/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java b/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java index 52ed7fc..e238bd7 100644 --- a/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java +++ b/src/main/java/com/github/luben/zstd/ZstdBufferDecompressingStreamNoFinalizer.java @@ -15,8 +15,8 @@ public ZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) { if (source.isDirect()) { throw new IllegalArgumentException("Source buffer should be a non-direct buffer"); } - stream = createDStreamNative(); - initDStreamNative(stream); + stream = createDStream(); + initDStream(stream); } @Override @@ -44,8 +44,14 @@ long initDStream(long stream) { @Override long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) { - byte[] targetArr = Zstd.extractArray(dst); - byte[] sourceArr = Zstd.extractArray(source); + if (!src.hasArray()) { + throw new IllegalArgumentException("provided source ByteBuffer lacks array"); + } + if (!dst.hasArray()) { + throw new IllegalArgumentException("provided destination ByteBuffer lacks array"); + } + byte[] targetArr = dst.array(); + byte[] sourceArr = src.array(); return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize); } diff --git a/src/main/java/com/github/luben/zstd/ZstdDirectBufferDecompressingStreamNoFinalizer.java b/src/main/java/com/github/luben/zstd/ZstdDirectBufferDecompressingStreamNoFinalizer.java index 8055848..92b93ad 100644 --- a/src/main/java/com/github/luben/zstd/ZstdDirectBufferDecompressingStreamNoFinalizer.java +++ b/src/main/java/com/github/luben/zstd/ZstdDirectBufferDecompressingStreamNoFinalizer.java @@ -16,8 +16,8 @@ public ZstdDirectBufferDecompressingStreamNoFinalizer(ByteBuffer source) { throw new IllegalArgumentException("Source buffer should be a direct buffer"); } this.source = source; - stream = createDStreamNative(); - initDStreamNative(stream); + stream = createDStream(); + initDStream(stream); } @Override diff --git a/src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java b/src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java index 49816c9..5e7a760 100644 --- a/src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java +++ b/src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java @@ -140,7 +140,7 @@ int readInternal(byte[] dst, int offset, int len) throws IOException { throw new IOException("Stream closed"); } - // guard agains buffer overflows + // guard against buffer overflows if (offset < 0 || len > dst.length - offset) { throw new IndexOutOfBoundsException("Requested length " + len + " from offset " + offset + " in buffer of size " + dst.length); diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index ee4d79a..9d57559 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -2,12 +2,13 @@ package com.github.luben.zstd import org.scalatest.flatspec.AnyFlatSpec import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks + import java.io._ import java.nio._ import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode +import java.nio.charset.Charset import java.nio.file.StandardOpenOption - import scala.io._ import scala.collection.mutable.WrappedArray import scala.util.Using @@ -676,21 +677,34 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { val orig = new File("src/test/resources/xml") val file = new File(s"src/test/resources/xml-$level.zst") val channel = FileChannel.open(file.toPath, StandardOpenOption.READ) - val readBuffer = ByteBuffer.allocate(channel.size().toInt) + // write some garbage bytes at the beginning of buffer containing compressed data to prove that + // this buffer's position doesn't have to start from 0. + val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset()); + val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length) + readBuffer.put(garbageBytes) channel.read(readBuffer) + // set pos to 0 and limit to containing bytes readBuffer.flip() + // advance the position after garbage data + readBuffer.position(garbageBytes.length) + val zis = new ZstdBufferDecompressingStream(readBuffer) val length = orig.length.toInt val buff = Array.fill[Byte](length)(0) var pos = 0 - val block = ByteBuffer.allocate(1) + // write some garbage bytes at the beginning of buffer containing uncompressed data to prove that + // this buffer's position doesn't have to start from 0. + val block = ByteBuffer.allocate(1 + garbageBytes.length) while (pos < length && zis.hasRemaining) { block.clear + block.put(garbageBytes) val read = zis.read(block) if (read != 1) { sys.error(s"Failed reading compressed file before end. Bytes read: $read") } block.flip() + // advance the position after garbage data + block.position(garbageBytes.length); buff.update(pos, block.get()) pos += 1 }