From 475ed5d456fb31bbc04c7b1002a9c65ebc10de66 Mon Sep 17 00:00:00 2001 From: Jacob Greenfield Date: Tue, 28 Nov 2023 00:17:18 +0000 Subject: [PATCH] More robust error handling --- src/main/java/com/github/luben/zstd/Zstd.java | 2 +- .../github/luben/zstd/ZstdCompressCtx.java | 28 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/github/luben/zstd/Zstd.java b/src/main/java/com/github/luben/zstd/Zstd.java index b9d9f45..1a7c81e 100644 --- a/src/main/java/com/github/luben/zstd/Zstd.java +++ b/src/main/java/com/github/luben/zstd/Zstd.java @@ -591,7 +591,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff public static native int loadDictCompress(long stream, byte[] dict, int dict_size); public static native int loadFastDictCompress(long stream, ZstdDictCompress dict); public static native void registerSequenceProducer(long stream, long seqProdState, long seqProdFunction); - public static native void generateSequences(long stream, long outSeqs, long outSeqsSize, long src, long srcSize); + static native void generateSequences(long stream, long outSeqs, long outSeqsSize, long src, long srcSize); static native long getBuiltinSequenceProducer(); // Used in tests static native long getStubSequenceProducer(); // Used in tests public static native int setCompressionChecksums(long stream, boolean useChecksums); diff --git a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java index 69099b2..b1a378d 100644 --- a/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java +++ b/src/main/java/com/github/luben/zstd/ZstdCompressCtx.java @@ -283,19 +283,27 @@ public ZstdCompressCtx setLong(int windowLog) { public ZstdCompressCtx registerSequenceProducer(SequenceProducer producer) { ensureOpen(); acquireSharedLock(); - if (this.seqprod != null) { - this.seqprod.freeState(seqprod_state); - } + try { + if (this.seqprod != null) { + this.seqprod.freeState(seqprod_state); + this.seqprod = null; + } - if (producer == null) { - seqprod_state = 0; + if (producer == null) { + Zstd.registerSequenceProducer(nativePtr, 0, 0); + } else { + seqprod_state = producer.createState(); + Zstd.registerSequenceProducer(nativePtr, seqprod_state, producer.getFunctionPointer()); + this.seqprod = producer; + } + } catch (Exception e) { + this.seqprod = null; Zstd.registerSequenceProducer(nativePtr, 0, 0); - } else { - seqprod_state = producer.createState(); - Zstd.registerSequenceProducer(nativePtr, seqprod_state, producer.getFunctionPointer()); + releaseSharedLock(); + throw e; + } finally { + releaseSharedLock(); } - this.seqprod = producer; - releaseSharedLock(); return this; }