Skip to content

Commit

Permalink
Relax the requirement for source and target ByteBuffer in ZstdBufferD…
Browse files Browse the repository at this point in the history
…ecompressingStream
  • Loading branch information
divijvaidya authored and luben committed Apr 26, 2023
1 parent 4793b0b commit 100c434
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ public abstract class BaseZstdBufferDecompressingStreamNoFinalizer implements Cl
protected boolean closed = false;
private boolean finishedFrame = false;
private boolean streamEnd = false;
/**
* This field is set by the native call to represent the number of bytes consumed from {@link #source} buffer.
*/
private int consumed;
/**
* This field is set by the native call to represent the number of bytes produced into the target buffer.
*/
private int produced;

BaseZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
Expand All @@ -27,6 +33,9 @@ protected ByteBuffer refill(ByteBuffer toRefill) {
return toRefill;
}

/**
* @return false if all data is processed and no more data is available from the {@link #source}
*/
public boolean hasRemaining() {
return !streamEnd && (source.hasRemaining() || !finishedFrame);
}
Expand All @@ -52,6 +61,15 @@ public BaseZstdBufferDecompressingStreamNoFinalizer setDict(ZstdDictDecompress d
return this;
}

/**
* Set the value of zstd parameter <code>ZSTD_d_windowLogMax</code>.
*
* @param windowLogMax window size in bytes
* @return this instance of {@link BaseZstdBufferDecompressingStreamNoFinalizer}
* @throws ZstdIOException if there is an error while setting the configuration natively.
*
* @see <a href="https://github.com/facebook/zstd/blob/0525d1cec64a8df749ff293ee476f616de79f7b0/lib/zstd.h#L606"> Zstd's ZSTD_d_windowLogMax parameter</a>
*/
public BaseZstdBufferDecompressingStreamNoFinalizer setLongMax(int windowLogMax) throws IOException {
long size = Zstd.setDecompressionLongMax(stream, windowLogMax);
if (Zstd.isError(size)) {
Expand Down Expand Up @@ -106,6 +124,23 @@ public void close() {
}
}
}

/**
* Reads the content of the de-compressed stream into the target buffer.
* <p>This method will block until the chunk of compressed data stored in {@link #source} has been decompressed and
* written into the target buffer. After each execution, this method will refill the {@link #source} buffer, using
* {@link #refill(ByteBuffer)}.
*<p>To read the full stream of decompressed data, this method should be called in a loop while {@link #hasRemaining()}
* is <code>true</code>.
*<p>The target buffer will be written starting from {@link ByteBuffer#position()}. The {@link ByteBuffer#position()}
* of source and the target buffers will be modified to represent the data read and written respectively.
*
* @param target buffer to store the read bytes from uncompressed stream.
* @return the number of bytes read into the target buffer.
* @throws ZstdIOException if an error occurs while reading.
* @throws IllegalArgumentException if provided source or target buffers are incorrectly configured.
* @throws IOException if the stream is closed before reading.
*/
public abstract int read(ByteBuffer target) throws IOException;

abstract long createDStream();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public ZstdBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
if (source.isDirect()) {
throw new IllegalArgumentException("Source buffer should be a non-direct buffer");
}
stream = createDStreamNative();
initDStreamNative(stream);
stream = createDStream();
initDStream(stream);
}

@Override
Expand Down Expand Up @@ -44,8 +44,14 @@ long initDStream(long stream) {

@Override
long decompressStream(long stream, ByteBuffer dst, int dstOffset, int dstSize, ByteBuffer src, int srcOffset, int srcSize) {
byte[] targetArr = Zstd.extractArray(dst);
byte[] sourceArr = Zstd.extractArray(source);
if (!src.hasArray()) {
throw new IllegalArgumentException("provided source ByteBuffer lacks array");
}
if (!dst.hasArray()) {
throw new IllegalArgumentException("provided destination ByteBuffer lacks array");
}
byte[] targetArr = dst.array();
byte[] sourceArr = src.array();

return decompressStreamNative(stream, targetArr, dstOffset, dstSize, sourceArr, srcOffset, srcSize);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ public ZstdDirectBufferDecompressingStreamNoFinalizer(ByteBuffer source) {
throw new IllegalArgumentException("Source buffer should be a direct buffer");
}
this.source = source;
stream = createDStreamNative();
initDStreamNative(stream);
stream = createDStream();
initDStream(stream);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ int readInternal(byte[] dst, int offset, int len) throws IOException {
throw new IOException("Stream closed");
}

// guard agains buffer overflows
// guard against buffer overflows
if (offset < 0 || len > dst.length - offset) {
throw new IndexOutOfBoundsException("Requested length " + len
+ " from offset " + offset + " in buffer of size " + dst.length);
Expand Down
20 changes: 17 additions & 3 deletions src/test/scala/Zstd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package com.github.luben.zstd

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks

import java.io._
import java.nio._
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.nio.charset.Charset
import java.nio.file.StandardOpenOption

import scala.io._
import scala.collection.mutable.WrappedArray
import scala.util.Using
Expand Down Expand Up @@ -676,21 +677,34 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
val orig = new File("src/test/resources/xml")
val file = new File(s"src/test/resources/xml-$level.zst")
val channel = FileChannel.open(file.toPath, StandardOpenOption.READ)
val readBuffer = ByteBuffer.allocate(channel.size().toInt)
// write some garbage bytes at the beginning of buffer containing compressed data to prove that
// this buffer's position doesn't have to start from 0.
val garbageBytes = "garbage bytes".getBytes(Charset.defaultCharset());
val readBuffer = ByteBuffer.allocate(channel.size().toInt + garbageBytes.length)
readBuffer.put(garbageBytes)
channel.read(readBuffer)
// set pos to 0 and limit to containing bytes
readBuffer.flip()
// advance the position after garbage data
readBuffer.position(garbageBytes.length)

val zis = new ZstdBufferDecompressingStream(readBuffer)
val length = orig.length.toInt
val buff = Array.fill[Byte](length)(0)
var pos = 0
val block = ByteBuffer.allocate(1)
// write some garbage bytes at the beginning of buffer containing uncompressed data to prove that
// this buffer's position doesn't have to start from 0.
val block = ByteBuffer.allocate(1 + garbageBytes.length)
while (pos < length && zis.hasRemaining) {
block.clear
block.put(garbageBytes)
val read = zis.read(block)
if (read != 1) {
sys.error(s"Failed reading compressed file before end. Bytes read: $read")
}
block.flip()
// advance the position after garbage data
block.position(garbageBytes.length);
buff.update(pos, block.get())
pos += 1
}
Expand Down

0 comments on commit 100c434

Please sign in to comment.