diff --git a/src/main/java/com/github/luben/zstd/Zstd.java b/src/main/java/com/github/luben/zstd/Zstd.java index 8f25b5c..522f3d1 100644 --- a/src/main/java/com/github/luben/zstd/Zstd.java +++ b/src/main/java/com/github/luben/zstd/Zstd.java @@ -561,6 +561,9 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff public static native int loadDictCompress(long stream, byte[] dict, int dict_size); public static native int loadFastDictCompress(long stream, ZstdDictCompress dict); public static native void registerSequenceProducer(long stream, long seqProdState, long seqProdFunction); + public static native void generateSequences(long stream, long outSeqs, long outSeqsSize, long src, long srcSize); + static native long getBuiltinSequenceProducer(); // Used in tests + static native long getStubSequenceProducer(); // Used in tests public static native int setCompressionChecksums(long stream, boolean useChecksums); public static native int setCompressionMagicless(long stream, boolean useMagicless); public static native int setCompressionLevel(long stream, int level); @@ -578,6 +581,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff public static native int setDecompressionLongMax(long stream, int windowLogMax); public static native int setDecompressionMagicless(long stream, boolean useMagicless); public static native int setRefMultipleDDicts(long stream, boolean useMultiple); + public static native int setValidateSequences(long stream, boolean validateSequences); /* Utility methods */ /** diff --git a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java index cd8ff9b..29cd453 100644 --- a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java +++ b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java @@ -312,6 +312,24 @@ public ZstdCompressCtx setSequenceProducerFallback(boolean fallbackFlag){ } private static native void setSequenceProducerFallback0(long ptr, boolean fallbackFlag); + public ZstdCompressCtx setValidateSequences(boolean validateSequences) { + ensureOpen(); + acquireSharedLock(); + try { + long result = Zstd.setValidateSequences(nativePtr, validateSequences); + if (Zstd.isError(result)) { + throw new ZstdException(result); + } + } finally { + releaseSharedLock(); + } + return this; + } + + // Used in tests + long getNativePtr() { + return nativePtr; + } /** * Load compression dictionary to be used for subsequently compressed frames. diff --git a/src/main/native/jni_zstd.c b/src/main/native/jni_zstd.c index 198827a..8bed96f 100644 --- a/src/main/native/jni_zstd.c +++ b/src/main/native/jni_zstd.c @@ -293,6 +293,57 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_loadFastDictCompress return ZSTD_CCtx_refCDict((ZSTD_CCtx *)(intptr_t) stream, cdict); } +size_t builtinSequenceProducer( + void* sequenceProducerState, + ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, + const void* src, size_t srcSize, + const void* dict, size_t dictSize, + int compressionLevel, + size_t windowSize +) { + ZSTD_CCtx *zc = (ZSTD_CCtx *)sequenceProducerState; + int windowLog = 0; + while (windowSize > 1) { + windowLog++; + windowSize >>= 1; + } + ZSTD_CCtx_setParameter(zc, ZSTD_c_compressionLevel, compressionLevel); + ZSTD_CCtx_setParameter(zc, ZSTD_c_windowLog, windowSize); + size_t numSeqs = ZSTD_generateSequences((ZSTD_CCtx *)sequenceProducerState, outSeqs, outSeqsCapacity, src, srcSize); + return ZSTD_isError(numSeqs) ? ZSTD_SEQUENCE_PRODUCER_ERROR : numSeqs; +} + +size_t stubSequenceProducer( + void* sequenceProducerState, + ZSTD_Sequence* outSeqs, size_t outSeqsCapacity, + const void* src, size_t srcSize, + const void* dict, size_t dictSize, + int compressionLevel, + size_t windowSize +) { + return ZSTD_SEQUENCE_PRODUCER_ERROR; +} + +/* + * Class: com_github_luben_zstd_Zstd + * Method: getBuiltinSequenceProducer + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getBuiltinSequenceProducer + (JNIEnv *env, jclass obj) { + return (jlong)(intptr_t)&builtinSequenceProducer; +} + +/* + * Class: com_github_luben_zstd_Zstd + * Method: getBuiltinSequenceProducer + * Signature: ()J + */ +JNIEXPORT jlong JNICALL Java_com_github_luben_zstd_Zstd_getStubSequenceProducer + (JNIEnv *env, jclass obj) { + return (jlong)(intptr_t)&stubSequenceProducer; +} + /* * Class: com_github_luben_zstd_Zstd * Method: registerSequenceProducer @@ -489,6 +540,16 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setRefMultipleDDicts return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_refMultipleDDicts, value); } +/* + * Class: com_github_luben_zstd_Zstd + * Method: setValidateSequences + * Signature: (JZ)I + */ +JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setValidateSequences + (JNIEnv *env, jclass obj, jlong stream, jboolean validateSequences) { + return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_validateSequences, validateSequences); +} + /* * Class: com_github_luben_zstd_Zstd * Methods: header constants access diff --git a/src/test/scala/Zstd.scala b/src/test/scala/Zstd.scala index 0235f41..65a147a 100644 --- a/src/test/scala/Zstd.scala +++ b/src/test/scala/Zstd.scala @@ -9,8 +9,9 @@ import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode import java.nio.charset.Charset import java.nio.file.StandardOpenOption -import scala.io._ +import scala.annotation.unused import scala.collection.mutable.WrappedArray +import scala.io._ import scala.util.Using class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { @@ -1105,7 +1106,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { } } - "streaming compressiong and decompression" should "roundtrip" in { + "streaming compression and decompression" should "roundtrip" in { Using.Manager { use => val cctx = use(new ZstdCompressCtx()) val dctx = use(new ZstdDecompressCtx()) @@ -1149,7 +1150,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { decompressedBuffer.flip() val comparison = inputBuffer.compareTo(decompressedBuffer) - comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size + assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size) } } }.get @@ -1211,4 +1212,180 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks { } } }.get + + it should "be able to use a sequence producer" in { + Using.Manager { use => + val cctx = use(new ZstdCompressCtx()) + val cctx2 = use(new ZstdCompressCtx()) + val dctx = use(new ZstdDecompressCtx()) + + forAll { input: Array[Byte] => + { + val size = input.length + val inputBuffer = ByteBuffer.allocateDirect(size) + inputBuffer.put(input) + inputBuffer.flip() + cctx.reset() + cctx.setLevel(9) + val seqProd = new SequenceProducer { + def getFunctionPointer(): Long = { + Zstd.getBuiltinSequenceProducer() + } + + def createState(): Long = { + cctx2.getNativePtr() + } + + def freeState(@unused state: Long) = {} + } + cctx.registerSequenceProducer(seqProd) + cctx.setValidateSequences(true) + cctx.setSequenceProducerFallback(false) + cctx.setPledgedSrcSize(size) + val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt) + while (inputBuffer.hasRemaining) { + compressedBuffer.limit(compressedBuffer.position() + 1) + cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE) + } + + var frameProgression = cctx.getFrameProgression() + assert(frameProgression.getIngested() == size) + assert(frameProgression.getFlushed() == compressedBuffer.position()) + + compressedBuffer.limit(compressedBuffer.capacity()) + val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END) + assert(done) + + frameProgression = cctx.getFrameProgression() + assert(frameProgression.getConsumed() == size) + + compressedBuffer.flip() + val decompressedBuffer = ByteBuffer.allocateDirect(size) + dctx.reset() + while (compressedBuffer.hasRemaining) { + if (decompressedBuffer.limit() < decompressedBuffer.position()) { + decompressedBuffer.limit(compressedBuffer.position() + 1) + } + dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer) + } + + inputBuffer.rewind() + compressedBuffer.rewind() + decompressedBuffer.flip() + + val comparison = inputBuffer.compareTo(decompressedBuffer) + assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size) + } + } + }.get + } + + it should "fail with a stub sequence producer" in { + Using.Manager { use => + val cctx = use(new ZstdCompressCtx()) + + forAll(minSize(32)) { input: Array[Byte] => + { + val size = input.length + val inputBuffer = ByteBuffer.allocateDirect(size) + inputBuffer.put(input) + inputBuffer.flip() + cctx.reset() + cctx.setLevel(9) + + val seqProd = new SequenceProducer { + def getFunctionPointer(): Long = { + Zstd.getStubSequenceProducer() + } + + def createState(): Long = { 0 } + def freeState(@unused state: Long) = { 0 } + } + + cctx.registerSequenceProducer(seqProd) + cctx.setValidateSequences(true) + cctx.setSequenceProducerFallback(false) + cctx.setPledgedSrcSize(size) + + val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt) + try { + while (inputBuffer.hasRemaining) { + compressedBuffer.limit(compressedBuffer.position() + 1) + cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE) + } + cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END) + fail("compression succeeded, but should have failed") + } catch { + case _: ZstdException => // compression should throw a ZstdException + } + } + } + }.get + } + + it should "succeed with a stub sequence producer and software fallback" in { + Using.Manager { use => + val cctx = use(new ZstdCompressCtx()) + val dctx = use(new ZstdDecompressCtx()) + + forAll { input: Array[Byte] => + { + val size = input.length + val inputBuffer = ByteBuffer.allocateDirect(size) + inputBuffer.put(input) + inputBuffer.flip() + cctx.reset() + cctx.setLevel(9) + + val seqProd = new SequenceProducer { + def getFunctionPointer(): Long = { + Zstd.getStubSequenceProducer() + } + + def createState(): Long = { 0 } + def freeState(@unused state: Long) = { 0 } + } + + cctx.registerSequenceProducer(seqProd) + cctx.setValidateSequences(true) + cctx.setSequenceProducerFallback(true) // !! + cctx.setPledgedSrcSize(size) + + val compressedBuffer = ByteBuffer.allocateDirect(Zstd.compressBound(size).toInt) + while (inputBuffer.hasRemaining) { + compressedBuffer.limit(compressedBuffer.position() + 1) + cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.CONTINUE) + } + + var frameProgression = cctx.getFrameProgression() + assert(frameProgression.getIngested() == size) + assert(frameProgression.getFlushed() == compressedBuffer.position()) + + compressedBuffer.limit(compressedBuffer.capacity()) + val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective.END) + assert(done) + + frameProgression = cctx.getFrameProgression() + assert(frameProgression.getConsumed() == size) + + compressedBuffer.flip() + val decompressedBuffer = ByteBuffer.allocateDirect(size) + dctx.reset() + while (compressedBuffer.hasRemaining) { + if (decompressedBuffer.limit() < decompressedBuffer.position()) { + decompressedBuffer.limit(compressedBuffer.position() + 1) + } + dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer) + } + + inputBuffer.rewind() + compressedBuffer.rewind() + decompressedBuffer.flip() + + val comparison = inputBuffer.compareTo(decompressedBuffer) + assert(comparison == 0 && Zstd.decompressedSize(compressedBuffer) == size && Zstd.getFrameContentSize(compressedBuffer) == size) + } + } + }.get + } }