Skip to content

Commit

Permalink
Refactor load in checkpoint connector (Lightning-AI#4593)
Browse files Browse the repository at this point in the history
* Refactor load step commentaries

* Refactor hpc ckpt suffix acquisition

* Refactor restore/hpc_load match

* Refactor hpc load trial

* Refactor checkpoint dir check

* Refactor unneeded function nest

* Refactor nested If

* Refactor duplicated cache clear

* Refactor attempt flow with if/elif

* Fix pip8

* Refactor hook commentary

Co-authored-by: chaton <thomas@grid.ai>

* Fix pep8

* Refactor hpc load checkpoint path acquisition

* Fix pip8

* Fix doc

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Refactor None Union type with Optional

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
  • Loading branch information
5 people authored Dec 13, 2020
1 parent 398f122 commit 16feb51
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 62 deletions.
116 changes: 58 additions & 58 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.

import os
from pathlib import Path
import re
from typing import Union, Optional

import torch

import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
Expand Down Expand Up @@ -52,16 +54,17 @@ def restore_weights(self, model: LightningModule):
if self.trainer.on_gpu:
torch.cuda.empty_cache()

# if script called from hpc resubmit, load weights
did_restore_hpc_weights = self.restore_hpc_weights_if_needed(model)
# 1. Attempt to restore states from HPC checkpoint
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer.on_gpu)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# clear cache after restore
if self.trainer.on_gpu:
torch.cuda.empty_cache()

if not did_restore_hpc_weights:
if self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
Expand All @@ -72,24 +75,14 @@ def restore_weights(self, model: LightningModule):

def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from the checkpoint file through file-read and state-restore.
Also restores all training state like:
- epoch
- callbacks
- schedulers
- optimizer
In detail, check return value description of `dump_checkpoint`
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# if on_gpu:
# checkpoint = torch.load(checkpoint_path)
# else:
# load on CPU first
# read a checkpoint dictionary object from the checkpoint file at `checkpoint_path`
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# restore states from the checkpoint dictionary object
# load model state
# acquire the model
model = self.trainer.get_model()

# restore model and datamodule state
Expand All @@ -106,14 +99,14 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
"""

# give the datamodule a chance to load something
# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)

# give model a chance to restore something
# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)

# restore the state_dict on the model
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])

def restore_training_state(self, checkpoint):
Expand Down Expand Up @@ -187,23 +180,6 @@ def restore_training_state(self, checkpoint):
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)

def restore_hpc_weights_if_needed(self, model: LightningModule):
"""If there is a set of hpc weights, use as signal to restore model."""
did_restore = False

# look for hpc weights
folderpath = str(self.trainer.weights_save_path)
fs = get_filesystem(folderpath)
if fs.exists(folderpath):
files = [os.path.basename(f['name']) for f in fs.listdir(folderpath)]
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
if len(hpc_weight_paths) > 0:
self.hpc_load(folderpath, self.trainer.on_gpu)
did_restore = True
return did_restore

# ----------------------------------
# PRIVATE OPS
# ----------------------------------
Expand All @@ -216,7 +192,8 @@ def hpc_save(self, folderpath: str, logger):
# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
max_suffix = self.max_ckpt_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1

fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
Expand Down Expand Up @@ -333,36 +310,52 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
"""

# load on CPU first
checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage)
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# load model state
# acquire the model
model = self.trainer.get_model()

# restore states from 'PyTorch-Lightning checkpoint' dictionary object
# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)

# load training state (affects trainer only)
# restore training state
self.restore_training_state(checkpoint)

# call model hook
# call hpc specific hook
model.on_hpc_load(checkpoint)

log.info(f'restored hpc model from: {filepath}')
def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with name_key, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
Returns:
None if no-corresponding-file else maximum suffix number
"""

# check directory existence
fs = get_filesystem(dir_path)
if not fs.exists(dir_path):
return None

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
fs = get_filesystem(path)
files = [os.path.basename(f["name"]) for f in fs.listdir(path)]
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
return None

# extract suffix number
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
Expand All @@ -371,6 +364,13 @@ def max_ckpt_in_folder(self, path, name_key='ckpt_'):

return max(ckpt_vs)

def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""

max_suffix = self.max_ckpt_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'

def save_checkpoint(self, filepath, weights_only: bool = False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand Down
6 changes: 4 additions & 2 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
trainer.init_optimizers(pretrained_model)

# test HPC loading / saving
# test HPC saving
trainer.checkpoint_connector.hpc_save(save_dir, logger)
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)


def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
Expand Down
6 changes: 4 additions & 2 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def run_test_from_config(trainer_options):
for dataloader in test_loaders:
run_prediction(dataloader, pretrained_model)

# test HPC loading / saving
# test HPC saving
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=args.on_gpu)

if args.on_gpu:
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)
Expand Down

0 comments on commit 16feb51

Please sign in to comment.