From 03ab63c6692d51beed657054ad91fa08f1d9feb1 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 24 Feb 2020 11:20:46 -0800 Subject: [PATCH] Fix transfer learning jupyter notebook. Change-Id: I0263fc6eab7646fe63f7f10a59119543070b75f0 --- jupyter/TrainingUtils.java | 46 +------------------------------------- 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/jupyter/TrainingUtils.java b/jupyter/TrainingUtils.java index 37a0cb25b2c..a58152e0693 100644 --- a/jupyter/TrainingUtils.java +++ b/jupyter/TrainingUtils.java @@ -1,9 +1,7 @@ import ai.djl.Model; import ai.djl.training.Trainer; -import ai.djl.training.TrainingListener; import ai.djl.training.dataset.Batch; import ai.djl.training.dataset.Dataset; -import ai.djl.training.util.ProgressBar; import java.io.IOException; import java.nio.file.Paths; @@ -31,7 +29,7 @@ public static void fit( } } // reset training and validation evaluators at end of epoch - trainer.resetEvaluators(); + trainer.endEpoch(); // save model at end of each epoch if (outputDir != null) { Model model = trainer.getModel(); @@ -40,46 +38,4 @@ public static void fit( } } } - - public static TrainingListener getTrainingListener( - ProgressBar trainingProgressBar, ProgressBar validateProgressBar) { - return new SimpleTrainingListener(trainingProgressBar, validateProgressBar); - } - - private static final class SimpleTrainingListener implements TrainingListener { - - private ProgressBar trainingProgressBar; - private ProgressBar validateProgressBar; - private int trainingProgress; - private int validateProgress; - - public SimpleTrainingListener( - ProgressBar trainingProgressBar, ProgressBar validateProgressBar) { - this.trainingProgressBar = trainingProgressBar; - this.validateProgressBar = validateProgressBar; - } - - /** {@inheritDoc} */ - @Override - public void onTrainingBatch() { - if (trainingProgressBar != null) { - trainingProgressBar.update(trainingProgress++); - } - } - - /** {@inheritDoc} */ - @Override - public void onValidationBatch() { - if (validateProgressBar != null) { - validateProgressBar.update(validateProgress++); - } - } - - /** {@inheritDoc} */ - @Override - public void onEpoch() { - trainingProgress = 0; - validateProgress = 0; - } - } }