diff --git a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java index e8299201529..86eff274d7f 100644 --- a/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java +++ b/engines/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDManager.java @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public DlrNDArray from(NDArray array) { - if (array instanceof DlrNDArray) { + if (array == null || array instanceof DlrNDArray) { return (DlrNDArray) array; } return (DlrNDArray) create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java index 509e27ad656..fc984013135 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java @@ -52,7 +52,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public XgbNDArray from(NDArray array) { - if (array instanceof XgbNDArray) { + if (array == null || array instanceof XgbNDArray) { return (XgbNDArray) array; } return (XgbNDArray) create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java index ad0ac616f71..90b61955e2a 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDManager.java @@ -65,7 +65,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public MxNDArray from(NDArray array) { - if (array instanceof MxNDArray) { + if (array == null || array instanceof MxNDArray) { return (MxNDArray) array; } MxNDArray ret = create(array.getShape(), array.getDataType()); diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java index 56235e806ff..d9b7f31d252 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDManager.java @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public OrtNDArray from(NDArray array) { - if (array instanceof OrtNDArray) { + if (array == null || array instanceof OrtNDArray) { return (OrtNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java index 5d4f54cc159..445eadcf4ed 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDManager.java @@ -69,7 +69,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public PpNDArray from(NDArray array) { - if (array instanceof PpNDArray) { + if (array == null || array instanceof PpNDArray) { return (PpNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java index cb92cd3b978..1cfb2e933cd 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java @@ -30,6 +30,7 @@ public class PtNDArrayEx implements NDArrayEx { private NDArrayIndexer indexer; private PtNDArray array; + private PtNDManager manager; /** * Constructs an {@code PtNDArrayEx} given a {@link NDArray}. @@ -38,13 +39,14 @@ public class PtNDArrayEx implements NDArrayEx { */ PtNDArrayEx(PtNDArray parent) { this.array = parent; - indexer = new PtNDArrayIndexer(array.getManager()); + this.manager = array.getManager(); + indexer = new PtNDArrayIndexer(manager); } /** {@inheritDoc} */ @Override public PtNDArray rdiv(Number n) { - return rdiv(array.getManager().create(n)); + return rdiv(manager.create(n)); } /** {@inheritDoc} */ @@ -276,10 +278,10 @@ public void adamUpdate( boolean lazyUpdate) { // TODO: Lazy update not used JniUtils.adamUpdate( - (PtNDArray) inputs.get(0), - (PtNDArray) inputs.get(1), - (PtNDArray) inputs.get(2), - (PtNDArray) inputs.get(3), + manager.from(inputs.get(0)), + manager.from(inputs.get(1)), + manager.from(inputs.get(2)), + manager.from(inputs.get(3)), learningRate, weightDecay, rescaleGrad, @@ -288,7 +290,7 @@ public void adamUpdate( beta2, epsilon); // call zero-grad - JniUtils.zeroGrad((PtNDArray) weights.singletonOrThrow()); + JniUtils.zeroGrad(manager.from(weights.singletonOrThrow())); } /** {@inheritDoc} */ @@ -333,16 +335,16 @@ public void sgdUpdate( boolean lazyUpdate) { // TODO: Lazy update not used JniUtils.sgdUpdate( - (PtNDArray) inputs.get(0), - (PtNDArray) inputs.get(1), - (momentum == 0f) ? null : (PtNDArray) inputs.get(2), + manager.from(inputs.get(0)), + manager.from(inputs.get(1)), + (momentum == 0f) ? null : manager.from(inputs.get(2)), learningRate, weightDecay, rescaleGrad, clipGrad, momentum); // call zero-grad - JniUtils.zeroGrad((PtNDArray) weights.singletonOrThrow()); + JniUtils.zeroGrad(manager.from(weights.singletonOrThrow())); } /** {@inheritDoc} */ @@ -357,9 +359,9 @@ public NDList convolution( int groups) { return new NDList( JniUtils.convolution( - (PtNDArray) input, - (PtNDArray) weight, - (PtNDArray) bias, + manager.from(input), + manager.from(weight), + manager.from(bias), stride, padding, dilation, @@ -383,7 +385,8 @@ public NDList deconvolution( /** {@inheritDoc} */ @Override public NDList linear(NDArray input, NDArray weight, NDArray bias) { - return new NDList(JniUtils.linear((PtNDArray) input, (PtNDArray) weight, (PtNDArray) bias)); + return new NDList( + JniUtils.linear(manager.from(input), manager.from(weight), manager.from(bias))); } /** {@inheritDoc} */ @@ -394,8 +397,8 @@ public NDList embedding(NDArray input, NDArray weight, SparseFormat sparseFormat } return new NDList( JniUtils.embedding( - (PtNDArray) input, - (PtNDArray) weight, + manager.from(input), + manager.from(weight), sparseFormat.equals(SparseFormat.COO))); } @@ -408,7 +411,7 @@ public NDList prelu(NDArray input, NDArray alpha) { /** {@inheritDoc} */ @Override public NDList dropout(NDArray input, float rate, boolean training) { - return new NDList(JniUtils.dropout((PtNDArray) input, rate, training)); + return new NDList(JniUtils.dropout(manager.from(input), rate, training)); } /** {@inheritDoc} */ @@ -417,10 +420,10 @@ public NDList layerNorm( NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) { return new NDList( JniUtils.layerNorm( - (PtNDArray) input, + manager.from(input), normalizedShape, - (PtNDArray) gamma, - (PtNDArray) beta, + manager.from(gamma), + manager.from(beta), eps)); } /** {@inheritDoc} */ @@ -440,11 +443,11 @@ public NDList batchNorm( if (axis == -1) { return new NDList( JniUtils.batchNorm( - (PtNDArray) input, - (PtNDArray) runningMean, - (PtNDArray) runningVar, - (PtNDArray) gamma, - (PtNDArray) beta, + manager.from(input), + manager.from(runningMean), + manager.from(runningVar), + manager.from(gamma), + manager.from(beta), training, // momentum is defined differently in PyTorch 1f - momentum, @@ -457,11 +460,11 @@ public NDList batchNorm( result = result.swapAxes(1, axis); result = JniUtils.batchNorm( - (PtNDArray) result, - (PtNDArray) runningMean, - (PtNDArray) runningVar, - (PtNDArray) gamma, - (PtNDArray) beta, + manager.from(result), + manager.from(runningMean), + manager.from(runningVar), + manager.from(gamma), + manager.from(beta), training, // momentum is defined differently in PyTorch 1f - momentum, @@ -487,8 +490,8 @@ public NDList rnn( boolean bidirectional, boolean batchFirst) { return JniUtils.rnn( - (PtNDArray) input, - (PtNDArray) state, + manager.from(input), + manager.from(state), params, hasBiases, numLayers, @@ -512,8 +515,8 @@ public NDList gru( boolean bidirectional, boolean batchFirst) { return JniUtils.gru( - (PtNDArray) input, - (PtNDArray) state, + manager.from(input), + manager.from(state), params, hasBiases, numLayers, @@ -536,7 +539,7 @@ public NDList lstm( boolean bidirectional, boolean batchFirst) { return JniUtils.lstm( - (PtNDArray) input, + manager.from(input), states, params, hasBiases, @@ -551,7 +554,7 @@ public NDList lstm( @Override public PtNDArray resize(int width, int height, int interpolation) { // create subManager to help close intermediate NDArray - try (NDManager subManager = array.getManager().newSubManager()) { + try (NDManager subManager = manager.newSubManager()) { array.attach(subManager); NDArray result = array; if (result.isEmpty()) { @@ -567,7 +570,7 @@ public PtNDArray resize(int width, int height, int interpolation) { result = result.transpose(0, 3, 1, 2); result = JniUtils.interpolate( - (PtNDArray) result, + manager.from(result), new long[] {height, width}, getInterpolationMode(interpolation), false) @@ -626,7 +629,7 @@ public PtNDArray where(NDArray condition, NDArray other) { throw new UnsupportedOperationException( "condition and self shape mismatch, broadcast is not supported"); } - return JniUtils.where((PtNDArray) condition, array, (PtNDArray) other); + return JniUtils.where(manager.from(condition), array, manager.from(other)); } /** {@inheritDoc} */ @@ -634,7 +637,10 @@ public PtNDArray where(NDArray condition, NDArray other) { public PtNDArray stack(NDList arrays, int axis) { PtNDArray[] srcArray = new PtNDArray[arrays.size() + 1]; srcArray[0] = array; - System.arraycopy(arrays.toArray(new NDArray[0]), 0, srcArray, 1, arrays.size()); + int i = 1; + for (NDArray arr : arrays) { + srcArray[i++] = manager.from(arr); + } return JniUtils.stack(srcArray, axis); } @@ -645,7 +651,10 @@ public PtNDArray concat(NDList list, int axis) { PtNDArray[] srcArray = new PtNDArray[list.size() + 1]; srcArray[0] = array; - System.arraycopy(list.toArray(new NDArray[0]), 0, srcArray, 1, list.size()); + int i = 1; + for (NDArray arr : list) { + srcArray[i++] = manager.from(arr); + } return JniUtils.cat(srcArray, axis); } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index d1ecd903ad8..6959e11b421 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -47,7 +47,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public PtNDArray from(NDArray array) { - if (array instanceof PtNDArray) { + if (array == null || array instanceof PtNDArray) { return (PtNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java index 12c03e739ae..cc923e63955 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDManager.java @@ -51,7 +51,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public TfNDArray from(NDArray array) { - if (array instanceof TfNDArray) { + if (array == null || array instanceof TfNDArray) { return (TfNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java index 6691d7552c3..3fb4c5d70db 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtNDManager.java @@ -55,7 +55,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public TrtNDArray from(NDArray array) { - if (array instanceof TrtNDArray) { + if (array == null || array instanceof TrtNDArray) { return (TrtNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java index fae622ddf46..5ba7228a3ba 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDManager.java @@ -49,7 +49,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public TfLiteNDArray from(NDArray array) { - if (array instanceof TfLiteNDArray) { + if (array == null || array instanceof TfLiteNDArray) { return (TfLiteNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType());