diff --git a/nemo/lightning/ckpt_utils.py b/nemo/lightning/ckpt_utils.py index ae1fe520a119..fa588092497a 100644 --- a/nemo/lightning/ckpt_utils.py +++ b/nemo/lightning/ckpt_utils.py @@ -33,12 +33,6 @@ def idempotent_path_append(base_dir: Union[str, Path], suffix) -> Path: return base_dir -def ckpt_to_weights_subdir(filepath: Union[str, Path]) -> Path: - """Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory.""" - base_dir = ckpt_to_dir(filepath=filepath) - return idempotent_path_append(base_dir, WEIGHTS_PATH) - - def ckpt_to_context_subdir(filepath: Union[str, Path]) -> Path: """Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the context subdirectory.""" base_dir = ckpt_to_dir(filepath=filepath) diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index b1ca867cab83..60eb518a1e42 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -22,7 +22,7 @@ from torch import nn from typing_extensions import Self, override -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.io.mixin import IOMixin, serialization, track_io if TYPE_CHECKING: @@ -83,7 +83,7 @@ def load_model( model = context.model dist_model = self.setup_module(model) - self.load(ckpt_to_weights_subdir(path), {"state_dict": dist_model}) + self.load(path, {"state_dict": dist_model}) return dist_model diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index fd7b814fe730..2ccb9bb1b1fe 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -22,7 +22,7 @@ from filelock import FileLock, Timeout from pytorch_lightning.trainer.states import TrainerFn -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir # Dynamically inherit from the correct Path subclass based on the operating system. if os.name == 'nt': @@ -198,7 +198,7 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True trainer.strategy.setup(trainer) output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) - trainer.save_checkpoint(ckpt_to_weights_subdir(output_path)) + trainer.save_checkpoint(output_path) if getattr(trainer.strategy, "async_save", False): trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 1a7880e38492..10ed52b136c2 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -37,7 +37,7 @@ from torch import nn from typing_extensions import Self, override -from nemo.lightning.ckpt_utils import ckpt_to_dir +from nemo.lightning.ckpt_utils import WEIGHTS_PATH, ckpt_to_dir from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.mixin import IOMixin @@ -78,6 +78,26 @@ def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]: return extra +def ckpt_to_weights_subdir(filepath: Union[str, Path], is_saving) -> Path: + """Given an input checkpoint filepath, clean it using `ckpt_to_dir` and then return the weights subdirectory, if it exists.""" + filepath = ckpt_to_dir(filepath=filepath) + base_dir = filepath + assert isinstance(base_dir, Path) + if base_dir.parts[-1] != WEIGHTS_PATH: + maybe_base_dir = base_dir / WEIGHTS_PATH + if maybe_base_dir.is_dir() or is_saving: + base_dir = maybe_base_dir + ## handle adapter paths + if hasattr(base_dir, "base_model_path") and base_dir.base_model_path.parts[-1] != WEIGHTS_PATH: + maybe_base_model_path = base_dir.base_model_path / WEIGHTS_PATH + if maybe_base_model_path.is_dir() or is_saving: + base_dir.base_model_path = base_dir.base_model_path / WEIGHTS_PATH + if is_saving: + assert base_dir.parts[-1] == WEIGHTS_PATH + assert base_dir.parent == Path(filepath) + return base_dir + + class MegatronCheckpointIO(AsyncCompatibleCheckpointIO, IOMixin): """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, common for most use cases. @@ -132,7 +152,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio f" storage_options, but {storage_options=} was provided." f" Ignoring given storage_options" ) - checkpoint_dir = ckpt_to_dir(path) + checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True) + fs = get_filesystem(checkpoint_dir) if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') @@ -180,6 +201,11 @@ def load_checkpoint( if not fs.isdir(path): raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.") + # Load from ckpt_path/weights (new format) if it exists + path = ckpt_to_weights_subdir(path, is_saving=False) + if hasattr(path, "base_model_path") and not path.base_model_path.exists(): + path.base_model_path = path.base_model_path.parent + if self.save_ckpt_format == 'zarr' and self.load_directly_on_device: from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index cffa8b9275ff..b384976d82bd 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -58,7 +58,6 @@ class ModelCheckpoint(PTLModelCheckpoint): """ UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" - WEIGHTS_PATH = "weights" def __init__( self, @@ -438,7 +437,6 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. - ckpt_filepath = ckpt_to_dir(filepath) / ModelCheckpoint.WEIGHTS_PATH self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) @@ -455,15 +453,15 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) if self.async_save: raise ValueError('async_save with EMA not supported') with ema_callback.save_original_optimizer_state(trainer): - super()._save_checkpoint(trainer, ckpt_filepath) + super()._save_checkpoint(trainer, filepath) # save EMA copy of the model as well. with ema_callback.save_ema_model(trainer): - rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") - ckpt_filepath = self._ema_format_filepath(ckpt_filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + filepath = self._ema_format_filepath(filepath) if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {ckpt_filepath}") - super()._save_checkpoint(trainer, ckpt_filepath) + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) else: ## Determine whether to include optimizer states in the checkpoint @@ -489,7 +487,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self.deferred_ckpts_to_remove.append([]) else: storage_options = None - trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options) + trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options) if self.always_save_context and is_global_rank_zero(): TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"]) @@ -598,11 +596,11 @@ def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: } checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} - for ckpt_filepath in checkpoint_filepaths: - possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath) + for filepath in checkpoint_filepaths: + possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath) if possible_marker_path in existing_marker_filepaths: - logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}') - os.remove(ckpt_filepath) + logging.warning(f'Removing unfinished checkpoint: {filepath}') + os.remove(filepath) # some directories might be distributed checkpoints, we remove these if they have a unfinished marker all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 906cbd6e450e..5336615a4a38 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -28,7 +28,7 @@ from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME from nemo.lightning.io.mixin import IOMixin -from nemo.lightning.io.pl import ckpt_to_dir +from nemo.lightning.io.pl import ckpt_to_dir, ckpt_to_weights_subdir from nemo.lightning.megatron_parallel import MegatronParallel from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule @@ -346,7 +346,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio if is_global_rank_zero(): metadata = {"model_ckpt_path": str(self.model_ckpt_path)} - base_dir = ckpt_to_dir(path) + base_dir = ckpt_to_weights_subdir(path, is_saving=True) base_dir.mkdir(parents=True, exist_ok=True) adapter_meta_path = base_dir / ADAPTER_META_FILENAME with open(adapter_meta_path, "w") as f: diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index e99be666ec04..8a0147a4613a 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -57,7 +57,6 @@ from nemo.core.optim.mcore_optim import McoreDistributedOptimizer from nemo.lightning import _strategy_lib, io -from nemo.lightning.ckpt_utils import ckpt_to_weights_subdir from nemo.lightning.megatron_parallel import ( CallbackConnector, MegatronParallel, @@ -703,13 +702,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: if self.lightning_module.optimizers(use_pl_optimizer=False): sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] - # Load from ckpt_path/weights (new format) if it exists, otherwise load from ckpt_path (legacy format) - load_dir = ckpt_to_weights_subdir(checkpoint_path) - if not load_dir.exists(): - load_dir = checkpoint_path - if isinstance(load_dir, AdapterPath) and not load_dir.base_model_path.exists(): - load_dir.base_model_path = load_dir.base_model_path.parent - checkpoint = self.checkpoint_io.load_checkpoint(load_dir, sharded_state_dict=sharded_state_dict) + checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict) return checkpoint diff --git a/tests/collections/llm/bitexact/mixtral/run.sh b/tests/collections/llm/bitexact/mixtral/run.sh index c32dbbc95b98..0fe9e331b18a 100644 --- a/tests/collections/llm/bitexact/mixtral/run.sh +++ b/tests/collections/llm/bitexact/mixtral/run.sh @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral. # Compare outputs python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \ - "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/" "$MCORE_OUTPUT_PATH/iter_0000010/" + "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" diff --git a/tests/collections/llm/megatron_mixtral_pretraining.py b/tests/collections/llm/megatron_mixtral_pretraining.py index 82188f75351e..b4c5b960e0a7 100644 --- a/tests/collections/llm/megatron_mixtral_pretraining.py +++ b/tests/collections/llm/megatron_mixtral_pretraining.py @@ -158,7 +158,7 @@ def main(args): ) # Confirm checkpoint directory structure - output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/" + output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/weights" assert output_path.exists(), f"Expected {output_path} to exist" assert output_path.is_dir(), f"Expected {output_path} to be a directory" output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata']