diff --git a/cli/conf/finetune/default.yaml b/cli/conf/finetune/default.yaml index 9c33817..c17c80a 100644 --- a/cli/conf/finetune/default.yaml +++ b/cli/conf/finetune/default.yaml @@ -10,6 +10,7 @@ run_name: ??? seed: 0 tf32: true compile: false # set to mode: default, reduce-overhead, max-autotune +ckpt_path: null trainer: _target_: lightning.Trainer accelerator: auto diff --git a/cli/conf/pretrain/default.yaml b/cli/conf/pretrain/default.yaml index 33cf519..8bb5e0d 100644 --- a/cli/conf/pretrain/default.yaml +++ b/cli/conf/pretrain/default.yaml @@ -10,6 +10,7 @@ run_name: ??? seed: 0 tf32: true compile: false # set to mode: default, reduce-overhead, max-autotune +ckpt_path: null # set to "last" to resume training trainer: _target_: lightning.Trainer accelerator: auto diff --git a/cli/train.py b/cli/train.py index ee2aa3b..aa073cc 100644 --- a/cli/train.py +++ b/cli/train.py @@ -142,6 +142,7 @@ def main(cfg: DictConfig): trainer.fit( model, datamodule=DataModule(cfg, train_dataset, val_dataset), + ckpt_path=cfg.ckpt_path, )