From 53907ad508304b3decc5cee9f3406042efcf3c16 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Wed, 11 Aug 2021 10:43:00 -0700 Subject: [PATCH] let paddle hold the creation buffer --- .../main/java/ai/djl/paddlepaddle/engine/PpNDArray.java | 7 ++++++- .../src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java | 4 ++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 2c65232bc4a..2645413f1fe 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -26,6 +26,8 @@ public class PpNDArray extends NativeResource implements NDArrayAdapter { private PpNDManager manager; + // we keep the data to prevent GC from early collecting native memory + private ByteBuffer data; private Shape shape; private DataType dataType; @@ -33,11 +35,13 @@ public class PpNDArray extends NativeResource implements NDArrayAdapter { * Constructs an PpNDArray from a native handle (internal. Use {@link NDManager} instead). * * @param manager the manager to attach the new array to + * @param data bytebuffer that holds the native memory * @param handle the pointer to the native MxNDArray memory */ - public PpNDArray(PpNDManager manager, long handle) { + public PpNDArray(PpNDManager manager, ByteBuffer data, long handle) { super(handle); this.manager = manager; + this.data = data; manager.attachInternal(getUid(), this); } @@ -148,6 +152,7 @@ public void close() { Long pointer = handle.getAndSet(null); if (pointer != null) { JniUtils.deleteNd(pointer); + data = null; } } } diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java index a899dfc5e96..56db235d4d7 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/jni/JniUtils.java @@ -37,7 +37,7 @@ public static PpNDArray createNdArray( long handle = PaddleLibrary.LIB.paddleCreateTensor( data, data.remaining(), intShape, PpDataType.toPaddlePaddle(dtype)); - return new PpNDArray(manager, handle); + return new PpNDArray(manager, data, handle); } public static DataType getDTypeFromNd(PpNDArray array) { @@ -119,7 +119,7 @@ public static PpNDArray[] predictorForward( PpNDManager manager = (PpNDManager) inputs[0].getManager(); PpNDArray[] arrays = new PpNDArray[outputs.length]; for (int i = 0; i < outputs.length; i++) { - arrays[i] = new PpNDArray(manager, outputs[i]); + arrays[i] = new PpNDArray(manager, null, outputs[i]); } return arrays; }