Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[timeseries] add probability distribution support for timeseries #2025

Merged
merged 14 commits into from
Sep 20, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;

/** Represents the distribution of an affinely transformed random variable. */
public class AffineTransformed extends Distribution {

private Distribution baseDistribution;
private NDArray loc;
private NDArray scale;

/**
* Construct a new {@code AffineTransformed}
*
* <p>This is the distribution of Y = scale * X + loc, where X is a random variable distributed
* according to {@code baseDistribution}.
*
* @param baseDistribution original distribution
* @param loc translation parameter of the affine transformation
* @param scale scaling parameter of the affine transformation
*/
public AffineTransformed(Distribution baseDistribution, NDArray loc, NDArray scale) {
this.baseDistribution = baseDistribution;
this.loc = loc == null ? baseDistribution.mean().zerosLike() : loc;
this.scale = scale == null ? baseDistribution.mean().onesLike() : scale;
}

/** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {
NDArray x = fInv(target);
NDArray ladj = logAbsDetJac(x);
NDArray lp = ladj.mul(-1);
return baseDistribution.logProb(x).add(lp);
}

/** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDArray sample = baseDistribution.sample(numSamples);
return f(sample);
}

/** {@inheritDoc} */
@Override
public NDArray mean() {
return baseDistribution.mean().mul(scale).add(loc);
}

private NDArray f(NDArray x) {
return x.mul(scale).add(loc);
}

private NDArray fInv(NDArray y) {
return y.sub(loc).div(scale);
}

private NDArray logAbsDetJac(NDArray x) {
return scale.broadcast(x.getShape()).abs().log();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** An abstract class representing probability distribution. */
public abstract class Distribution {

/**
* Compute the log of the probability density/mass function evaluated at target.
*
* @param target {@link NDArray} of shape (*batch_shape, *event_shape)
* @return Tensor of shape (batch_shape) containing the probability log-density for each event
* in target
*/
public abstract NDArray logProb(NDArray target);

/**
* Draw samples from the distribution.
*
* <p>This function would expand the dimension of arguments, the first dimension of the output
* will be numSamples.
*
* @param numSamples Number of samples to be drawn
* @return a {@link NDArray} has shape (num_samples, *batch_shape, *target_shape)
*/
public abstract NDArray sample(int numSamples);

/**
* Draw samples from the distribution.
*
* <p>This function would not expand the dimension
*
* @return a sampled {@link NDArray}
*/
public NDArray sample() {
return sample(0);
}

/**
* Return the mean of the distribution.
*
* @return the mean of the distribution
*/
public abstract NDArray mean();

/**
* A builder to extend for all classes extend the {@link Distribution}.
*
* @param <T> the concrete builder type
*/
public abstract static class DistributionBuilder<T extends DistributionBuilder<T>> {
protected NDList distrArgs;
protected NDArray scale;
protected NDArray loc;

/**
* Set the appropriate arguments for the probability distribution.
*
* @param distrArgs a {@link NDList} containing distribution args named after the parameter
* name
* @return this builder
*/
public T setDistrArgs(NDList distrArgs) {
this.distrArgs = distrArgs;
return self();
}

/**
* Set the affine scale for the probability distribution.
*
* @param scale the affine scale
* @return this builder
*/
public T optScale(NDArray scale) {
this.scale = scale;
return self();
}

/**
* Set the affine location of the probability.
*
* @param loc the affine location
* @return this builder
*/
public T optLoc(NDArray loc) {
this.loc = loc;
return self();
}

/**
* Build a {@code Distribution}.
*
* @return the {@code Distribution}
*/
public abstract Distribution build();

protected abstract T self();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.training.loss.Loss;

/**
* {@code DistributionLoss} calculates loss for a given distribution.
*
* <p>Distribution Loss is calculated by {@link Distribution#logProb(NDArray)} at label point
*/
public class DistributionLoss extends Loss {

private DistributionOutput distrOutput;

/**
* Calculates Distribution Loss between the label and distribution.
*
* @param name the name of the loss
* @param distrOutput the {@link DistributionOutput} to construct the target distribution
*/
public DistributionLoss(String name, DistributionOutput distrOutput) {
super(name);
this.distrOutput = distrOutput;
}

/** {@inheritDoc} */
@Override
public NDArray evaluate(NDList labels, NDList predictions) {
Distribution.DistributionBuilder<?> builder = distrOutput.distributionBuilder();
builder.setDistrArgs(predictions);
if (predictions.contains("scale")) {
builder.optScale(predictions.get("scale"));
}
if (predictions.contains("loc")) {
builder.optLoc(predictions.get("loc"));
}

NDArray loss = builder.build().logProb(labels.singletonOrThrow()).mul(-1);

if (predictions.contains("loss_weights")) {
NDArray lossWeights = predictions.get("loss_weights");
NDArray weightedValue =
NDArrays.where(lossWeights.neq(0), loss.mul(lossWeights), loss.zerosLike());
NDArray sumWeights = lossWeights.sum().maximum(1.);
loss = weightedValue.sum().div(sumWeights);
}
return loss;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

package ai.djl.timeseries.distribution;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.util.Preconditions;

/**
* Negative binomial distribution.
*
* <p>The distribution of the number of successes in a sequence of independent Bernoulli trials.
*
* <p>Two arguments for this distribution. {@code mu} mean of the distribution, {@code alpha} the
* inverse number of negative Bernoulli trials to stop
*/
public final class NegativeBinomial extends Distribution {

private NDArray mu;
private NDArray alpha;

NegativeBinomial(Builder builder) {
mu = builder.distrArgs.get("mu");
alpha = builder.distrArgs.get("alpha");
}

/** {@inheritDoc} */
@Override
public NDArray logProb(NDArray target) {

NDArray alphaInv = alpha.getNDArrayInternal().rdiv(1);
NDArray alphaTimesMu = alpha.mul(mu);

return target.mul(alphaTimesMu.div(alphaTimesMu.add(1)).log())
.sub(alphaInv.mul(alphaTimesMu.add(1).log()))
.add(target.add(alphaInv).gammaln())
.sub(target.add(1.).gammaln())
.sub(alphaInv.gammaln());
}

/** {@inheritDoc} */
@Override
public NDArray sample(int numSamples) {
NDManager manager = mu.getManager();
NDArray expandedMu = numSamples > 0 ? mu.expandDims(0).repeat(0, numSamples) : mu;
NDArray expandedAlpha = numSamples > 0 ? alpha.expandDims(0).repeat(0, numSamples) : alpha;

NDArray r = expandedAlpha.getNDArrayInternal().rdiv(1f);
NDArray theta = expandedAlpha.mul(expandedMu);
return manager.samplePoisson(manager.sampleGamma(r, theta));
}

/** {@inheritDoc} */
@Override
public NDArray mean() {
return mu;
}

/**
* Creates a builder to build a {@code NegativeBinomial}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

/** The builder to construct a {@code NegativeBinomial}. */
public static final class Builder extends DistributionBuilder<Builder> {

/** {@inheritDoc} */
@Override
public Distribution build() {
Preconditions.checkArgument(
distrArgs.contains("mu"), "NegativeBinomial's args must contain mu.");
Preconditions.checkArgument(
distrArgs.contains("alpha"), "NegativeBinomial's args must contain alpha.");
// We cannot scale using the affine transformation since negative binomial should return
// integers. Instead we scale the parameters.
if (scale != null) {
NDArray mu = distrArgs.get("mu");
mu = mu.mul(scale);
mu.setName("mu");
distrArgs.remove("mu");
distrArgs.add(mu);
}
return new NegativeBinomial(this);
}

/** {@inheritDoc} */
@Override
protected Builder self() {
return this;
}
}
}
Loading