Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change Checkpoint callback's save_best_only to save_top_k #128

Merged
merged 50 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
6b54c1e
docs: enable syntax highlight
Ir1d Aug 13, 2019
49e35b1
feat: change Checkpoint callback's `save_best_only` to `save_top_k`
Ir1d Aug 17, 2019
08f57b6
docs: update docs for save_top_k
Ir1d Aug 17, 2019
4855bf8
revert other files
Ir1d Aug 17, 2019
2f6f784
style: lint for travis-ci
Ir1d Aug 17, 2019
daae566
fix typo
Ir1d Aug 17, 2019
a7da269
make flake8 happy
Ir1d Aug 17, 2019
373cd1d
update according to review
Ir1d Aug 19, 2019
49f78d6
add tests
Ir1d Aug 19, 2019
a34811a
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 19, 2019
fbc8a4e
rename func to private
Ir1d Aug 19, 2019
52bfcb7
add doc on `save_top_k == 0`
Ir1d Aug 19, 2019
38afd66
make flake8 happy
Ir1d Aug 20, 2019
ab92212
update according to PR comments
Ir1d Aug 22, 2019
7e8767f
change some f-strings
Ir1d Aug 22, 2019
3079b71
Update pt_callbacks.py
williamFalcon Aug 23, 2019
3bb6e3e
Update test_models.py
williamFalcon Aug 23, 2019
4626e1a
update options
Ir1d Aug 23, 2019
6a47d55
create folders
Ir1d Aug 23, 2019
4b3d5bf
Update test_models.py
williamFalcon Aug 23, 2019
d29eff7
change epoch num
Ir1d Aug 23, 2019
22cfeb5
Merge remote-tracking branch 'origin/save_top_k' into save_top_k
Ir1d Aug 23, 2019
d15b03b
support calling multiple times, add docs and tests
Ir1d Aug 23, 2019
58a8410
update docs
Ir1d Aug 23, 2019
b9a855d
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 23, 2019
e1f7a48
Merge remote-tracking branch 'wf/master' into save_top_k
Ir1d Aug 24, 2019
2dd65e9
roll back changes in earlystopping
Ir1d Aug 24, 2019
634ad80
clean test files
Ir1d Aug 24, 2019
1861f7d
rebase upstream
Ir1d Oct 22, 2019
b66b33c
make flake8 happy
Ir1d Oct 22, 2019
aed9ad9
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 4, 2019
abb9629
fix epoch number
Ir1d Nov 4, 2019
a0c2269
update tests about epoch numbers
Ir1d Nov 4, 2019
4490466
clean debugging code
Ir1d Nov 4, 2019
0e98423
fix testing utils codes
Ir1d Nov 4, 2019
2eab575
fix testing utils codes
Ir1d Nov 4, 2019
27ebd1c
fix testing utils codes
Ir1d Nov 4, 2019
37cfedf
fix testing utils codes
Ir1d Nov 4, 2019
3f49122
change save_dir to tests/tests according to previous lines
Ir1d Nov 4, 2019
585fff0
remove unused overwrite option
Ir1d Nov 4, 2019
94a6c49
make flake8 happy
Ir1d Nov 4, 2019
b983860
change var name as per review
Ir1d Nov 4, 2019
c3f56dd
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
13803ff
make flake8 happy
Ir1d Nov 5, 2019
7c58669
Merge remote-tracking branch 'will/master' into save_top_k
Ir1d Nov 5, 2019
235e5f1
update property name to work on master
Ir1d Nov 5, 2019
8635612
elaborate in the docs
Ir1d Nov 5, 2019
e59d5bd
update docs as per review
Ir1d Nov 6, 2019
3f05419
Merge branch 'master' into save_top_k
Ir1d Nov 16, 2019
b000db9
revert previous commit
Ir1d Nov 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions docs/Trainer/Checkpointing.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Lightning can automate saving and loading checkpoints.

---

### Model saving
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in :
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -10,13 +11,13 @@ Trainer(default_save_path='/your/path/to/save/checkpoints')

To modify the behavior of checkpointing pass in your own callback.

``` {.python}
```{.python}
from pytorch_lightning.callbacks import ModelCheckpoint

# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
save_top_k=-1,
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
verbose=True,
monitor='val_loss',
mode='min',
Expand All @@ -27,10 +28,12 @@ trainer = Trainer(checkpoint_callback=checkpoint_callback)
```

---
### Restoring training session

### Restoring training session

You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).

Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint.
``` {.python}
Expand All @@ -52,18 +55,19 @@ trainer = Trainer(
trainer.fit(model)
```

The trainer restores:
The trainer restores:

- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- global_step
- current_epoch
- All optimizers
- All lr_schedulers
- Model weights

You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.
You can even change the logic of your model as long as the weights and "architecture" of
the system isn't different. If you add a layer, for instance, it might not work.

At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63):

At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63):
```python

self.global_step = checkpoint['global_step']
Expand All @@ -79,6 +83,6 @@ lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)

# uses the model you passed into trainer
# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
```
```
41 changes: 25 additions & 16 deletions docs/examples/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ In 99% of cases you want to just copy [one of the examples](https://github.com/w
wget https://raw.githubusercontent.com/williamFalcon/pytorch-lightning/master/pl_examples/new_project_templates/lightning_module_template.py
```

---
### Trainer Example
---

### Trainer Example

** \_\_main__ function**
** \_\_main\_\_ function**

Normally, we want to let the \_\_main__ function start the training.
Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a
chance to add hyperparameters.
Normally, we want to let the \_\_main\_\_ function start the training.
Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a
chance to add hyperparameters.

```{.python}
from test_tube import HyperOptArgumentParser
Expand All @@ -32,13 +33,15 @@ if __name__ == '__main__':
# train model
main(hyperparams)
```
**Main Function**

**Main Function**

The main function is your entry into the program. This is where you init your model, checkpoint directory, and launch the training.
The main function should have 3 arguments:
- hparams: a configuration of hyperparameters.
The main function should have 3 arguments:

- hparams: a configuration of hyperparameters.
- slurm_manager: Slurm cluster manager object (can be None)
- dict: for you to return any values you want (useful in meta-learning, otherwise set to _)
- dict: for you to return any values you want (useful in meta-learning, otherwise set to \_)

```python
def main(hparams, cluster, results_dict):
Expand All @@ -62,13 +65,15 @@ The __main__ function will start training on your **main** function. If you use
in hyper parameter optimization mode, this main function will get one set of hyperparameters. If you use it as a simple
argument parser you get the default arguments in the argument parser.

So, calling main(hyperparams) runs the model with the default argparse arguments.
So, calling main(hyperparams) runs the model with the default argparse arguments.

```{.python}
main(hyperparams)
```

---
#### CPU hyperparameter search

#### CPU hyperparameter search

```{.python}
# run a grid search over 20 hyperparameter combinations.
Expand All @@ -80,7 +85,9 @@ hyperparams.optimize_parallel_cpu(
```

---
#### Hyperparameter search on a single or multiple GPUs

#### Hyperparameter search on a single or multiple GPUs

```{.python}
# run a grid search over 20 hyperparameter combinations.
hyperparams.optimize_parallel_gpu(
Expand All @@ -92,8 +99,10 @@ hyperparams.optimize_parallel_gpu(
```

---
#### Hyperparameter search on a SLURM HPC cluster
```{.python}

#### Hyperparameter search on a SLURM HPC cluster

```{.python}
def optimize_on_cluster(hyperparams):
# enable cluster training
cluster = SlurmCluster(
Expand Down Expand Up @@ -126,6 +135,6 @@ def optimize_on_cluster(hyperparams):
job_name=job_display_name
)

# run cluster hyperparameter search
# run cluster hyperparameter search
optimize_on_cluster(hyperparams)
```
130 changes: 94 additions & 36 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,17 @@ class ModelCheckpoint(Callback):
filepath: string, path to save the model file.
monitor: quantity to monitor.
verbose: verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`,
the latest best model according to
the quantity monitored will not be overwritten.
save_top_k: if `save_top_k == k`,
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
the best k models according to
the quantity monitored will be saved.
if `save_top_k == 0`, no models are saved.
if `save_top_k == -1`, all models are saved.
Please note that the monitors are checked every `period` epochs.
if `save_top_k >= 2` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode: one of {auto, min, max}.
If `save_best_only=True`, the decision
If `save_top_k != 0`, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
Expand All @@ -176,17 +182,23 @@ class ModelCheckpoint(Callback):
"""

def __init__(self, filepath, monitor='val_loss', verbose=0,
save_best_only=True, save_weights_only=False,
save_top_k=1, save_weights_only=False,
mode='auto', period=1, prefix=''):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
if not os.path.exists(filepath):
os.makedirs(filepath)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0

if mode not in ['auto', 'min', 'max']:
warnings.warn(
Expand All @@ -196,66 +208,112 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,

if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
self.kth_value = np.Inf
self.mode = 'min'
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
self.kth_value = -np.Inf
self.mode = 'max'
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
self.kth_value = -np.Inf
self.mode = 'max'
else:
self.monitor_op = np.less
self.best = np.Inf
self.kth_value = np.Inf
self.mode = 'min'

def save_model(self, filepath, overwrite):
dirpath = '/'.join(filepath.split('/')[:-1])
def _del_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
os.makedirs(dirpath, exist_ok=True)

if overwrite:
for filename in os.listdir(dirpath):
if self.prefix in filename:
path_to_delete = os.path.join(dirpath, filename)
try:
shutil.rmtree(path_to_delete)
except OSError:
os.remove(path_to_delete)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should only remove files this callback saved. For instance this would remove other checkpoints the user drags in manually or the ones saved by slurm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what del_model and save_model is intended to be. Through my experiments I noticed that the original implementation simply delete all the models in the corresponding folder. I modified the functions so that they delete only the filepath model. AFAIK, these two functions are called with the exact model filepath.

shutil.rmtree(filepath)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shutil.rmtree has a parameter ignore_errors so there is no need for this try/except...
https://docs.python.org/2/library/shutil.html

except OSError:
os.remove(filepath)

def _save_model(self, filepath):
dirpath = os.path.dirname(filepath)

# make paths
os.makedirs(dirpath, exist_ok=True)

# delegate the saving to the model
self.save_function(filepath)

def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k
if less_than_k_models:
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1)
if self.save_best_only:
self.epochs_since_last_check += 1

if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
version_cnt += 1

print(filepath)

if self.save_top_k != -1:
current = logs.get(self.monitor)

if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.check_monitor_top_k(current):

# remove kth
if len(self.best_k_models.keys()) == self.save_top_k:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you have a method _del_model which is quite empty, so move this logic about removing k-th model there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so :)
_save_model handles a filename to save, and _del_model should do the same, and handles a filename to delete.

delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models.keys()) == self.save_top_k:
# monitor dict has reached k elements
Ir1d marked this conversation as resolved.
Show resolved Hide resolved
if self.mode == 'min':
self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get)
else:
self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]

if self.mode == 'min':
self.best = min(self.best_k_models.values())
else:
self.best = max(self.best_k_models.values())
if self.verbose > 0:
logging.info(
f'\nEpoch {epoch + 1:05d}: {self.monitor} improved'
f' from {self.best:0.5f} to {current:0.5f},',
f' saving model to {filepath}')
self.best = current
self.save_model(filepath, overwrite=True)
f'\nEpoch {epoch:05d}: {self.monitor} reached',
f'{current:0.5f} (best {self.best:0.5f}), saving model to',
f'{filepath} as top {self.save_top_k}')
self._save_model(filepath)

else:
if self.verbose > 0:
logging.info(
f'\nEpoch {epoch + 1:05d}: {self.monitor} did not improve')
f'\nEpoch {epoch:05d}: {self.monitor}',
f'was not in top {self.save_top_k}')

else:
if self.verbose > 0:
logging.info(f'\nEpoch {epoch + 1:05d}: saving model to {filepath}')
self.save_model(filepath, overwrite=False)
logging.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)


class GradientAccumulationScheduler(Callback):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_a_restore_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_load_model_from_checkpoint():
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = LightningTestModel.load_from_checkpoint(
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt")
)

# test that hparams loaded correctly
Expand Down
Loading