diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index e00903a982e..74dd50f0b96 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -216,6 +216,24 @@ default byte[] encode() { */ boolean hasGradient(); + /** + * Returns an NDArray equal to this that stop gradient propagation through it. + * + * @return an NDArray equal to this that stops gradient propagation through it + */ + NDArray stopGradient(); + + /** + * Returns an NDArray equal to this that magnifies the gradient propagated to this by a + * constant. + * + * @param scale how to much to magnify the gradient propagated to this + * @return an NDArray equal to this that magnifies the gradient propagated to this by a constant + */ + default NDArray scaleGradient(double scale) { + return this.mul(scale).add(this.stopGradient().mul(1 - scale)); + } + /** * Returns the size of this {@code NDArray} along a given axis. * diff --git a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java index 0634ee9a381..9ee1254587d 100644 --- a/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java +++ b/dlr/dlr-engine/src/main/java/ai/djl/dlr/engine/DlrNDArray.java @@ -154,6 +154,12 @@ public boolean hasGradient() { return false; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not supported for DLR"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index bd4c8becf7c..616c5747d74 100644 --- a/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -268,6 +268,11 @@ public boolean hasGradient() { return hasGradient; } + @Override + public NDArray stopGradient() { + return manager.invoke("stop_gradient", this, null); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java index fd8f0ba210d..ab485883598 100644 --- a/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java +++ b/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtNDArray.java @@ -165,6 +165,12 @@ public boolean hasGradient() { return false; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not supported for ONNX Runtime"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java index 95d290d9ce8..e6918826561 100644 --- a/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java +++ b/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpNDArray.java @@ -177,6 +177,12 @@ public boolean hasGradient() { return false; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not supported"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 97b830af3f1..2c5a37d7365 100644 --- a/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -199,6 +199,12 @@ public boolean hasGradient() { return hasGradient; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not supported"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() { diff --git a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index f137e6b676e..2739ef3db00 100644 --- a/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -178,6 +178,12 @@ public boolean hasGradient() { return false; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public double[] toDoubleArray() { diff --git a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java index 6f4e11c7bd2..86f9983379a 100644 --- a/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java +++ b/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteNDArray.java @@ -192,6 +192,12 @@ public boolean hasGradient() { return false; } + /** {@inheritDoc} */ + @Override + public NDArray stopGradient() { + throw new UnsupportedOperationException("Not supported for TFLite"); + } + /** {@inheritDoc} */ @Override public ByteBuffer toByteBuffer() {