Skip to content

Commit

Permalink
loss function fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Nov 9, 2022
1 parent 808b6d4 commit 33b1181
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 26 deletions.
9 changes: 5 additions & 4 deletions api/src/main/java/ai/djl/training/loss/QuantileL1Loss.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

/**
* {@code QuantileL1Loss} calculates the Weighted Quantile Loss between labels and predictions. It
* is useful for regression problems where you wish to estimate a particular quantile. For example,
* is useful in regression problems to target the best-fit line at a particular quantile. E.g.,
* to target the P90, instantiate {@code new QuantileL1Loss("P90", 0.90)}. Basically, what this loss
* function does is to focus on a centain persentile of the data. Eg. q=0.5 is the original default
* function does is to focus on a certain percentile of the data. E.g. q=0.5 is the original default
* case of regression, meaning the best-fit line lies in the center. When q=0.9, the best-fit line
* will lie above the center; and, if \partial forecast / \partial w are the same, then exactly 0.9
* of total data points will lie below the best-fit line.
* will lie above the center. By differentiating the loss function, the optimal solution will yield
* the result that, for some special cases like those where \partial forecast / \partial w are
* uniform, exactly 0.9 of total data points will lie below the best-fit line.
*
* <pre>
* def quantile_loss(target, forecast, q):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
Expand All @@ -33,6 +34,7 @@
import ai.djl.timeseries.SampleForecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;
Expand Down Expand Up @@ -86,9 +88,10 @@ public static Map<String, Float> predict()
// Then add the setting `.optRepository(repository)` to the builder below
M5Dataset dataset = M5Dataset.builder().setManager(manager).build();

// The modelUrl can be replaced by local model path. E.g.,
// String modelUrl = "rootPath/deepar.zip";
// Note that, for a model exported from MXNel, the tensor shape of the `begin_state` may be problematic, as indicated in this [issue](https://github.com/deepjavalibrary/djl/issues/2106#issuecomment-1295703321). As described there, you need to "change every begin_state's shape to (-1, 40)".
String modelUrl = "djl://ai.djl.mxnet/deepar/0.0.1/m5forecast";
// To use a load a local model, do:
// String modelUrl = "rootPath/deepar.zip";
int predictionLength = 4;
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
Expand Down Expand Up @@ -319,15 +322,10 @@ public Map<String, Float> getMetricsPerTs(

for (float quantile : quantiles) {
NDArray forecastQuantile = forecast.quantile(quantile);

NDArray quantileLoss =
forecastQuantile
.sub(gtTarget)
.mul(gtTarget.lte(forecastQuantile).sub(quantile))
.abs()
.sum()
.mul(2);
NDArray quantileCoverage = gtTarget.lt(forecastQuantile).mean();
NDArray quantileLoss = Loss.quantileL1Loss(quantile)
.evaluate(new NDList(gtTarget), new NDList(forecastQuantile));
NDArray quantileCoverage = gtTarget.lt(forecastQuantile)
.toType(DataType.FLOAT32, false).mean();
retMetrics.put(
String.format("QuantileLoss[%.2f]", quantile), quantileLoss.getFloat());
retMetrics.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
Expand All @@ -34,6 +35,8 @@
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.translator.DeepARTranslator;
import ai.djl.training.evaluator.Coverage;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.DeferredTranslatorFactory;
import ai.djl.translate.TranslateException;
Expand Down Expand Up @@ -327,15 +330,10 @@ public Map<String, Float> getMetricsPerTs(

for (float quantile : quantiles) {
NDArray forecastQuantile = forecast.quantile(quantile);

NDArray quantileLoss =
forecastQuantile
.sub(gtTarget)
.mul(gtTarget.lte(forecastQuantile).sub(quantile))
.abs()
.sum()
.mul(2);
NDArray quantileCoverage = gtTarget.lt(forecastQuantile).mean();
NDArray quantileLoss = Loss.quantileL1Loss(quantile)
.evaluate(new NDList(gtTarget), new NDList(forecastQuantile));
NDArray quantileCoverage = gtTarget.lt(forecastQuantile)
.toType(DataType.FLOAT32, false).mean();
retMetrics.put(
String.format("QuantileLoss[%.2f]", quantile), quantileLoss.getFloat());
retMetrics.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class DeepARTranslator extends BaseTimeSeriesTranslator {
private boolean useFeatStaticCat;
private int historyLength;

private static final String[] PRED_INPUT_FIELDS = {
private static final String[] MX_PRED_INPUT_FIELDS = {
FieldName.FEAT_STATIC_CAT.name(),
FieldName.FEAT_STATIC_REAL.name(),
"PAST_" + FieldName.FEAT_TIME.name(),
Expand Down Expand Up @@ -164,7 +164,11 @@ public NDList processInput(TranslatorContext ctx, TimeSeriesData input) {
0,
input);

input = Field.selectField(PT_PRED_INPUT_FIELDS, input);
if ("PyTorch".equals(manager.getEngine().getEngineName())) {
input = Field.selectField(PT_PRED_INPUT_FIELDS, input);
} else {
input = Field.selectField(MX_PRED_INPUT_FIELDS, input);
}

return input.toNDList();
}
Expand Down

0 comments on commit 33b1181

Please sign in to comment.