diff --git a/src/fairseq2/checkpoint/_manager.py b/src/fairseq2/checkpoint/_manager.py index c63242165..21d1bd8bb 100644 --- a/src/fairseq2/checkpoint/_manager.py +++ b/src/fairseq2/checkpoint/_manager.py @@ -179,7 +179,7 @@ def __init__( self._tensor_loader = tensor_loader self._tensor_dumper = tensor_dumper - if gangs.tp.rank > 1: + if gangs.tp.size > 1: self._shard_suffix = f".{gangs.tp.rank}" else: self._shard_suffix = "" @@ -468,15 +468,15 @@ def maybe_with_dp_process_group() -> AbstractContextManager[None]: def load_part(filename: str) -> dict[str, object]: with maybe_with_dp_process_group(): # Required for `ShardedTensor`. + file = step_dir.joinpath(filename) + try: - part = self._tensor_loader.load( - step_dir.joinpath(filename), map_location=CPU - ) + part = self._tensor_loader.load(file, map_location=CPU) except FileNotFoundError: part = {} except TensorLoadError as ex: raise CheckpointLoadError( - step_nr, f"The '{filename}' checkpoint file of training step {step_nr} cannot be loaded. See the nested exception for details." # fmt: skip + step_nr, f"The '{file}' checkpoint file of training step {step_nr} cannot be loaded. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() diff --git a/src/fairseq2/utils/file.py b/src/fairseq2/utils/file.py index 25f792e79..46fbdea2e 100644 --- a/src/fairseq2/utils/file.py +++ b/src/fairseq2/utils/file.py @@ -229,6 +229,8 @@ def load_error() -> TensorLoadError: try: fp = self._file_system.open(path) + except FileNotFoundError: + raise except OSError as ex: raise load_error() from ex @@ -236,8 +238,6 @@ def load_error() -> TensorLoadError: data: dict[str, object] = torch.load( fp, map_location, weights_only=self._restrict # type: ignore[arg-type] ) - except FileNotFoundError: - raise except (RuntimeError, OSError, PickleError) as ex: raise load_error() from ex finally: