From 04d7d7d13c2ceac0c6625bff1a9d50bd8f2132cb Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Thu, 15 Sep 2022 14:27:29 +0800 Subject: [PATCH 01/14] feature: add distribution and loss function --- .../timeseries/distribution/Distribution.java | 71 +++++++++++ .../distribution/DistributionLoss.java | 60 ++++++++++ .../distribution/NegativeBinomial.java | 79 ++++++++++++ .../djl/timeseries/distribution/StudentT.java | 68 +++++++++++ .../distribution/output/ArgProj.java | 112 ++++++++++++++++++ .../output/DistributionOutput.java | 73 ++++++++++++ .../output/NegativeBinomialOutput.java | 49 ++++++++ .../distribution/output/StudentTOutput.java | 52 ++++++++ 8 files changed, 564 insertions(+) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java new file mode 100644 index 00000000000..056d95a315c --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -0,0 +1,71 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +public abstract class Distribution { + + /** + * Compute the log of the probability density/mass function evaluated at target + * + * @param target {@link NDArray} of shape (*batch_shape, *event_shape) + * @return Tensor of shape (batch_shape) containing the probability log-density for each event + * in target + */ + public abstract NDArray logProb(NDArray target); + + /** + * Draw samples from the distribution. + * + *
The first dimension of the output will be numSamples.
+ *
+ * @param numSamples Number of samples to be drawn
+ * @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape)
+ */
+ public abstract NDArray sample(int numSamples);
+
+ /**
+ * Return the mean of the distribution.
+ *
+ * @return the mean of the distribution
+ */
+ public abstract NDArray mean();
+
+ public abstract static class DistributionBuilder This function is usually used as the lambda of the Lambda block.
+ *
+ * @param arrays the arguments
+ * @return converted arguments
+ */
+ public abstract NDList domainMap(NDList arrays);
+
+ /**
+ * Return the associated {@code DistributionBuilder}, given the collection of constructor arguments and, optionally, a scale tensor.
+ *
+ * @return the associated {@code DistributionBuilder}
+ */
+ public abstract Distribution.DistributionBuilder> distributionBuilder();
+}
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
new file mode 100644
index 00000000000..9eab777de65
--- /dev/null
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+package ai.djl.timeseries.distribution.output;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.timeseries.distribution.Distribution;
+import ai.djl.timeseries.distribution.DistributionLoss;
+import ai.djl.timeseries.distribution.NegativeBinomial;
+import ai.djl.util.PairList;
+
+public final class NegativeBinomialOutput extends DistributionOutput {
+
+ public NegativeBinomialOutput() {
+ argsDim = new PairList<>(2);
+ argsDim.add("mu", 1);
+ argsDim.add("alpha", 1);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList domainMap(NDList arrays) {
+ NDArray mu = arrays.get(0);
+ NDArray alpha = arrays.get(1);
+ mu = mu.getNDArrayInternal().softPlus().squeeze(-1);
+ alpha = alpha.getNDArrayInternal().softPlus().squeeze(-1);
+ // TODO: make setName() must be implemented
+ mu.setName("mu");
+ alpha.setName("alpha");
+ return new NDList(mu, alpha);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Distribution.DistributionBuilder> distributionBuilder() {
+ return NegativeBinomial.builder();
+ }
+}
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
new file mode 100644
index 00000000000..0791629de54
--- /dev/null
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+package ai.djl.timeseries.distribution.output;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.timeseries.distribution.Distribution;
+import ai.djl.timeseries.distribution.DistributionLoss;
+import ai.djl.util.PairList;
+
+public class StudentTOutput extends DistributionOutput {
+
+ public StudentTOutput() {
+ argsDim = new PairList<>(3);
+ argsDim.add("mu", 1);
+ argsDim.add("sigma", 1);
+ argsDim.add("nu", 1);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDList domainMap(NDList arrays) {
+ NDArray mu = arrays.get(0);
+ NDArray sigma = arrays.get(1);
+ NDArray nu = arrays.get(2);
+ mu = mu.squeeze(-1);
+ sigma = sigma.getNDArrayInternal().softPlus().squeeze(-1);
+ nu = nu.getNDArrayInternal().softPlus().add(2.).squeeze(-1);
+ // TODO: make setName() must be implemented
+ mu.setName("mu");
+ sigma.setName("sigma");
+ nu.setName("nu");
+ return new NDList(mu, sigma, nu);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Distribution.DistributionBuilder> distributionBuilder() {
+ return null;
+ }
+}
From 2e748f0da8c10022ca4ca82bd37ddf63e97dfcd3 Mon Sep 17 00:00:00 2001
From: Carkham <1302112560@qq.com>
Date: Fri, 16 Sep 2022 14:12:14 +0800
Subject: [PATCH 02/14] bug fix and add unit test
---
.../djl/timeseries/distribution/StudentT.java | 4 +-
.../distribution/DistributionTest.java | 48 +++++++++++++++++++
.../timeseries/distribution/package-info.java | 1 +
3 files changed, 51 insertions(+), 2 deletions(-)
create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
index 2933411f3f0..5379f96fa7a 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
@@ -23,11 +23,11 @@ public NDArray logProb(NDArray target) {
NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square());
NDArray z = nup1Half.gammaln()
- .sub(nu.sub(2.).gammaln())
+ .sub(nu.div(2.).gammaln())
.sub(nu.mul(Math.PI).log().mul(0.5))
.sub(sigma.log());
- return z.sub(nup1Half).mul(part1.add(1.).log());
+ return z.sub(nup1Half.mul(part1.add(1.).log()));
}
@Override
diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
new file mode 100644
index 00000000000..fffaba4cafb
--- /dev/null
+++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
@@ -0,0 +1,48 @@
+package ai.djl.timeseries.distribution;
+
+import ai.djl.ndarray.NDArray;
+import ai.djl.ndarray.NDList;
+import ai.djl.ndarray.NDManager;
+import ai.djl.testing.Assertions;
+import org.testng.annotations.Test;
+
+public class DistributionTest {
+
+ @Test
+ public void testNegativeBinomial() {
+ try (NDManager manager = NDManager.newBaseManager()) {
+ NDArray mu = manager.create(new float[]{1000f, 1f});
+ NDArray alpha = manager.create(new float[]{1f, 2f});
+ mu.setName("mu");
+ alpha.setName("alpha");
+ Distribution negativeBinomial = NegativeBinomial
+ .builder()
+ .setDistrArgs(new NDList(mu, alpha))
+ .build();
+
+ NDArray expected = manager.create(new float[]{-6.9098f, -1.6479f});
+ NDArray real = negativeBinomial.logProb(manager.create(new float[]{1f, 1f}));
+ Assertions.assertAlmostEquals(real, expected);
+ }
+ }
+
+ @Test
+ public void testStudentT() {
+ try (NDManager manager = NDManager.newBaseManager()) {
+ NDArray mu = manager.create(new float[]{1000f, -1000f});
+ NDArray sigma = manager.create(new float[]{1f, 2f});
+ NDArray nu = manager.create(new float[]{4.2f, 3f});
+ mu.setName("mu");
+ sigma.setName("sigma");
+ nu.setName("nu");
+ Distribution studentT = StudentT
+ .builder()
+ .setDistrArgs(new NDList(mu, sigma, nu))
+ .build();
+
+ NDArray expected = manager.create(new float[]{-0.9779f, -1.6940f});
+ NDArray real = studentT.logProb(manager.create(new float[]{1000f, -1000f}));
+ Assertions.assertAlmostEquals(real, expected);
+ }
+ }
+}
diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
new file mode 100644
index 00000000000..3835506616d
--- /dev/null
+++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
@@ -0,0 +1 @@
+package ai.djl.timeseries.distribution;
\ No newline at end of file
From 78c992acce1b2080765b44b0cb679a903e63ec6e Mon Sep 17 00:00:00 2001
From: Carkham <1302112560@qq.com>
Date: Fri, 16 Sep 2022 15:07:30 +0800
Subject: [PATCH 03/14] add comments and unit test
---
.../timeseries/distribution/Distribution.java | 27 +++++++++++-
.../distribution/DistributionLoss.java | 19 +++++----
.../distribution/NegativeBinomial.java | 29 ++++++++++---
.../djl/timeseries/distribution/StudentT.java | 42 +++++++++++++++----
.../distribution/output/ArgProj.java | 35 +++++++++++++++-
.../output/DistributionOutput.java | 19 ++++++---
.../output/NegativeBinomialOutput.java | 6 ++-
.../distribution/output/StudentTOutput.java | 5 ++-
.../distribution/DistributionTest.java | 31 +++++++-------
9 files changed, 165 insertions(+), 48 deletions(-)
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
index 056d95a315c..31358b185e6 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
@@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
+/** An abstract class representing probability distribution. */
public abstract class Distribution {
/**
@@ -23,7 +24,7 @@ public abstract class Distribution {
*
* @param target {@link NDArray} of shape (*batch_shape, *event_shape)
* @return Tensor of shape (batch_shape) containing the probability log-density for each event
- * in target
+ * in target
*/
public abstract NDArray logProb(NDArray target);
@@ -44,21 +45,45 @@ public abstract class Distribution {
*/
public abstract NDArray mean();
+ /**
+ * A builder to extend for all classes extend the {@link Distribution}.
+ *
+ * @param Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point
+ */
public class DistributionLoss extends Loss {
private DistributionOutput distrOutput;
/**
- * Base class for metric with abstract update methods.
+ * Calculates Distribution Loss between the label and distribution.
*
- * @param name The display name of the Loss
+ * @param name the name of the loss
+ * @param distrOutput the {@link DistributionOutput} to construct the target distribution
*/
public DistributionLoss(String name, DistributionOutput distrOutput) {
super(name);
@@ -49,11 +55,10 @@ public NDArray evaluate(NDList labels, NDList predictions) {
if (predictions.contains("loss_weights")) {
NDArray lossWeights = predictions.get("loss_weights");
- NDArray weightedValue = NDArrays.where(
- lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike()
- );
- NDArray sumWeights = lossWeights.sum(new int[]{1}).maximum(1.);
- loss = weightedValue.sum(new int[]{1}).div(sumWeights);
+ NDArray weightedValue =
+ NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike());
+ NDArray sumWeights = lossWeights.sum(new int[] {1}).maximum(1.);
+ loss = weightedValue.sum(new int[] {1}).div(sumWeights);
}
return loss;
}
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
index e04b244f325..5851136ae6f 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
@@ -14,10 +14,17 @@
package ai.djl.timeseries.distribution;
import ai.djl.ndarray.NDArray;
-import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.util.Preconditions;
+/**
+ * Negative binomial distribution.
+ *
+ * The distribution of the number of successes in a sequence of independent Bernoulli trials.
+ *
+ * Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the
+ * inverse number of negative Bernoulli trials to stop
+ */
public final class NegativeBinomial extends Distribution {
private NDArray mu;
@@ -28,20 +35,21 @@ public final class NegativeBinomial extends Distribution {
alpha = builder.distrArgs.get("alpha");
}
+ /** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {
NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1);
NDArray alphaTimesMu = alpha.mul(mu);
- return target
- .mul(alphaTimesMu.div(alphaTimesMu.add(1)).log())
+ return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log())
.sub(alphaInv.mul(alphaTimesMu.add(1).log()))
.add(target.add(alphaInv).gammaln())
.sub(target.add(1.).gammaln())
.sub(alphaInv.gammaln());
}
+ /** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
@@ -53,24 +61,35 @@ public NDArray sample(int numSamples) {
return manager.samplePoisson(manager.sampleGamma(r, theta));
}
+ /** {@inheritDoc} */
@Override
public NDArray mean() {
return mu;
}
+ /**
+ * Creates a builder to build a {@code NegativeBinomial}.
+ *
+ * @return a new builder
+ */
public static Builder builder() {
return new Builder();
}
+ /** The builder to construct a {@code NegativeBinomial}. */
public static final class Builder extends DistributionBuilder Three arguments for this distribution. {@code mu} mean of the distribution, {@code sigma} the
+ * standard deviations (scale), {@code nu} degrees of freedom.
+ */
public class StudentT extends Distribution {
-
+
private NDArray mu;
private NDArray sigma;
private NDArray nu;
@@ -17,19 +23,22 @@ public class StudentT extends Distribution {
nu = builder.distrArgs.get("nu");
}
+ /** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {
NDArray nup1Half = nu.add(1.).div(2.);
NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square());
- NDArray z = nup1Half.gammaln()
- .sub(nu.div(2.).gammaln())
- .sub(nu.mul(Math.PI).log().mul(0.5))
- .sub(sigma.log());
+ NDArray z =
+ nup1Half.gammaln()
+ .sub(nu.div(2.).gammaln())
+ .sub(nu.mul(Math.PI).log().mul(0.5))
+ .sub(sigma.log());
return z.sub(nup1Half.mul(part1.add(1.).log()));
}
+ /** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
@@ -37,29 +46,44 @@ public NDArray sample(int numSamples) {
NDArray expandedSigma = sigma.expandDims(0).repeat(0, numSamples);
NDArray expandedNu = nu.expandDims(0).repeat(0, numSamples);
- NDArray gammas = manager.sampleGamma(expandedNu.div(2.), expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.));
+ NDArray gammas =
+ manager.sampleGamma(
+ expandedNu.div(2.),
+ expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.));
return manager.sampleNormal(expandedMu, gammas.sqrt().getNDArrayInternal().rdiv(1.));
}
+ /** {@inheritDoc} */
@Override
public NDArray mean() {
return NDArrays.where(nu.gt(1.0), mu, mu.getManager().full(mu.getShape(), Float.NaN));
}
+ /**
+ * Creates a builder to build a {@code NegativeBinomial}.
+ *
+ * @return a new builder
+ */
public static Builder builder() {
return new Builder();
}
+ /** The builder to construct a {@code NegativeBinomial}. */
public static final class Builder extends DistributionBuilder By default {@code 0f}. This value will be used when padding data series.
+ *
+ * @return the valueInSupport
+ */
public float getValueInSupport() {
return valueInSupport;
}
/**
- * Return the corresponding projection block based on the args dimension of different
- * ditributions.
+ * Return the corresponding projection block based on the arguments dimension of different
+ * distributions.
*
* @return the corresponding projection block
*/
@@ -38,7 +46,7 @@ public ArgProj getArgsProj() {
}
/**
- * Return the corresponding projection block based on the args dimension of different
+ * Return the corresponding projection block based on the arguments dimension of different
* ditributions.
*
* @param prefix the prefix string of projection layer block
@@ -65,7 +73,8 @@ public ArgProj getArgsProj(String prefix) {
public abstract NDList domainMap(NDList arrays);
/**
- * Return the associated {@code DistributionBuilder}, given the collection of constructor arguments and, optionally, a scale tensor.
+ * Return the associated {@code DistributionBuilder}, given the collection of constructor
+ * arguments and, optionally, a scale tensor.
*
* @return the associated {@code DistributionBuilder}
*/
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
index 9eab777de65..610621fc6c8 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
@@ -16,12 +16,16 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
-import ai.djl.timeseries.distribution.DistributionLoss;
import ai.djl.timeseries.distribution.NegativeBinomial;
import ai.djl.util.PairList;
+/**
+ * {@code NegativeBinomialOutput} is a {@link DistributionOutput} that for negative binomial
+ * distribution.
+ */
public final class NegativeBinomialOutput extends DistributionOutput {
+ /** Construct a negative binomial output with two arguments, {@code mu} and {@code alpha}. */
public NegativeBinomialOutput() {
argsDim = new PairList<>(2);
argsDim.add("mu", 1);
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
index 0791629de54..156c83f71b3 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
@@ -16,11 +16,14 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
-import ai.djl.timeseries.distribution.DistributionLoss;
import ai.djl.util.PairList;
+/**
+ * {@code StudentTOutput} is a {@link DistributionOutput} that for Student's t-test distribution.
+ */
public class StudentTOutput extends DistributionOutput {
+ /** Construct a negative binomial output with two arguments, {@code mu} and {@code sigma}. */
public StudentTOutput() {
argsDim = new PairList<>(3);
argsDim.add("mu", 1);
diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
index fffaba4cafb..8ec50da5dbd 100644
--- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
+++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
@@ -4,6 +4,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.testing.Assertions;
+
import org.testng.annotations.Test;
public class DistributionTest {
@@ -11,17 +12,15 @@ public class DistributionTest {
@Test
public void testNegativeBinomial() {
try (NDManager manager = NDManager.newBaseManager()) {
- NDArray mu = manager.create(new float[]{1000f, 1f});
- NDArray alpha = manager.create(new float[]{1f, 2f});
+ NDArray mu = manager.create(new float[] {1000f, 1f});
+ NDArray alpha = manager.create(new float[] {1f, 2f});
mu.setName("mu");
alpha.setName("alpha");
- Distribution negativeBinomial = NegativeBinomial
- .builder()
- .setDistrArgs(new NDList(mu, alpha))
- .build();
+ Distribution negativeBinomial =
+ NegativeBinomial.builder().setDistrArgs(new NDList(mu, alpha)).build();
- NDArray expected = manager.create(new float[]{-6.9098f, -1.6479f});
- NDArray real = negativeBinomial.logProb(manager.create(new float[]{1f, 1f}));
+ NDArray expected = manager.create(new float[] {-6.9098f, -1.6479f});
+ NDArray real = negativeBinomial.logProb(manager.create(new float[] {1f, 1f}));
Assertions.assertAlmostEquals(real, expected);
}
}
@@ -29,19 +28,17 @@ public void testNegativeBinomial() {
@Test
public void testStudentT() {
try (NDManager manager = NDManager.newBaseManager()) {
- NDArray mu = manager.create(new float[]{1000f, -1000f});
- NDArray sigma = manager.create(new float[]{1f, 2f});
- NDArray nu = manager.create(new float[]{4.2f, 3f});
+ NDArray mu = manager.create(new float[] {1000f, -1000f});
+ NDArray sigma = manager.create(new float[] {1f, 2f});
+ NDArray nu = manager.create(new float[] {4.2f, 3f});
mu.setName("mu");
sigma.setName("sigma");
nu.setName("nu");
- Distribution studentT = StudentT
- .builder()
- .setDistrArgs(new NDList(mu, sigma, nu))
- .build();
+ Distribution studentT =
+ StudentT.builder().setDistrArgs(new NDList(mu, sigma, nu)).build();
- NDArray expected = manager.create(new float[]{-0.9779f, -1.6940f});
- NDArray real = studentT.logProb(manager.create(new float[]{1000f, -1000f}));
+ NDArray expected = manager.create(new float[] {-0.9779f, -1.6940f});
+ NDArray real = studentT.logProb(manager.create(new float[] {1000f, -1000f}));
Assertions.assertAlmostEquals(real, expected);
}
}
From 50d9c7adb07925d0671a3a60a07d1393bb6111df Mon Sep 17 00:00:00 2001
From: Carkham <1302112560@qq.com>
Date: Fri, 16 Sep 2022 15:23:27 +0800
Subject: [PATCH 04/14] add copyright and package-info
---
.../ai/djl/timeseries/distribution/StudentT.java | 13 +++++++++++++
.../djl/timeseries/distribution/package-info.java | 15 +++++++++++++++
.../timeseries/distribution/DistributionTest.java | 13 +++++++++++++
.../djl/timeseries/distribution/package-info.java | 14 ++++++++++++++
4 files changed, 55 insertions(+)
create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
index 22090c34e96..8854ee9d165 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
@@ -1,3 +1,16 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
package ai.djl.timeseries.distribution;
import ai.djl.ndarray.NDArray;
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java
new file mode 100644
index 00000000000..5f18cc5f60e
--- /dev/null
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains classes to support distribution in djl. */
+package ai.djl.timeseries.distribution;
diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
index 8ec50da5dbd..051f7f4743b 100644
--- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
+++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java
@@ -1,3 +1,16 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
package ai.djl.timeseries.distribution;
import ai.djl.ndarray.NDArray;
diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
index 3835506616d..d6d06167d43 100644
--- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
+++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java
@@ -1 +1,15 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains tests for the distribution module. */
package ai.djl.timeseries.distribution;
\ No newline at end of file
From 8048e1bccc75799588b208d1f30909da6cab0045 Mon Sep 17 00:00:00 2001
From: Carkham <1302112560@qq.com>
Date: Fri, 16 Sep 2022 15:34:34 +0800
Subject: [PATCH 05/14] style fix
---
.../djl/timeseries/distribution/Distribution.java | 9 +++++++--
.../timeseries/distribution/DistributionLoss.java | 2 +-
.../timeseries/distribution/output/ArgProj.java | 8 ++++----
.../distribution/output/DistributionOutput.java | 2 +-
.../distribution/output/package-info.java | 15 +++++++++++++++
5 files changed, 28 insertions(+), 8 deletions(-)
create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
index 31358b185e6..1adcdd4ce77 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
@@ -20,7 +20,7 @@
public abstract class Distribution {
/**
- * Compute the log of the probability density/mass function evaluated at target
+ * Compute the log of the probability density/mass function evaluated at target.
*
* @param target {@link NDArray} of shape (*batch_shape, *event_shape)
* @return Tensor of shape (batch_shape) containing the probability log-density for each event
@@ -79,7 +79,7 @@ public T optScale(NDArray scale) {
}
/**
- * Set the affine location of the probability
+ * Set the affine location of the probability.
*
* @param loc the affine location
* @return this builder
@@ -89,6 +89,11 @@ public T optLoc(NDArray loc) {
return self();
}
+ /**
+ * Build a {@code Distribution}.
+ *
+ * @return the {@code Distribution}
+ */
public abstract Distribution build();
protected abstract T self();
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
index cc3b12bb3bc..30d7a2f35d9 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
@@ -20,7 +20,7 @@
import ai.djl.training.loss.Loss;
/**
- * {@code DistributionLoss} calculates loss for a given distribution
+ * {@code DistributionLoss} calculates loss for a given distribution.
*
* Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point
*/
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java
index 72e6aae3bb8..e47c04e1782 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java
@@ -29,7 +29,7 @@
import java.util.List;
import java.util.function.Function;
-/** A Block that can be used to project from a dense layer to distribution arguments */
+/** A Block that can be used to project from a dense layer to distribution arguments. */
public final class ArgProj extends AbstractBlock {
private Block domainMap;
@@ -97,7 +97,7 @@ public static final class Builder {
private String prefix = "";
/**
- * Set the arguments dimensions of distribution
+ * Set the arguments dimensions of distribution.
*
* @param argsDim the arguments dimension
* @return this builder
@@ -108,7 +108,7 @@ public Builder setArgsDim(PairList This is the distribution of Y = scale * X + loc, where X is a random variable distributed
+ * according to {@code baseDistribution}.
+ *
+ * @param baseDistribution original distribution
+ * @param loc translation parameter of the affine transformation
+ * @param scale scaling parameter of the affine transformation
+ */
+ public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) {
+ this.baseDistribution = baseDistribution;
+ this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc;
+ this.scale = scale == null ? baseDistribution.mean().onesLike() : scale;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray logProb(NDArray target) {
+ NDArray x = fInv(target);
+ NDArray ladj = logAbsDetJac(x);
+ NDArray lp = ladj.mul(-1);
+ return baseDistribution.logProb(x).add(lp);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray sample(int numSamples) {
+ NDArray sample = baseDistribution.sample(numSamples);
+ return f(sample);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public NDArray mean() {
+ return baseDistribution.mean().mul(scale).add(loc);
+ }
+
+ private NDArray f(NDArray x) {
+ return x.mul(scale).add(loc);
+ }
+
+ private NDArray fInv(NDArray y) {
+ return y.sub(loc).div(scale);
+ }
+
+ private NDArray logAbsDetJac(NDArray x) {
+ return scale.broadcast(x.getShape()).abs().log();
+ }
+}
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
index 5851136ae6f..b3a90c9d94b 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
@@ -86,6 +86,15 @@ public Distribution build() {
distrArgs.contains("mu"), "NegativeBinomial's args must contain mu.");
Preconditions.checkArgument(
distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha.");
+ // We cannot scale using the affine transformation since negative binomial should return
+ // integers. Instead we scale the parameters.
+ if (scale != null) {
+ NDArray mu = distrArgs.get("mu");
+ mu = mu.mul(scale);
+ mu.setName("mu");
+ distrArgs.remove("mu");
+ distrArgs.add(mu);
+ }
return new NegativeBinomial(this);
}
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
index 8854ee9d165..1086f052ac1 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
@@ -93,7 +93,11 @@ public Distribution build() {
distrArgs.contains("sigma"), "StudentTl's args must contain sigma.");
Preconditions.checkArgument(
distrArgs.contains("nu"), "StudentTl's args must contain nu.");
- return new StudentT(this);
+ StudentT baseDistr = new StudentT(this);
+ if (scale == null && loc == null) {
+ return baseDistr;
+ }
+ return new AffineTransformed(baseDistr, loc, scale);
}
/** {@inheritDoc} */
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java
index 81c932221ad..de3538df489 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java
@@ -60,6 +60,11 @@ public ArgProj getArgsProj(String prefix) {
.build();
}
+ /**
+ * Return an array containing all the argument names.
+ *
+ * @return an array containing all the argument names
+ */
public String[] getArgsArray() {
return argsDim.keyArray(new String[argsDim.size()]);
}
From 15f3b9c6c44edf349910b08c221c528e07cbf797 Mon Sep 17 00:00:00 2001
From: Carkham <13193102258@163.com>
Date: Mon, 19 Sep 2022 00:10:52 +0800
Subject: [PATCH 08/14] feature: add sample for zero num
---
.../djl/timeseries/distribution/Distribution.java | 14 +++++++++++++-
.../timeseries/distribution/NegativeBinomial.java | 4 ++--
.../ai/djl/timeseries/distribution/StudentT.java | 6 +++---
3 files changed, 18 insertions(+), 6 deletions(-)
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
index 1adcdd4ce77..739694cb9a7 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java
@@ -31,13 +31,25 @@ public abstract class Distribution {
/**
* Draw samples from the distribution.
*
- * The first dimension of the output will be numSamples.
+ * This function would expand the dimension of arguments, the first dimension of the output
+ * will be numSamples.
*
* @param numSamples Number of samples to be drawn
* @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape)
*/
public abstract NDArray sample(int numSamples);
+ /**
+ * Draw samples from the distribution.
+ *
+ * This function would not expand the dimension
+ *
+ * @return a sampled {@link NDArray}
+ */
+ public NDArray sample() {
+ return sample(0);
+ }
+
/**
* Return the mean of the distribution.
*
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
index b3a90c9d94b..8565babc244 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java
@@ -53,8 +53,8 @@ public NDArray logProb(NDArray target) {
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
- NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples);
- NDArray expandedAlpha = alpha.expandDims(0).repeat(0, numSamples);
+ NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu;
+ NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha;
NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f);
NDArray theta = expandedAlpha.mul(expandedMu);
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
index 1086f052ac1..34168acd601 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java
@@ -55,9 +55,9 @@ public NDArray logProb(NDArray target) {
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
- NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples);
- NDArray expandedSigma = sigma.expandDims(0).repeat(0, numSamples);
- NDArray expandedNu = nu.expandDims(0).repeat(0, numSamples);
+ NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu;
+ NDArray expandedSigma = numSamples > 0 ? sigma.expandDims(0).repeat(0, numSamples) : sigma;
+ NDArray expandedNu = numSamples > 0 ? nu.expandDims(0).repeat(0, numSamples) : nu;
NDArray gammas =
manager.sampleGamma(
From 0f39651012301e79dd4cb887a43f76ef034ba124 Mon Sep 17 00:00:00 2001
From: Carkham <13193102258@163.com>
Date: Mon, 19 Sep 2022 21:56:47 +0800
Subject: [PATCH 09/14] bug fix: sum the loss
---
.../java/ai/djl/timeseries/distribution/DistributionLoss.java | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
index 30d7a2f35d9..30251e71728 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java
@@ -57,8 +57,8 @@ public NDArray evaluate(NDList labels, NDList predictions) {
NDArray lossWeights = predictions.get("loss_weights");
NDArray weightedValue =
NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike());
- NDArray sumWeights = lossWeights.sum(new int[] {1}).maximum(1.);
- loss = weightedValue.sum(new int[] {1}).div(sumWeights);
+ NDArray sumWeights = lossWeights.sum().maximum(1.);
+ loss = weightedValue.sum().div(sumWeights);
}
return loss;
}
From d8219931c0ba34d9c7c10d5b628804fa96700d5f Mon Sep 17 00:00:00 2001
From: Carkham <13193102258@163.com>
Date: Tue, 20 Sep 2022 00:10:13 +0800
Subject: [PATCH 10/14] bug fix: add builder for StudentTOutput and clamp the
min value in neg_binomial
---
.../distribution/output/NegativeBinomialOutput.java | 4 ++--
.../ai/djl/timeseries/distribution/output/StudentTOutput.java | 3 ++-
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
index 610621fc6c8..0131eb946eb 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java
@@ -37,8 +37,8 @@ public NegativeBinomialOutput() {
public NDList domainMap(NDList arrays) {
NDArray mu = arrays.get(0);
NDArray alpha = arrays.get(1);
- mu = mu.getNDArrayInternal().softPlus().squeeze(-1);
- alpha = alpha.getNDArrayInternal().softPlus().squeeze(-1);
+ mu = mu.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1);
+ alpha = alpha.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1);
// TODO: make setName() must be implemented
mu.setName("mu");
alpha.setName("alpha");
diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
index 156c83f71b3..b0d775f99ea 100644
--- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
+++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java
@@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.Distribution;
+import ai.djl.timeseries.distribution.StudentT;
import ai.djl.util.PairList;
/**
@@ -50,6 +51,6 @@ public NDList domainMap(NDList arrays) {
/** {@inheritDoc} */
@Override
public Distribution.DistributionBuilder> distributionBuilder() {
- return null;
+ return StudentT.builder();
}
}
From dcf44cf169d55cf36d62a8992aa9c561079fb174 Mon Sep 17 00:00:00 2001
From: KexinFeng