-
Notifications
You must be signed in to change notification settings - Fork 654
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create EasyHpo to simplify training with hyperparameters (#488)
* Create EasyHpo to simplify training with hyperparameters Change-Id: If0896dadba41bf84109344fa814c8039c81694dc * Update api/src/main/java/ai/djl/training/hyperparameter/EasyHpo.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com> * Update api/src/main/java/ai/djl/training/hyperparameter/EasyHpo.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com> * Update api/src/main/java/ai/djl/training/hyperparameter/package-info.java Co-authored-by: Frank Liu <frankfliu2000@gmail.com> * Fix TrainingConfig change Change-Id: I2f7e74182c972ba82f4a92bb85c42c3f64467b52 Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
- Loading branch information
Showing
3 changed files
with
261 additions
and
98 deletions.
There are no files selected for viewing
161 changes: 161 additions & 0 deletions
161
api/src/main/java/ai/djl/training/hyperparameter/EasyHpo.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.training.hyperparameter; | ||
|
||
import ai.djl.Model; | ||
import ai.djl.metric.Metrics; | ||
import ai.djl.ndarray.types.Shape; | ||
import ai.djl.training.EasyTrain; | ||
import ai.djl.training.Trainer; | ||
import ai.djl.training.TrainingConfig; | ||
import ai.djl.training.TrainingResult; | ||
import ai.djl.training.dataset.Dataset; | ||
import ai.djl.training.dataset.RandomAccessDataset; | ||
import ai.djl.training.hyperparameter.optimizer.HpORandom; | ||
import ai.djl.training.hyperparameter.optimizer.HpOptimizer; | ||
import ai.djl.training.hyperparameter.param.HpSet; | ||
import ai.djl.translate.TranslateException; | ||
import ai.djl.util.Pair; | ||
import java.io.IOException; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
/** Helper for easy training with hyperparameters. */ | ||
public abstract class EasyHpo { | ||
|
||
private static final Logger logger = LoggerFactory.getLogger(EasyHpo.class); | ||
|
||
/** | ||
* Fits the model given the implemented abstract methods. | ||
* | ||
* @return the best model and training results | ||
* @throws IOException for various exceptions depending on the dataset | ||
* @throws TranslateException if there is an error while processing input | ||
*/ | ||
public Pair<Model, TrainingResult> fit() throws IOException, TranslateException { | ||
|
||
// get training and validation dataset | ||
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN); | ||
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST); | ||
|
||
HpSet hyperParams = setupHyperParams(); | ||
HpOptimizer hpOptimizer = new HpORandom(hyperParams); | ||
|
||
final int hyperparameterTests = numHyperParameterTests(); | ||
|
||
for (int i = 0; i < hyperparameterTests; i++) { | ||
HpSet hpVals = hpOptimizer.nextConfig(); | ||
Pair<Model, TrainingResult> trained = train(hpVals, trainingSet, validateSet); | ||
trained.getKey().close(); | ||
float loss = trained.getValue().getValidateLoss(); | ||
hpOptimizer.update(hpVals, loss); | ||
logger.info( | ||
"--------- hp test {}/{} - Loss {} - {}", i, hyperparameterTests, loss, hpVals); | ||
} | ||
|
||
HpSet bestHpVals = hpOptimizer.getBest().getKey(); | ||
Pair<Model, TrainingResult> trained = train(bestHpVals, trainingSet, validateSet); | ||
TrainingResult result = trained.getValue(); | ||
|
||
Model model = trained.getKey(); | ||
saveModel(model, result); | ||
return trained; | ||
} | ||
|
||
private Pair<Model, TrainingResult> train( | ||
HpSet hpVals, RandomAccessDataset trainingSet, RandomAccessDataset validateSet) | ||
throws IOException, TranslateException { | ||
|
||
// Construct neural network | ||
Model model = buildModel(hpVals); | ||
|
||
// setup training configuration | ||
TrainingConfig config = setupTrainingConfig(hpVals); | ||
|
||
try (Trainer trainer = model.newTrainer(config)) { | ||
trainer.setMetrics(new Metrics()); | ||
|
||
// initialize trainer with proper input shape | ||
trainer.initialize(inputShape(hpVals)); | ||
|
||
EasyTrain.fit(trainer, numEpochs(hpVals), trainingSet, validateSet); | ||
|
||
TrainingResult result = trainer.getTrainingResult(); | ||
return new Pair<>(model, result); | ||
} | ||
} | ||
|
||
/** | ||
* Returns the initial hyperparameters. | ||
* | ||
* @return the initial hyperparameters | ||
*/ | ||
protected abstract HpSet setupHyperParams(); | ||
|
||
/** | ||
* Returns the dataset to train with. | ||
* | ||
* @param usage the usage of the dataset | ||
* @return the dataset to train with | ||
* @throws IOException if the dataset could not be loaded | ||
*/ | ||
protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException; | ||
|
||
/** | ||
* Returns the {@link ai.djl.training.TrainingConfig} to use to train each hyperparameter set. | ||
* | ||
* @param hpVals the hyperparameters to train with | ||
* @return the {@link ai.djl.training.TrainingConfig} to use to train each hyperparameter set | ||
*/ | ||
protected abstract TrainingConfig setupTrainingConfig(HpSet hpVals); | ||
|
||
/** | ||
* Builds the {@link Model} and {@link ai.djl.nn.Block} to train. | ||
* | ||
* @param hpVals the hyperparameter values to use for the model | ||
* @return the model to train | ||
*/ | ||
protected abstract Model buildModel(HpSet hpVals); | ||
|
||
/** | ||
* Returns the input shape for the model. | ||
* | ||
* @param hpVals the hyperparameter values for the model | ||
* @return returns the model input shape | ||
*/ | ||
protected abstract Shape inputShape(HpSet hpVals); | ||
|
||
/** | ||
* Returns the number of epochs to train for the current hyperparameter set. | ||
* | ||
* @param hpVals the current hyperparameter set | ||
* @return the number of epochs | ||
*/ | ||
protected abstract int numEpochs(HpSet hpVals); | ||
|
||
/** | ||
* Returns the number of hyperparameter sets to train with. | ||
* | ||
* @return the number of hyperparameter sets to train with | ||
*/ | ||
protected abstract int numHyperParameterTests(); | ||
|
||
/** | ||
* Saves the best hyperparameter set. | ||
* | ||
* @param model the model to save | ||
* @param result the training result for training with this model's hyperparameters | ||
* @throws IOException if the model could not be saved | ||
*/ | ||
protected void saveModel(Model model, TrainingResult result) throws IOException {} | ||
} |
18 changes: 18 additions & 0 deletions
18
api/src/main/java/ai/djl/training/hyperparameter/package-info.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
/* | ||
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
|
||
/** | ||
* Contains utilities to train, describe, and manipulate {@link | ||
* ai.djl.training.hyperparameter.param.Hyperparameter}s. | ||
*/ | ||
package ai.djl.training.hyperparameter; |
Oops, something went wrong.