From 04d7d7d13c2ceac0c6625bff1a9d50bd8f2132cb Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Thu, 15 Sep 2022 14:27:29 +0800 Subject: [PATCH 01/14] feature: add distribution and loss function --- .../timeseries/distribution/Distribution.java | 71 +++++++++++ .../distribution/DistributionLoss.java | 60 ++++++++++ .../distribution/NegativeBinomial.java | 79 ++++++++++++ .../djl/timeseries/distribution/StudentT.java | 68 +++++++++++ .../distribution/output/ArgProj.java | 112 ++++++++++++++++++ .../output/DistributionOutput.java | 73 ++++++++++++ .../output/NegativeBinomialOutput.java | 49 ++++++++ .../distribution/output/StudentTOutput.java | 52 ++++++++ 8 files changed, 564 insertions(+) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java new file mode 100644 index 00000000000..056d95a315c --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -0,0 +1,71 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +public abstract class Distribution { + + /** + * Compute the log of the probability density/mass function evaluated at target + * + * @param target {@link NDArray} of shape (*batch_shape, *event_shape) + * @return Tensor of shape (batch_shape) containing the probability log-density for each event + * in target + */ + public abstract NDArray logProb(NDArray target); + + /** + * Draw samples from the distribution. + * + *

The first dimension of the output will be numSamples. + * + * @param numSamples Number of samples to be drawn + * @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape) + */ + public abstract NDArray sample(int numSamples); + + /** + * Return the mean of the distribution. + * + * @return the mean of the distribution + */ + public abstract NDArray mean(); + + public abstract static class DistributionBuilder> { + protected NDList distrArgs; + protected NDArray scale; + protected NDArray loc; + + public T setDistrArgs(NDList distrArgs) { + this.distrArgs = distrArgs; + return self(); + } + + public T optScale(NDArray scale) { + this.scale = scale; + return self(); + } + + public T optLoc(NDArray loc) { + this.loc = loc; + return self(); + } + + public abstract Distribution build(); + + protected abstract T self(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java new file mode 100644 index 00000000000..5a1f7800ef3 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java @@ -0,0 +1,60 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.output.DistributionOutput; +import ai.djl.training.loss.Loss; + +public class DistributionLoss extends Loss { + + private DistributionOutput distrOutput; + + /** + * Base class for metric with abstract update methods. + * + * @param name The display name of the Loss + */ + public DistributionLoss(String name, DistributionOutput distrOutput) { + super(name); + this.distrOutput = distrOutput; + } + + /** {@inheritDoc} */ + @Override + public NDArray evaluate(NDList labels, NDList predictions) { + Distribution.DistributionBuilder builder = distrOutput.distributionBuilder(); + builder.setDistrArgs(predictions); + if (predictions.contains("scale")) { + builder.optScale(predictions.get("scale")); + } + if (predictions.contains("loc")) { + builder.optLoc(predictions.get("loc")); + } + + NDArray loss = builder.build().logProb(labels.singletonOrThrow()).mul(-1); + + if (predictions.contains("loss_weights")) { + NDArray lossWeights = predictions.get("loss_weights"); + NDArray weightedValue = NDArrays.where( + lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike() + ); + NDArray sumWeights = lossWeights.sum(new int[]{1}).maximum(1.); + loss = weightedValue.sum(new int[]{1}).div(sumWeights); + } + return loss; + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java new file mode 100644 index 00000000000..e04b244f325 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -0,0 +1,79 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Preconditions; + +public final class NegativeBinomial extends Distribution { + + private NDArray mu; + private NDArray alpha; + + NegativeBinomial(Builder builder) { + mu = builder.distrArgs.get("mu"); + alpha = builder.distrArgs.get("alpha"); + } + + @Override + public NDArray logProb(NDArray target) { + + NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1); + NDArray alphaTimesMu = alpha.mul(mu); + + return target + .mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) + .sub(alphaInv.mul(alphaTimesMu.add(1).log())) + .add(target.add(alphaInv).gammaln()) + .sub(target.add(1.).gammaln()) + .sub(alphaInv.gammaln()); + } + + @Override + public NDArray sample(int numSamples) { + NDManager manager = mu.getManager(); + NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples); + NDArray expandedAlpha = alpha.expandDims(0).repeat(0, numSamples); + + NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f); + NDArray theta = expandedAlpha.mul(expandedMu); + return manager.samplePoisson(manager.sampleGamma(r, theta)); + } + + @Override + public NDArray mean() { + return mu; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder extends DistributionBuilder { + + @Override + public Distribution build() { + Preconditions.checkArgument(distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); + Preconditions.checkArgument(distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); + return new NegativeBinomial(this); + } + + @Override + protected Builder self() { + return this; + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java new file mode 100644 index 00000000000..2933411f3f0 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -0,0 +1,68 @@ +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrays; +import ai.djl.ndarray.NDManager; +import ai.djl.util.Preconditions; + +public class StudentT extends Distribution { + + private NDArray mu; + private NDArray sigma; + private NDArray nu; + + StudentT(Builder builder) { + mu = builder.distrArgs.get("mu"); + sigma = builder.distrArgs.get("sigma"); + nu = builder.distrArgs.get("nu"); + } + + @Override + public NDArray logProb(NDArray target) { + NDArray nup1Half = nu.add(1.).div(2.); + NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square()); + + NDArray z = nup1Half.gammaln() + .sub(nu.sub(2.).gammaln()) + .sub(nu.mul(Math.PI).log().mul(0.5)) + .sub(sigma.log()); + + return z.sub(nup1Half).mul(part1.add(1.).log()); + } + + @Override + public NDArray sample(int numSamples) { + NDManager manager = mu.getManager(); + NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples); + NDArray expandedSigma = sigma.expandDims(0).repeat(0, numSamples); + NDArray expandedNu = nu.expandDims(0).repeat(0, numSamples); + + NDArray gammas = manager.sampleGamma(expandedNu.div(2.), expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.)); + return manager.sampleNormal(expandedMu, gammas.sqrt().getNDArrayInternal().rdiv(1.)); + } + + @Override + public NDArray mean() { + return NDArrays.where(nu.gt(1.0), mu, mu.getManager().full(mu.getShape(), Float.NaN)); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder extends DistributionBuilder { + + @Override + public Distribution build() { + Preconditions.checkArgument(distrArgs.contains("mu"), "StudentTl's args must contain mu."); + Preconditions.checkArgument(distrArgs.contains("sigma"), "StudentTl's args must contain sigma."); + Preconditions.checkArgument(distrArgs.contains("nu"), "StudentTl's args must contain nu."); + return new StudentT(this); + } + + @Override + protected Builder self() { + return this; + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java new file mode 100644 index 00000000000..0f94358a0d3 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -0,0 +1,112 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.nn.Block; +import ai.djl.nn.core.Linear; +import ai.djl.training.ParameterStore; +import ai.djl.util.Pair; +import ai.djl.util.PairList; +import ai.djl.util.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public final class ArgProj extends AbstractBlock { + + private Block domainMap; + private List proj; + + ArgProj(Builder builder) { + proj = new ArrayList<>(); + for (Pair entry : builder.argsDim) { + proj.add( + addChildBlock( + String.format("%s_distr_%s", builder.prefix, entry.getKey()), + Linear.builder().setUnits(entry.getValue()).build())); + } + domainMap = + addChildBlock(String.format("%s_domain_map", builder.prefix), builder.domainMap); + } + + @Override + protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { + for (Block block : proj) { + block.initialize(manager, dataType, inputShapes); + } + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + NDList paramsUnbounded = new NDList(); + for (Block block : proj) { + paramsUnbounded.add( + block.forward(parameterStore, inputs, training, params).singletonOrThrow()); + } + return domainMap.forward(parameterStore, paramsUnbounded, training, params); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + Shape[] projOutShapes = new Shape[proj.size()]; + for (int i = 0; i < proj.size(); i++) { + projOutShapes[i] = proj.get(i).getOutputShapes(inputShapes)[0]; + } + return domainMap.getOutputShapes(projOutShapes); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + private PairList argsDim; + private Function domainMap; + private String prefix = ""; + + public Builder setArgsDim(PairList argsDim) { + this.argsDim = argsDim; + return this; + } + + public Builder setDomainMap(Function domainMap) { + this.domainMap = domainMap; + return this; + } + + public Builder optPrefix(String prefix) { + this.prefix = prefix; + return this; + } + + public ArgProj build() { + Preconditions.checkArgument(argsDim != null, "must specify dim args"); + Preconditions.checkArgument(domainMap != null, "must specify domain PairList function"); + return new ArgProj(this); + } + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java new file mode 100644 index 00000000000..a032d6552db --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -0,0 +1,73 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.DistributionLoss; +import ai.djl.util.PairList; + +public abstract class DistributionOutput { + + protected PairList argsDim; + private float valueInSupport = 0f; + + public float getValueInSupport() { + return valueInSupport; + } + + /** + * Return the corresponding projection block based on the args dimension of different + * ditributions. + * + * @return the corresponding projection block + */ + public ArgProj getArgsProj() { + return ArgProj.builder().setArgsDim(argsDim).setDomainMap(this::domainMap).build(); + } + + /** + * Return the corresponding projection block based on the args dimension of different + * ditributions. + * + * @param prefix the prefix string of projection layer block + * @return the corresponding projection block + */ + public ArgProj getArgsProj(String prefix) { + return ArgProj.builder() + .setArgsDim(argsDim) + .setDomainMap(this::domainMap) + .optPrefix(prefix) + .build(); + } + + /** + * Convert arguments to the right shape and domain. The domain depends on the type of + * distribution, while the correct shape is obtained by reshaping the trailing axis in such a + * way that the returned tensors define a distribution of the right event_shape. + * + *

This function is usually used as the lambda of the Lambda block. + * + * @param arrays the arguments + * @return converted arguments + */ + public abstract NDList domainMap(NDList arrays); + + /** + * Return the associated {@code DistributionBuilder}, given the collection of constructor arguments and, optionally, a scale tensor. + * + * @return the associated {@code DistributionBuilder} + */ + public abstract Distribution.DistributionBuilder distributionBuilder(); +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java new file mode 100644 index 00000000000..9eab777de65 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -0,0 +1,49 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.DistributionLoss; +import ai.djl.timeseries.distribution.NegativeBinomial; +import ai.djl.util.PairList; + +public final class NegativeBinomialOutput extends DistributionOutput { + + public NegativeBinomialOutput() { + argsDim = new PairList<>(2); + argsDim.add("mu", 1); + argsDim.add("alpha", 1); + } + + /** {@inheritDoc} */ + @Override + public NDList domainMap(NDList arrays) { + NDArray mu = arrays.get(0); + NDArray alpha = arrays.get(1); + mu = mu.getNDArrayInternal().softPlus().squeeze(-1); + alpha = alpha.getNDArrayInternal().softPlus().squeeze(-1); + // TODO: make setName() must be implemented + mu.setName("mu"); + alpha.setName("alpha"); + return new NDList(mu, alpha); + } + + /** {@inheritDoc} */ + @Override + public Distribution.DistributionBuilder distributionBuilder() { + return NegativeBinomial.builder(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java new file mode 100644 index 00000000000..0791629de54 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -0,0 +1,52 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution.output; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.DistributionLoss; +import ai.djl.util.PairList; + +public class StudentTOutput extends DistributionOutput { + + public StudentTOutput() { + argsDim = new PairList<>(3); + argsDim.add("mu", 1); + argsDim.add("sigma", 1); + argsDim.add("nu", 1); + } + + /** {@inheritDoc} */ + @Override + public NDList domainMap(NDList arrays) { + NDArray mu = arrays.get(0); + NDArray sigma = arrays.get(1); + NDArray nu = arrays.get(2); + mu = mu.squeeze(-1); + sigma = sigma.getNDArrayInternal().softPlus().squeeze(-1); + nu = nu.getNDArrayInternal().softPlus().add(2.).squeeze(-1); + // TODO: make setName() must be implemented + mu.setName("mu"); + sigma.setName("sigma"); + nu.setName("nu"); + return new NDList(mu, sigma, nu); + } + + /** {@inheritDoc} */ + @Override + public Distribution.DistributionBuilder distributionBuilder() { + return null; + } +} From 2e748f0da8c10022ca4ca82bd37ddf63e97dfcd3 Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Fri, 16 Sep 2022 14:12:14 +0800 Subject: [PATCH 02/14] bug fix and add unit test --- .../djl/timeseries/distribution/StudentT.java | 4 +- .../distribution/DistributionTest.java | 48 +++++++++++++++++++ .../timeseries/distribution/package-info.java | 1 + 3 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java create mode 100644 extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java index 2933411f3f0..5379f96fa7a 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -23,11 +23,11 @@ public NDArray logProb(NDArray target) { NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square()); NDArray z = nup1Half.gammaln() - .sub(nu.sub(2.).gammaln()) + .sub(nu.div(2.).gammaln()) .sub(nu.mul(Math.PI).log().mul(0.5)) .sub(sigma.log()); - return z.sub(nup1Half).mul(part1.add(1.).log()); + return z.sub(nup1Half.mul(part1.add(1.).log())); } @Override diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java new file mode 100644 index 00000000000..fffaba4cafb --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java @@ -0,0 +1,48 @@ +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.testing.Assertions; +import org.testng.annotations.Test; + +public class DistributionTest { + + @Test + public void testNegativeBinomial() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray mu = manager.create(new float[]{1000f, 1f}); + NDArray alpha = manager.create(new float[]{1f, 2f}); + mu.setName("mu"); + alpha.setName("alpha"); + Distribution negativeBinomial = NegativeBinomial + .builder() + .setDistrArgs(new NDList(mu, alpha)) + .build(); + + NDArray expected = manager.create(new float[]{-6.9098f, -1.6479f}); + NDArray real = negativeBinomial.logProb(manager.create(new float[]{1f, 1f})); + Assertions.assertAlmostEquals(real, expected); + } + } + + @Test + public void testStudentT() { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray mu = manager.create(new float[]{1000f, -1000f}); + NDArray sigma = manager.create(new float[]{1f, 2f}); + NDArray nu = manager.create(new float[]{4.2f, 3f}); + mu.setName("mu"); + sigma.setName("sigma"); + nu.setName("nu"); + Distribution studentT = StudentT + .builder() + .setDistrArgs(new NDList(mu, sigma, nu)) + .build(); + + NDArray expected = manager.create(new float[]{-0.9779f, -1.6940f}); + NDArray real = studentT.logProb(manager.create(new float[]{1000f, -1000f})); + Assertions.assertAlmostEquals(real, expected); + } + } +} diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java new file mode 100644 index 00000000000..3835506616d --- /dev/null +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java @@ -0,0 +1 @@ +package ai.djl.timeseries.distribution; \ No newline at end of file From 78c992acce1b2080765b44b0cb679a903e63ec6e Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Fri, 16 Sep 2022 15:07:30 +0800 Subject: [PATCH 03/14] add comments and unit test --- .../timeseries/distribution/Distribution.java | 27 +++++++++++- .../distribution/DistributionLoss.java | 19 +++++---- .../distribution/NegativeBinomial.java | 29 ++++++++++--- .../djl/timeseries/distribution/StudentT.java | 42 +++++++++++++++---- .../distribution/output/ArgProj.java | 35 +++++++++++++++- .../output/DistributionOutput.java | 19 ++++++--- .../output/NegativeBinomialOutput.java | 6 ++- .../distribution/output/StudentTOutput.java | 5 ++- .../distribution/DistributionTest.java | 31 +++++++------- 9 files changed, 165 insertions(+), 48 deletions(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java index 056d95a315c..31358b185e6 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; +/** An abstract class representing probability distribution. */ public abstract class Distribution { /** @@ -23,7 +24,7 @@ public abstract class Distribution { * * @param target {@link NDArray} of shape (*batch_shape, *event_shape) * @return Tensor of shape (batch_shape) containing the probability log-density for each event - * in target + * in target */ public abstract NDArray logProb(NDArray target); @@ -44,21 +45,45 @@ public abstract class Distribution { */ public abstract NDArray mean(); + /** + * A builder to extend for all classes extend the {@link Distribution}. + * + * @param the concrete builder type + */ public abstract static class DistributionBuilder> { protected NDList distrArgs; protected NDArray scale; protected NDArray loc; + /** + * Set the appropriate arguments for the probability distribution. + * + * @param distrArgs a {@link NDList} containing distribution args named after the parameter + * name + * @return this builder + */ public T setDistrArgs(NDList distrArgs) { this.distrArgs = distrArgs; return self(); } + /** + * Set the affine scale for the probability distribution. + * + * @param scale the affine scale + * @return this builder + */ public T optScale(NDArray scale) { this.scale = scale; return self(); } + /** + * Set the affine location of the probability + * + * @param loc the affine location + * @return this builder + */ public T optLoc(NDArray loc) { this.loc = loc; return self(); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java index 5a1f7800ef3..cc3b12bb3bc 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java @@ -19,14 +19,20 @@ import ai.djl.timeseries.distribution.output.DistributionOutput; import ai.djl.training.loss.Loss; +/** + * {@code DistributionLoss} calculates loss for a given distribution + * + *

Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point + */ public class DistributionLoss extends Loss { private DistributionOutput distrOutput; /** - * Base class for metric with abstract update methods. + * Calculates Distribution Loss between the label and distribution. * - * @param name The display name of the Loss + * @param name the name of the loss + * @param distrOutput the {@link DistributionOutput} to construct the target distribution */ public DistributionLoss(String name, DistributionOutput distrOutput) { super(name); @@ -49,11 +55,10 @@ public NDArray evaluate(NDList labels, NDList predictions) { if (predictions.contains("loss_weights")) { NDArray lossWeights = predictions.get("loss_weights"); - NDArray weightedValue = NDArrays.where( - lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike() - ); - NDArray sumWeights = lossWeights.sum(new int[]{1}).maximum(1.); - loss = weightedValue.sum(new int[]{1}).div(sumWeights); + NDArray weightedValue = + NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike()); + NDArray sumWeights = lossWeights.sum(new int[] {1}).maximum(1.); + loss = weightedValue.sum(new int[] {1}).div(sumWeights); } return loss; } diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java index e04b244f325..5851136ae6f 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -14,10 +14,17 @@ package ai.djl.timeseries.distribution; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.util.Preconditions; +/** + * Negative binomial distribution. + * + *

The distribution of the number of successes in a sequence of independent Bernoulli trials. + * + *

Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the + * inverse number of negative Bernoulli trials to stop + */ public final class NegativeBinomial extends Distribution { private NDArray mu; @@ -28,20 +35,21 @@ public final class NegativeBinomial extends Distribution { alpha = builder.distrArgs.get("alpha"); } + /** {@inheritDoc} */ @Override public NDArray logProb(NDArray target) { NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1); NDArray alphaTimesMu = alpha.mul(mu); - return target - .mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) + return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log()) .sub(alphaInv.mul(alphaTimesMu.add(1).log())) .add(target.add(alphaInv).gammaln()) .sub(target.add(1.).gammaln()) .sub(alphaInv.gammaln()); } + /** {@inheritDoc} */ @Override public NDArray sample(int numSamples) { NDManager manager = mu.getManager(); @@ -53,24 +61,35 @@ public NDArray sample(int numSamples) { return manager.samplePoisson(manager.sampleGamma(r, theta)); } + /** {@inheritDoc} */ @Override public NDArray mean() { return mu; } + /** + * Creates a builder to build a {@code NegativeBinomial}. + * + * @return a new builder + */ public static Builder builder() { return new Builder(); } + /** The builder to construct a {@code NegativeBinomial}. */ public static final class Builder extends DistributionBuilder { + /** {@inheritDoc} */ @Override public Distribution build() { - Preconditions.checkArgument(distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); - Preconditions.checkArgument(distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); + Preconditions.checkArgument( + distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); + Preconditions.checkArgument( + distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); return new NegativeBinomial(this); } + /** {@inheritDoc} */ @Override protected Builder self() { return this; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java index 5379f96fa7a..22090c34e96 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -5,8 +5,14 @@ import ai.djl.ndarray.NDManager; import ai.djl.util.Preconditions; +/** + * Student's t-test distribution. + * + *

Three arguments for this distribution. {@code mu} mean of the distribution, {@code sigma} the + * standard deviations (scale), {@code nu} degrees of freedom. + */ public class StudentT extends Distribution { - + private NDArray mu; private NDArray sigma; private NDArray nu; @@ -17,19 +23,22 @@ public class StudentT extends Distribution { nu = builder.distrArgs.get("nu"); } + /** {@inheritDoc} */ @Override public NDArray logProb(NDArray target) { NDArray nup1Half = nu.add(1.).div(2.); NDArray part1 = nu.getNDArrayInternal().rdiv(1.).mul(target.sub(mu).div(sigma).square()); - NDArray z = nup1Half.gammaln() - .sub(nu.div(2.).gammaln()) - .sub(nu.mul(Math.PI).log().mul(0.5)) - .sub(sigma.log()); + NDArray z = + nup1Half.gammaln() + .sub(nu.div(2.).gammaln()) + .sub(nu.mul(Math.PI).log().mul(0.5)) + .sub(sigma.log()); return z.sub(nup1Half.mul(part1.add(1.).log())); } + /** {@inheritDoc} */ @Override public NDArray sample(int numSamples) { NDManager manager = mu.getManager(); @@ -37,29 +46,44 @@ public NDArray sample(int numSamples) { NDArray expandedSigma = sigma.expandDims(0).repeat(0, numSamples); NDArray expandedNu = nu.expandDims(0).repeat(0, numSamples); - NDArray gammas = manager.sampleGamma(expandedNu.div(2.), expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.)); + NDArray gammas = + manager.sampleGamma( + expandedNu.div(2.), + expandedNu.mul(expandedSigma.square()).getNDArrayInternal().rdiv(2.)); return manager.sampleNormal(expandedMu, gammas.sqrt().getNDArrayInternal().rdiv(1.)); } + /** {@inheritDoc} */ @Override public NDArray mean() { return NDArrays.where(nu.gt(1.0), mu, mu.getManager().full(mu.getShape(), Float.NaN)); } + /** + * Creates a builder to build a {@code NegativeBinomial}. + * + * @return a new builder + */ public static Builder builder() { return new Builder(); } + /** The builder to construct a {@code NegativeBinomial}. */ public static final class Builder extends DistributionBuilder { + /** {@inheritDoc} */ @Override public Distribution build() { - Preconditions.checkArgument(distrArgs.contains("mu"), "StudentTl's args must contain mu."); - Preconditions.checkArgument(distrArgs.contains("sigma"), "StudentTl's args must contain sigma."); - Preconditions.checkArgument(distrArgs.contains("nu"), "StudentTl's args must contain nu."); + Preconditions.checkArgument( + distrArgs.contains("mu"), "StudentTl's args must contain mu."); + Preconditions.checkArgument( + distrArgs.contains("sigma"), "StudentTl's args must contain sigma."); + Preconditions.checkArgument( + distrArgs.contains("nu"), "StudentTl's args must contain nu."); return new StudentT(this); } + /** {@inheritDoc} */ @Override protected Builder self() { return this; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java index 0f94358a0d3..72e6aae3bb8 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -13,7 +13,6 @@ package ai.djl.timeseries.distribution.output; -import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; @@ -30,6 +29,7 @@ import java.util.List; import java.util.function.Function; +/** A Block that can be used to project from a dense layer to distribution arguments */ public final class ArgProj extends AbstractBlock { private Block domainMap; @@ -47,8 +47,10 @@ public final class ArgProj extends AbstractBlock { addChildBlock(String.format("%s_domain_map", builder.prefix), builder.domainMap); } + /** {@inheritDoc} */ @Override - protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { + protected void initializeChildBlocks( + NDManager manager, DataType dataType, Shape... inputShapes) { for (Block block : proj) { block.initialize(manager, dataType, inputShapes); } @@ -79,30 +81,59 @@ public Shape[] getOutputShapes(Shape[] inputShapes) { return domainMap.getOutputShapes(projOutShapes); } + /** + * Creates a builder to build a {@code ArgProj}. + * + * @return a new builder + */ public static Builder builder() { return new Builder(); } + /** The Builder to construct a {@code ArgProj} type of {@link Block}. */ public static final class Builder { private PairList argsDim; private Function domainMap; private String prefix = ""; + /** + * Set the arguments dimensions of distribution + * + * @param argsDim the arguments dimension + * @return this builder + */ public Builder setArgsDim(PairList argsDim) { this.argsDim = argsDim; return this; } + /** + * Set the domain map function + * + * @param domainMap the domain map function + * @return this builder + */ public Builder setDomainMap(Function domainMap) { this.domainMap = domainMap; return this; } + /** + * Set the block name prefix + * + * @param prefix the prefix + * @return this builder + */ public Builder optPrefix(String prefix) { this.prefix = prefix; return this; } + /** + * Build a {@link ArgProj} block. + * + * @return the {@link ArgProj} block. + */ public ArgProj build() { Preconditions.checkArgument(argsDim != null, "must specify dim args"); Preconditions.checkArgument(domainMap != null, "must specify domain PairList function"); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java index a032d6552db..a2925015f6e 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -15,21 +15,29 @@ import ai.djl.ndarray.NDList; import ai.djl.timeseries.distribution.Distribution; -import ai.djl.timeseries.distribution.DistributionLoss; import ai.djl.util.PairList; +/** A class to construct a distribution given the output of a network. */ public abstract class DistributionOutput { protected PairList argsDim; private float valueInSupport = 0f; + /** + * A float that will have a valid numeric value when computing the log-loss of the corresponding + * distribution. + * + *

By default {@code 0f}. This value will be used when padding data series. + * + * @return the valueInSupport + */ public float getValueInSupport() { return valueInSupport; } /** - * Return the corresponding projection block based on the args dimension of different - * ditributions. + * Return the corresponding projection block based on the arguments dimension of different + * distributions. * * @return the corresponding projection block */ @@ -38,7 +46,7 @@ public ArgProj getArgsProj() { } /** - * Return the corresponding projection block based on the args dimension of different + * Return the corresponding projection block based on the arguments dimension of different * ditributions. * * @param prefix the prefix string of projection layer block @@ -65,7 +73,8 @@ public ArgProj getArgsProj(String prefix) { public abstract NDList domainMap(NDList arrays); /** - * Return the associated {@code DistributionBuilder}, given the collection of constructor arguments and, optionally, a scale tensor. + * Return the associated {@code DistributionBuilder}, given the collection of constructor + * arguments and, optionally, a scale tensor. * * @return the associated {@code DistributionBuilder} */ diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java index 9eab777de65..610621fc6c8 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -16,12 +16,16 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.timeseries.distribution.Distribution; -import ai.djl.timeseries.distribution.DistributionLoss; import ai.djl.timeseries.distribution.NegativeBinomial; import ai.djl.util.PairList; +/** + * {@code NegativeBinomialOutput} is a {@link DistributionOutput} that for negative binomial + * distribution. + */ public final class NegativeBinomialOutput extends DistributionOutput { + /** Construct a negative binomial output with two arguments, {@code mu} and {@code alpha}. */ public NegativeBinomialOutput() { argsDim = new PairList<>(2); argsDim.add("mu", 1); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java index 0791629de54..156c83f71b3 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -16,11 +16,14 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.timeseries.distribution.Distribution; -import ai.djl.timeseries.distribution.DistributionLoss; import ai.djl.util.PairList; +/** + * {@code StudentTOutput} is a {@link DistributionOutput} that for Student's t-test distribution. + */ public class StudentTOutput extends DistributionOutput { + /** Construct a negative binomial output with two arguments, {@code mu} and {@code sigma}. */ public StudentTOutput() { argsDim = new PairList<>(3); argsDim.add("mu", 1); diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java index fffaba4cafb..8ec50da5dbd 100644 --- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java @@ -4,6 +4,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.testing.Assertions; + import org.testng.annotations.Test; public class DistributionTest { @@ -11,17 +12,15 @@ public class DistributionTest { @Test public void testNegativeBinomial() { try (NDManager manager = NDManager.newBaseManager()) { - NDArray mu = manager.create(new float[]{1000f, 1f}); - NDArray alpha = manager.create(new float[]{1f, 2f}); + NDArray mu = manager.create(new float[] {1000f, 1f}); + NDArray alpha = manager.create(new float[] {1f, 2f}); mu.setName("mu"); alpha.setName("alpha"); - Distribution negativeBinomial = NegativeBinomial - .builder() - .setDistrArgs(new NDList(mu, alpha)) - .build(); + Distribution negativeBinomial = + NegativeBinomial.builder().setDistrArgs(new NDList(mu, alpha)).build(); - NDArray expected = manager.create(new float[]{-6.9098f, -1.6479f}); - NDArray real = negativeBinomial.logProb(manager.create(new float[]{1f, 1f})); + NDArray expected = manager.create(new float[] {-6.9098f, -1.6479f}); + NDArray real = negativeBinomial.logProb(manager.create(new float[] {1f, 1f})); Assertions.assertAlmostEquals(real, expected); } } @@ -29,19 +28,17 @@ public void testNegativeBinomial() { @Test public void testStudentT() { try (NDManager manager = NDManager.newBaseManager()) { - NDArray mu = manager.create(new float[]{1000f, -1000f}); - NDArray sigma = manager.create(new float[]{1f, 2f}); - NDArray nu = manager.create(new float[]{4.2f, 3f}); + NDArray mu = manager.create(new float[] {1000f, -1000f}); + NDArray sigma = manager.create(new float[] {1f, 2f}); + NDArray nu = manager.create(new float[] {4.2f, 3f}); mu.setName("mu"); sigma.setName("sigma"); nu.setName("nu"); - Distribution studentT = StudentT - .builder() - .setDistrArgs(new NDList(mu, sigma, nu)) - .build(); + Distribution studentT = + StudentT.builder().setDistrArgs(new NDList(mu, sigma, nu)).build(); - NDArray expected = manager.create(new float[]{-0.9779f, -1.6940f}); - NDArray real = studentT.logProb(manager.create(new float[]{1000f, -1000f})); + NDArray expected = manager.create(new float[] {-0.9779f, -1.6940f}); + NDArray real = studentT.logProb(manager.create(new float[] {1000f, -1000f})); Assertions.assertAlmostEquals(real, expected); } } From 50d9c7adb07925d0671a3a60a07d1393bb6111df Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Fri, 16 Sep 2022 15:23:27 +0800 Subject: [PATCH 04/14] add copyright and package-info --- .../ai/djl/timeseries/distribution/StudentT.java | 13 +++++++++++++ .../djl/timeseries/distribution/package-info.java | 15 +++++++++++++++ .../timeseries/distribution/DistributionTest.java | 13 +++++++++++++ .../djl/timeseries/distribution/package-info.java | 14 ++++++++++++++ 4 files changed, 55 insertions(+) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java index 22090c34e96..8854ee9d165 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -1,3 +1,16 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + package ai.djl.timeseries.distribution; import ai.djl.ndarray.NDArray; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java new file mode 100644 index 00000000000..5f18cc5f60e --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to support distribution in djl. */ +package ai.djl.timeseries.distribution; diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java index 8ec50da5dbd..051f7f4743b 100644 --- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/DistributionTest.java @@ -1,3 +1,16 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + package ai.djl.timeseries.distribution; import ai.djl.ndarray.NDArray; diff --git a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java index 3835506616d..d6d06167d43 100644 --- a/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java +++ b/extensions/timeseries/src/test/java/ai/djl/timeseries/distribution/package-info.java @@ -1 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests for the distribution module. */ package ai.djl.timeseries.distribution; \ No newline at end of file From 8048e1bccc75799588b208d1f30909da6cab0045 Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Fri, 16 Sep 2022 15:34:34 +0800 Subject: [PATCH 05/14] style fix --- .../djl/timeseries/distribution/Distribution.java | 9 +++++++-- .../timeseries/distribution/DistributionLoss.java | 2 +- .../timeseries/distribution/output/ArgProj.java | 8 ++++---- .../distribution/output/DistributionOutput.java | 2 +- .../distribution/output/package-info.java | 15 +++++++++++++++ 5 files changed, 28 insertions(+), 8 deletions(-) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java index 31358b185e6..1adcdd4ce77 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -20,7 +20,7 @@ public abstract class Distribution { /** - * Compute the log of the probability density/mass function evaluated at target + * Compute the log of the probability density/mass function evaluated at target. * * @param target {@link NDArray} of shape (*batch_shape, *event_shape) * @return Tensor of shape (batch_shape) containing the probability log-density for each event @@ -79,7 +79,7 @@ public T optScale(NDArray scale) { } /** - * Set the affine location of the probability + * Set the affine location of the probability. * * @param loc the affine location * @return this builder @@ -89,6 +89,11 @@ public T optLoc(NDArray loc) { return self(); } + /** + * Build a {@code Distribution}. + * + * @return the {@code Distribution} + */ public abstract Distribution build(); protected abstract T self(); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java index cc3b12bb3bc..30d7a2f35d9 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java @@ -20,7 +20,7 @@ import ai.djl.training.loss.Loss; /** - * {@code DistributionLoss} calculates loss for a given distribution + * {@code DistributionLoss} calculates loss for a given distribution. * *

Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point */ diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java index 72e6aae3bb8..e47c04e1782 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.function.Function; -/** A Block that can be used to project from a dense layer to distribution arguments */ +/** A Block that can be used to project from a dense layer to distribution arguments. */ public final class ArgProj extends AbstractBlock { private Block domainMap; @@ -97,7 +97,7 @@ public static final class Builder { private String prefix = ""; /** - * Set the arguments dimensions of distribution + * Set the arguments dimensions of distribution. * * @param argsDim the arguments dimension * @return this builder @@ -108,7 +108,7 @@ public Builder setArgsDim(PairList argsDim) { } /** - * Set the domain map function + * Set the domain map function. * * @param domainMap the domain map function * @return this builder @@ -119,7 +119,7 @@ public Builder setDomainMap(Function domainMap) { } /** - * Set the block name prefix + * Set the block name prefix. * * @param prefix the prefix * @return this builder diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java index a2925015f6e..2bf98459b15 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -21,7 +21,7 @@ public abstract class DistributionOutput { protected PairList argsDim; - private float valueInSupport = 0f; + private float valueInSupport; /** * A float that will have a valid numeric value when computing the log-loss of the corresponding diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java new file mode 100644 index 00000000000..8003778cd6a --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains classes to support construct distribution and project arguments. */ +package ai.djl.timeseries.distribution.output; From 06759d181be83651bce87ac8bb16f89357404fc3 Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Sat, 17 Sep 2022 19:48:19 +0800 Subject: [PATCH 06/14] add args array --- .../timeseries/distribution/output/DistributionOutput.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java index 2bf98459b15..81c932221ad 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -60,6 +60,10 @@ public ArgProj getArgsProj(String prefix) { .build(); } + public String[] getArgsArray() { + return argsDim.keyArray(new String[argsDim.size()]); + } + /** * Convert arguments to the right shape and domain. The domain depends on the type of * distribution, while the correct shape is obtained by reshaping the trailing axis in such a From ad62bd8f3e8a6953098ff902978f19ca979f4d66 Mon Sep 17 00:00:00 2001 From: Carkham <1302112560@qq.com> Date: Sun, 18 Sep 2022 16:33:58 +0800 Subject: [PATCH 07/14] feature: add affinely distribution --- .../distribution/AffineTransformed.java | 74 +++++++++++++++++++ .../distribution/NegativeBinomial.java | 9 +++ .../djl/timeseries/distribution/StudentT.java | 6 +- .../output/DistributionOutput.java | 5 ++ 4 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java new file mode 100644 index 00000000000..251aea51689 --- /dev/null +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/AffineTransformed.java @@ -0,0 +1,74 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +package ai.djl.timeseries.distribution; + +import ai.djl.ndarray.NDArray; + +/** Represents the distribution of an affinely transformed random variable. */ +public class AffineTransformed extends Distribution { + + private Distribution baseDistribution; + private NDArray loc; + private NDArray scale; + + /** + * Construct a new {@code AffineTransformed} + * + *

This is the distribution of Y = scale * X + loc, where X is a random variable distributed + * according to {@code baseDistribution}. + * + * @param baseDistribution original distribution + * @param loc translation parameter of the affine transformation + * @param scale scaling parameter of the affine transformation + */ + public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) { + this.baseDistribution = baseDistribution; + this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc; + this.scale = scale == null ? baseDistribution.mean().onesLike() : scale; + } + + /** {@inheritDoc} */ + @Override + public NDArray logProb(NDArray target) { + NDArray x = fInv(target); + NDArray ladj = logAbsDetJac(x); + NDArray lp = ladj.mul(-1); + return baseDistribution.logProb(x).add(lp); + } + + /** {@inheritDoc} */ + @Override + public NDArray sample(int numSamples) { + NDArray sample = baseDistribution.sample(numSamples); + return f(sample); + } + + /** {@inheritDoc} */ + @Override + public NDArray mean() { + return baseDistribution.mean().mul(scale).add(loc); + } + + private NDArray f(NDArray x) { + return x.mul(scale).add(loc); + } + + private NDArray fInv(NDArray y) { + return y.sub(loc).div(scale); + } + + private NDArray logAbsDetJac(NDArray x) { + return scale.broadcast(x.getShape()).abs().log(); + } +} diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java index 5851136ae6f..b3a90c9d94b 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -86,6 +86,15 @@ public Distribution build() { distrArgs.contains("mu"), "NegativeBinomial's args must contain mu."); Preconditions.checkArgument( distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha."); + // We cannot scale using the affine transformation since negative binomial should return + // integers. Instead we scale the parameters. + if (scale != null) { + NDArray mu = distrArgs.get("mu"); + mu = mu.mul(scale); + mu.setName("mu"); + distrArgs.remove("mu"); + distrArgs.add(mu); + } return new NegativeBinomial(this); } diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java index 8854ee9d165..1086f052ac1 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -93,7 +93,11 @@ public Distribution build() { distrArgs.contains("sigma"), "StudentTl's args must contain sigma."); Preconditions.checkArgument( distrArgs.contains("nu"), "StudentTl's args must contain nu."); - return new StudentT(this); + StudentT baseDistr = new StudentT(this); + if (scale == null && loc == null) { + return baseDistr; + } + return new AffineTransformed(baseDistr, loc, scale); } /** {@inheritDoc} */ diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java index 81c932221ad..de3538df489 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/DistributionOutput.java @@ -60,6 +60,11 @@ public ArgProj getArgsProj(String prefix) { .build(); } + /** + * Return an array containing all the argument names. + * + * @return an array containing all the argument names + */ public String[] getArgsArray() { return argsDim.keyArray(new String[argsDim.size()]); } From 15f3b9c6c44edf349910b08c221c528e07cbf797 Mon Sep 17 00:00:00 2001 From: Carkham <13193102258@163.com> Date: Mon, 19 Sep 2022 00:10:52 +0800 Subject: [PATCH 08/14] feature: add sample for zero num --- .../djl/timeseries/distribution/Distribution.java | 14 +++++++++++++- .../timeseries/distribution/NegativeBinomial.java | 4 ++-- .../ai/djl/timeseries/distribution/StudentT.java | 6 +++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java index 1adcdd4ce77..739694cb9a7 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/Distribution.java @@ -31,13 +31,25 @@ public abstract class Distribution { /** * Draw samples from the distribution. * - *

The first dimension of the output will be numSamples. + *

This function would expand the dimension of arguments, the first dimension of the output + * will be numSamples. * * @param numSamples Number of samples to be drawn * @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape) */ public abstract NDArray sample(int numSamples); + /** + * Draw samples from the distribution. + * + *

This function would not expand the dimension + * + * @return a sampled {@link NDArray} + */ + public NDArray sample() { + return sample(0); + } + /** * Return the mean of the distribution. * diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java index b3a90c9d94b..8565babc244 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/NegativeBinomial.java @@ -53,8 +53,8 @@ public NDArray logProb(NDArray target) { @Override public NDArray sample(int numSamples) { NDManager manager = mu.getManager(); - NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples); - NDArray expandedAlpha = alpha.expandDims(0).repeat(0, numSamples); + NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; + NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha; NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f); NDArray theta = expandedAlpha.mul(expandedMu); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java index 1086f052ac1..34168acd601 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/StudentT.java @@ -55,9 +55,9 @@ public NDArray logProb(NDArray target) { @Override public NDArray sample(int numSamples) { NDManager manager = mu.getManager(); - NDArray expandedMu = mu.expandDims(0).repeat(0, numSamples); - NDArray expandedSigma = sigma.expandDims(0).repeat(0, numSamples); - NDArray expandedNu = nu.expandDims(0).repeat(0, numSamples); + NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu; + NDArray expandedSigma = numSamples > 0 ? sigma.expandDims(0).repeat(0, numSamples) : sigma; + NDArray expandedNu = numSamples > 0 ? nu.expandDims(0).repeat(0, numSamples) : nu; NDArray gammas = manager.sampleGamma( From 0f39651012301e79dd4cb887a43f76ef034ba124 Mon Sep 17 00:00:00 2001 From: Carkham <13193102258@163.com> Date: Mon, 19 Sep 2022 21:56:47 +0800 Subject: [PATCH 09/14] bug fix: sum the loss --- .../java/ai/djl/timeseries/distribution/DistributionLoss.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java index 30d7a2f35d9..30251e71728 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/DistributionLoss.java @@ -57,8 +57,8 @@ public NDArray evaluate(NDList labels, NDList predictions) { NDArray lossWeights = predictions.get("loss_weights"); NDArray weightedValue = NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike()); - NDArray sumWeights = lossWeights.sum(new int[] {1}).maximum(1.); - loss = weightedValue.sum(new int[] {1}).div(sumWeights); + NDArray sumWeights = lossWeights.sum().maximum(1.); + loss = weightedValue.sum().div(sumWeights); } return loss; } From d8219931c0ba34d9c7c10d5b628804fa96700d5f Mon Sep 17 00:00:00 2001 From: Carkham <13193102258@163.com> Date: Tue, 20 Sep 2022 00:10:13 +0800 Subject: [PATCH 10/14] bug fix: add builder for StudentTOutput and clamp the min value in neg_binomial --- .../distribution/output/NegativeBinomialOutput.java | 4 ++-- .../ai/djl/timeseries/distribution/output/StudentTOutput.java | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java index 610621fc6c8..0131eb946eb 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -37,8 +37,8 @@ public NegativeBinomialOutput() { public NDList domainMap(NDList arrays) { NDArray mu = arrays.get(0); NDArray alpha = arrays.get(1); - mu = mu.getNDArrayInternal().softPlus().squeeze(-1); - alpha = alpha.getNDArrayInternal().softPlus().squeeze(-1); + mu = mu.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); + alpha = alpha.getNDArrayInternal().softPlus().maximum(Float.MIN_VALUE).squeeze(-1); // TODO: make setName() must be implemented mu.setName("mu"); alpha.setName("alpha"); diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java index 156c83f71b3..b0d775f99ea 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -16,6 +16,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.timeseries.distribution.Distribution; +import ai.djl.timeseries.distribution.StudentT; import ai.djl.util.PairList; /** @@ -50,6 +51,6 @@ public NDList domainMap(NDList arrays) { /** {@inheritDoc} */ @Override public Distribution.DistributionBuilder distributionBuilder() { - return null; + return StudentT.builder(); } } From dcf44cf169d55cf36d62a8992aa9c561079fb174 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Mon, 19 Sep 2022 09:21:17 -0700 Subject: [PATCH 11/14] Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java --- .../ai/djl/timeseries/distribution/output/StudentTOutput.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java index b0d775f99ea..98ea0771a35 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -20,7 +20,7 @@ import ai.djl.util.PairList; /** - * {@code StudentTOutput} is a {@link DistributionOutput} that for Student's t-test distribution. + * {@code StudentTOutput} is a {@link DistributionOutput} for the Student's t-test distribution. */ public class StudentTOutput extends DistributionOutput { From 0a4268582ed46a42aa7eb23d491ebdf195548eaa Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Mon, 19 Sep 2022 09:21:23 -0700 Subject: [PATCH 12/14] Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java --- .../timeseries/distribution/output/NegativeBinomialOutput.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java index 0131eb946eb..84222fe9741 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/NegativeBinomialOutput.java @@ -20,7 +20,7 @@ import ai.djl.util.PairList; /** - * {@code NegativeBinomialOutput} is a {@link DistributionOutput} that for negative binomial + * {@code NegativeBinomialOutput} is a {@link DistributionOutput} for the negative binomial * distribution. */ public final class NegativeBinomialOutput extends DistributionOutput { From 30804b5f5e8f99dc0ad361e2a9bf310eb4c96c1e Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Mon, 19 Sep 2022 09:21:45 -0700 Subject: [PATCH 13/14] Update extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java --- .../java/ai/djl/timeseries/distribution/output/ArgProj.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java index e47c04e1782..ae3e788bd58 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.function.Function; -/** A Block that can be used to project from a dense layer to distribution arguments. */ +/** A Block used to map the output of a dense layer to statistical parameters, like mean and standard deviation. It will be used in both training and inference. */ public final class ArgProj extends AbstractBlock { private Block domainMap; From 9349693604ba847562f480b599a992e30ddf7913 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Mon, 19 Sep 2022 09:54:08 -0700 Subject: [PATCH 14/14] format --- .../java/ai/djl/timeseries/distribution/output/ArgProj.java | 5 ++++- .../djl/timeseries/distribution/output/StudentTOutput.java | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java index ae3e788bd58..6969666e881 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/ArgProj.java @@ -29,7 +29,10 @@ import java.util.List; import java.util.function.Function; -/** A Block used to map the output of a dense layer to statistical parameters, like mean and standard deviation. It will be used in both training and inference. */ +/** + * A Block used to map the output of a dense layer to statistical parameters, like mean and standard + * deviation. It will be used in both training and inference. + */ public final class ArgProj extends AbstractBlock { private Block domainMap; diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java index 98ea0771a35..a38e7cab5c9 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/distribution/output/StudentTOutput.java @@ -19,9 +19,7 @@ import ai.djl.timeseries.distribution.StudentT; import ai.djl.util.PairList; -/** - * {@code StudentTOutput} is a {@link DistributionOutput} for the Student's t-test distribution. - */ +/** {@code StudentTOutput} is a {@link DistributionOutput} for the Student's t-test distribution. */ public class StudentTOutput extends DistributionOutput { /** Construct a negative binomial output with two arguments, {@code mu} and {@code sigma}. */