diff --git a/src/fairseq2/assets/_metadata_provider.py b/src/fairseq2/assets/_metadata_provider.py index 24f3b35cc..663c63004 100644 --- a/src/fairseq2/assets/_metadata_provider.py +++ b/src/fairseq2/assets/_metadata_provider.py @@ -119,7 +119,7 @@ def __init__( @override def _load_cache(self) -> dict[str, dict[str, object]]: - path = self._file_system.resolve(self._path) + path = self._path cache = {} @@ -139,7 +139,14 @@ def cache_file(file: Path, source: str) -> None: cache[name] = metadata - if self._file_system.is_dir(path): + try: + is_dir = self._file_system.is_dir(path) + except OSError as ex: + raise AssetMetadataError( + f"The '{path}' path cannot be accessed. See the nested exception for details." + ) from ex + + if is_dir: source = f"directory:{path}" def on_error(ex: OSError) -> NoReturn: diff --git a/src/fairseq2/checkpoint/__init__.py b/src/fairseq2/checkpoint/__init__.py index 221d3462e..04d27a12b 100644 --- a/src/fairseq2/checkpoint/__init__.py +++ b/src/fairseq2/checkpoint/__init__.py @@ -6,11 +6,14 @@ from __future__ import annotations +from fairseq2.checkpoint._manager import CheckpointDeleteError as CheckpointDeleteError from fairseq2.checkpoint._manager import CheckpointError as CheckpointError +from fairseq2.checkpoint._manager import CheckpointLoadError as CheckpointLoadError from fairseq2.checkpoint._manager import CheckpointManager as CheckpointManager from fairseq2.checkpoint._manager import ( CheckpointNotFoundError as CheckpointNotFoundError, ) +from fairseq2.checkpoint._manager import CheckpointSaveError as CheckpointSaveError from fairseq2.checkpoint._manager import FileCheckpointManager as FileCheckpointManager from fairseq2.checkpoint._metadata_provider import ( FileCheckpointMetadataProvider as FileCheckpointMetadataProvider, diff --git a/src/fairseq2/checkpoint/_manager.py b/src/fairseq2/checkpoint/_manager.py index 112d873fb..f18c22891 100644 --- a/src/fairseq2/checkpoint/_manager.py +++ b/src/fairseq2/checkpoint/_manager.py @@ -8,7 +8,7 @@ import warnings from abc import ABC, abstractmethod -from collections.abc import Iterator, Mapping, Set +from collections.abc import Iterable, Iterator, Mapping, Set from contextlib import AbstractContextManager, nullcontext from pathlib import Path from typing import final @@ -211,26 +211,28 @@ def save_checkpoint_card( "model_config": unstructure(config), } + if tokenizer_name is not None: + metadata["tokenizer_ref"] = tokenizer_name + if self._num_shards != 1: metadata["num_shards"] = self._num_shards - metadata["tokenizer_ref"] = tokenizer_name + metadata_file = self._checkpoint_dir.joinpath("model.yaml") + + def save_error() -> CheckpointError: + return CheckpointError( + f"The model metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details." + ) try: self._file_system.make_directory(self._checkpoint_dir) except OSError as ex: - raise CheckpointError( - f"The '{self._checkpoint_dir}' directory cannot be created. See the nested exception for details." - ) from ex - - metadata_file = self._checkpoint_dir.joinpath("model.yaml") + raise save_error() from ex try: self._yaml_dumper.dump(metadata, metadata_file) except OSError as ex: - raise CheckpointError( - f"The model metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details." - ) from ex + raise save_error() from ex self._gangs.root.barrier() @@ -242,8 +244,8 @@ def begin_checkpoint(self, step_nr: int) -> None: try: self.delete_checkpoint(step_nr, missing_ok=True) except CheckpointError as ex: - raise CheckpointError( - f"The previous checkpoint of training step {step_nr} cannot be deleted. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The previous checkpoint of training step {step_nr} cannot be deleted. See the nested exception for details." # fmt: skip ) from ex if self._gangs.root.rank == 0: @@ -252,8 +254,8 @@ def begin_checkpoint(self, step_nr: int) -> None: try: self._file_system.make_directory(tmp_step_dir) except OSError as ex: - raise CheckpointError( - f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be created. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be created. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -300,8 +302,8 @@ def model_replicated() -> bool: {model_key: state_dict, "model_key": model_key}, model_file ) except TensorDumpError as ex: - raise CheckpointError( - f"The replicated model state of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The replicated model state of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -328,8 +330,8 @@ def model_replicated() -> bool: try: self._tensor_dumper.dump(replicated_part, replicated_file) except TensorDumpError as ex: - raise CheckpointError( - f"The replicated checkpoint state of training step {step_nr} cannot be saved to the '{replicated_file}' file. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The replicated checkpoint state of training step {step_nr} cannot be saved to the '{replicated_file}' file. See the nested exception for details." # fmt: skip ) from ex else: if "*" in replicated_keys: @@ -357,8 +359,8 @@ def model_replicated() -> bool: try: self._tensor_dumper.dump(rank_part, rank_file) except TensorDumpError as ex: - raise CheckpointError( - f"The checkpoint state of training step {step_nr} cannot be saved to the '{rank_file}' file. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The checkpoint state of training step {step_nr} cannot be saved to the '{rank_file}' file. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -378,8 +380,8 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None: try: self._tensor_dumper.dump(metadata, metadata_file) except TensorDumpError as ex: - raise CheckpointError( - f"The checkpoint metadata of training step {step_nr} cannot be saved to the '{metadata_file}' file. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The checkpoint metadata of training step {step_nr} cannot be saved to the '{metadata_file}' file. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -391,16 +393,22 @@ def save_score(self, score: float | None, lower_better: bool = False) -> None: if self._gangs.root.rank == 0 and score is not None: score_file = self._checkpoint_dir.joinpath(f"step_{step_nr}.tmp/score.txt") - fp = self._file_system.open_text(score_file, mode=FileMode.WRITE) + def save_error() -> CheckpointError: + return CheckpointSaveError( + step_nr, f"The checkpoint score of training step {step_nr} cannot be saved to the '{score_file}' file. See the nested exception for details." # fmt: skip + ) + + try: + fp = self._file_system.open_text(score_file, mode=FileMode.WRITE) + except OSError as ex: + raise save_error() from ex direction = "-" if lower_better else "" try: fp.write(f"{direction}{score}\n") except OSError as ex: - raise CheckpointError( - f"The checkpoint score of training step {step_nr} cannot be saved to the '{score_file}' file. See the nested exception for details." - ) from ex + raise save_error() from ex finally: fp.close() @@ -438,8 +446,8 @@ def save_consolidated_fsdp_model(self, model: Module) -> None: {"model": state_dict, "model_key": "model"}, model_file ) except TensorDumpError as ex: - raise CheckpointError( - f"The consolidated FSDP model of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The consolidated FSDP model of training step {step_nr} cannot be saved to the '{model_file}' file. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -456,8 +464,8 @@ def commit_checkpoint(self) -> None: try: self._file_system.move(tmp_step_dir, step_dir) except OSError as ex: - raise CheckpointError( - f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be committed. See the nested exception for details." + raise CheckpointSaveError( + step_nr, f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be committed. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -492,8 +500,8 @@ def load_part(filename: str) -> dict[str, object]: except FileNotFoundError: part = {} except TensorLoadError as ex: - raise CheckpointError( - f"The '{filename}' checkpoint file of training step {step_nr} cannot be loaded. See the nested exception for details." + raise CheckpointLoadError( + step_nr, f"The '{filename}' checkpoint file of training step {step_nr} cannot be loaded. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -526,7 +534,7 @@ def load_part(filename: str) -> dict[str, object]: if not checkpoint: raise CheckpointNotFoundError( - f"The checkpoint of training step {step_nr} is not found." + step_nr, f"The checkpoint of training step {step_nr} is not found." # fmt: skip ) return checkpoint @@ -535,7 +543,7 @@ def load_part(filename: str) -> dict[str, object]: def load_last_checkpoint(self) -> tuple[int, dict[str, object]]: step_numbers = self.get_step_numbers() if not step_numbers: - raise CheckpointNotFoundError("No checkpoint found.") + raise CheckpointNotFoundError(-1, "No checkpoint found.") # fmt: skip last_step_nr = step_numbers[-1] @@ -554,8 +562,8 @@ def load_metadata(self, step_nr: int) -> dict[str, object] | None: except FileNotFoundError: metadata = None except TensorLoadError as ex: - raise CheckpointError( - f"The checkpoint metadata of training step {step_nr} cannot be loaded from the '{metadata_file}' file. See the nested exception for details." + raise CheckpointLoadError( + step_nr, f"The checkpoint metadata of training step {step_nr} cannot be loaded from the '{metadata_file}' file. See the nested exception for details." # fmt: skip ) from ex self._gangs.root.barrier() @@ -572,18 +580,27 @@ def delete_checkpoint( # Delete the temporary checkpoint directory if it exists. tmp_step_dir = step_dir.with_suffix(".tmp") + def delete_error() -> CheckpointError: + return CheckpointDeleteError( + step_nr, f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." # fmt: skip + ) + try: self._file_system.remove_directory(tmp_step_dir) + except FileNotFoundError: + pass except OSError as ex: - if not isinstance(ex, FileNotFoundError): - raise CheckpointError( - f"The temporary '{tmp_step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." - ) + raise delete_error() from ex - if not self._file_system.exists(step_dir): + try: + step_exists = self._file_system.exists(step_dir) + except OSError as ex: + raise delete_error() from ex + + if not step_exists: if not missing_ok: - raise CheckpointNotFoundError( - f"The '{step_dir}' checkpoint directory of training step {step_nr} is not found." + raise CheckpointDeleteError( + step_nr, f"The '{step_dir}' checkpoint directory of training step {step_nr} is not found." # fmt: skip ) self._gangs.root.barrier() @@ -591,30 +608,43 @@ def delete_checkpoint( return if preserve_model: + + def iter_torch_files() -> Iterable[Path]: + try: + for pt_file in self._file_system.glob(step_dir, "*.pt"): + if self._file_system.is_dir(pt_file): + continue + + yield pt_file + except OSError as ex: + raise CheckpointDeleteError( + step_nr, f"The '{step_dir}' checkpoint directory of training step {step_nr} cannot be traversed. See the nested exception for details." # fmt: skip + ) from ex + # Delete all PyTorch tensor files except 'model.X.pt' files that # represent a (consolidated) model. - for pt_file in self._file_system.glob(step_dir, "*.pt"): - if self._file_system.is_dir(pt_file): - continue - + for pt_file in iter_torch_files(): if pt_file.stem.startswith("model"): continue try: self._file_system.remove(pt_file) + except FileNotFoundError: + pass except OSError as ex: - if not isinstance(ex, FileNotFoundError): - raise CheckpointError( - f"The '{pt_file}' checkpoint file of training step {step_nr} cannot be deleted. See the nested exception for details." - ) + raise CheckpointDeleteError( + step_nr, f"The '{pt_file}' checkpoint file of training step {step_nr} cannot be deleted. See the nested exception for details." # fmt: skip + ) from ex else: try: self._file_system.remove_directory(step_dir) + except FileNotFoundError: + pass except OSError as ex: - if not missing_ok or not isinstance(ex, FileNotFoundError): - raise CheckpointError( - f"The '{step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." - ) + if not missing_ok: + raise CheckpointDeleteError( + step_nr, f"The '{step_dir}' checkpoint directory of training step {step_nr} cannot be deleted. See the nested exception for details." # fmt: skip + ) from ex self._gangs.root.barrier() @@ -660,14 +690,20 @@ def _load_scores(self, step_numbers: list[int]) -> list[tuple[float, int]]: for step_nr in step_numbers: score_file = self._checkpoint_dir.joinpath(f"step_{step_nr}/score.txt") - fp = self._file_system.open_text(score_file) + def load_error() -> CheckpointError: + return CheckpointError( + f"The score of training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." + ) + + try: + fp = self._file_system.open_text(score_file) + except OSError as ex: + raise load_error() from ex try: line = fp.readline() except OSError as ex: - raise CheckpointError( - f"The score of training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." - ) from ex + raise load_error() from ex finally: fp.close() @@ -723,5 +759,32 @@ class CheckpointError(Exception): pass -class CheckpointNotFoundError(CheckpointError): +class CheckpointSaveError(CheckpointError): + step_nr: int + + def __init__(self, step_nr: int, message: str) -> None: + super().__init__(message) + + self.step_nr = step_nr + + +class CheckpointLoadError(CheckpointError): + step_nr: int + + def __init__(self, step_nr: int, message: str) -> None: + super().__init__(message) + + self.step_nr = step_nr + + +class CheckpointNotFoundError(CheckpointLoadError): pass + + +class CheckpointDeleteError(CheckpointError): + step_nr: int + + def __init__(self, step_nr: int, message: str) -> None: + super().__init__(message) + + self.step_nr = step_nr diff --git a/src/fairseq2/checkpoint/_metadata_provider.py b/src/fairseq2/checkpoint/_metadata_provider.py index 954953360..2a0f3570f 100644 --- a/src/fairseq2/checkpoint/_metadata_provider.py +++ b/src/fairseq2/checkpoint/_metadata_provider.py @@ -7,7 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import final +from typing import Iterable, final from typing_extensions import override @@ -86,46 +86,62 @@ def add_checkpoint_metadata(name: str, step_nr: int) -> None: scores = [] - try: - for step_dir in self._file_system.glob(self._checkpoint_dir, "step_*"): - if not self._file_system.is_dir(step_dir): - continue + def iter_step_dirs() -> Iterable[Path]: + try: + for step_dir in self._file_system.glob(self._checkpoint_dir, "step_*"): + if not self._file_system.is_dir(step_dir): + continue - try: - step_nr = int(step_dir.name[5:]) - except ValueError: - continue + yield step_dir + except OSError as ex: + raise AssetMetadataError( + f"The '{self._checkpoint_dir}' base checkpoint directory cannot be traversed. See the nested exception for details." + ) from ex + + for step_dir in iter_step_dirs(): + try: + step_nr = int(step_dir.name[5:]) + except ValueError: + continue + + add_checkpoint_metadata(f"checkpoint_step_{step_nr}@", step_nr) + + max_step_nr = max(max_step_nr, step_nr) - add_checkpoint_metadata(f"checkpoint_step_{step_nr}@", step_nr) + # Load score. + score_file = step_dir.joinpath("score.txt") - max_step_nr = max(max_step_nr, step_nr) + def load_error() -> AssetMetadataError: + return AssetMetadataError( + f"The score of the training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." + ) - # Load score. - score_file = step_dir.joinpath("score.txt") - if self._file_system.exists(score_file): + try: + score_exists = self._file_system.exists(score_file) + except OSError as ex: + raise load_error() from ex + + if score_exists: + try: fp = self._file_system.open_text(score_file) + except OSError as ex: + raise load_error() from ex - try: - line = fp.readline() - except OSError as ex: - raise AssetMetadataError( - f"The score of the training step {step_nr} cannot be loaded from the '{score_file}' file. See the nested exception for details." - ) from ex - finally: - fp.close() - - try: - score = float(line) - except ValueError: - raise AssetMetadataError( - f"The score of the training step {step_nr} cannot be parsed as a floating-point number." - ) from None - - scores.append((score, step_nr)) - except OSError as ex: - raise AssetMetadataError( - f"The base '{self._checkpoint_dir}' checkpoint directory cannot be traversed. See the nested exception for details." - ) from ex + try: + line = fp.readline() + except OSError as ex: + raise load_error() from ex + finally: + fp.close() + + try: + score = float(line) + except ValueError: + raise AssetMetadataError( + f"The score of the training step {step_nr} cannot be parsed as a floating-point number." + ) from None + + scores.append((score, step_nr)) if max_step_nr == -1: return @@ -146,6 +162,14 @@ def add_checkpoint_metadata(name: str, step_nr: int) -> None: def _load_tokenizer(self, cache: dict[str, dict[str, object]]) -> None: metadata_file = self._checkpoint_dir.joinpath("tokenizer.yaml") - if self._file_system.exists(metadata_file): + + try: + tokenizer_exists = self._file_system.exists(metadata_file) + except OSError as ex: + raise AssetMetadataError( + f"The '{metadata_file}' path cannot be accessed. See the nested exception for details." + ) from ex + + if tokenizer_exists: for name, metadata in self._metadata_file_loader.load(metadata_file): cache[name] = metadata diff --git a/src/fairseq2/cli/commands/recipe.py b/src/fairseq2/cli/commands/recipe.py index dc239ebc8..55de87873 100644 --- a/src/fairseq2/cli/commands/recipe.py +++ b/src/fairseq2/cli/commands/recipe.py @@ -315,7 +315,7 @@ def run(self, context: RuntimeContext, args: Namespace) -> None: if not args.output_dir: raise MissingOutputDirectoryError("`args.output_dir` must be specified.") - output_dir = self._env_bootstrapper.run( + sweep_output_dir = self._env_bootstrapper.run( args.preset, config, args.output_dir, @@ -323,7 +323,7 @@ def run(self, context: RuntimeContext, args: Namespace) -> None: sweep_format=args.sweep_format, ) - self._runner.run(context, config, output_dir) + self._runner.run(context, config, sweep_output_dir) class MissingOutputDirectoryError(ValueError): diff --git a/src/fairseq2/recipes/common.py b/src/fairseq2/recipes/common.py index 9c8d8a814..b4442688d 100644 --- a/src/fairseq2/recipes/common.py +++ b/src/fairseq2/recipes/common.py @@ -134,14 +134,27 @@ def register_extra_asset_paths( metadata_file_loader = StandardMetadataFileLoader(yaml_loader) + def access_error(path: Path) -> SetupError: + return SetupError( + f"The '{path}' path cannot be accessed. See the nested exception for details." + ) + path = config_section.extra_path if path is not None: - if not file_system.exists(path): + try: + path_exists = file_system.exists(path) + except OSError as ex: + raise access_error(path) from ex + + if not path_exists: log.warning("The '{}' path pointed to by the `extra_asset_card_path` configuration does not exist.", path) # fmt: skip return - path = file_system.resolve(path) + try: + path = file_system.resolve(path) + except OSError as ex: + raise access_error(path) from ex context.asset_store.user_metadata_providers.append( FileAssetMetadataProvider(path, file_system, metadata_file_loader) @@ -150,12 +163,21 @@ def register_extra_asset_paths( path = config_section.checkpoint_dir if path is not None: metadata_file = path.joinpath("model.yaml") - if not file_system.exists(metadata_file): + + try: + metadata_exists = file_system.exists(metadata_file) + except OSError as ex: + raise access_error(metadata_file) from ex + + if not metadata_exists: log.warning("The checkpoint metadata file (model.yaml) is not found under the '{}' directory. Make sure that the `checkpoint_search_dir` configuration points to the base checkpoint directory used during training.", path) # fmt: skip return - path = file_system.resolve(path) + try: + path = file_system.resolve(path) + except OSError as ex: + raise access_error(path) from ex context.asset_store.user_metadata_providers.append( FileCheckpointMetadataProvider(path, file_system, metadata_file_loader) diff --git a/src/fairseq2/recipes/logging.py b/src/fairseq2/recipes/logging.py index 7b73fe99f..ccbad7546 100644 --- a/src/fairseq2/recipes/logging.py +++ b/src/fairseq2/recipes/logging.py @@ -76,7 +76,12 @@ def _setup_aten_logging(self, log_file: Path) -> None: aten_log_file = log_file.parent.joinpath("aten", log_file.name) - self._file_system.make_directory(aten_log_file.parent) + try: + self._file_system.make_directory(aten_log_file.parent) + except OSError as ex: + raise SetupError( + f"The '{aten_log_file.parent}' ATen log directory cannot be created. See the nested exception for details." + ) from ex _enable_aten_logging(aten_log_file) @@ -89,7 +94,12 @@ def _setup_nccl_logging(self, log_file: Path) -> None: nccl_log_file = log_file.parent.joinpath("nccl", log_file.name) - self._file_system.make_directory(nccl_log_file.parent) + try: + self._file_system.make_directory(nccl_log_file.parent) + except OSError as ex: + raise SetupError( + f"The '{nccl_log_file.parent}' NCCL log directory cannot be created. See the nested exception for details." + ) from ex os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_DEBUG_FILE"] = str(nccl_log_file) diff --git a/src/fairseq2/recipes/runner.py b/src/fairseq2/recipes/runner.py index 81ec7dba7..35f4ec949 100644 --- a/src/fairseq2/recipes/runner.py +++ b/src/fairseq2/recipes/runner.py @@ -199,6 +199,13 @@ def run( world_size, preset, unstructured_config, sweep_format ) + try: + output_dir = self._file_system.resolve(output_dir) + except OSError as ex: + raise SetupError( + f"The '{output_dir}' path cannot be accessed. See the nested exception for details." + ) from ex + sweep_output_dir = output_dir.joinpath(sweep_tag) try: @@ -212,7 +219,7 @@ def run( sweep_output_dir.joinpath("logs/rank_{rank}.log") ) - log.info("The log files are stored under the '{}' directory.", sweep_output_dir) + log.info("Log files stored under {}.", sweep_output_dir) log_config(log, unstructured_config) @@ -276,7 +283,14 @@ def read( # Update the configuration with `--config-file`. if config_files: for config_file in chain.from_iterable(config_files): - if not self._file_system.is_file(config_file): + try: + is_file = self._file_system.is_file(config_file) + except OSError as ex: + raise SetupError( + f"The '{config_file}' configuration file cannot be accessed. See the nested exception for details." + ) from ex + + if not is_file: raise ConfigFileNotFoundError(config_file) try: diff --git a/src/fairseq2/utils/file.py b/src/fairseq2/utils/file.py index 0ada17c37..b09e4c6ab 100644 --- a/src/fairseq2/utils/file.py +++ b/src/fairseq2/utils/file.py @@ -214,7 +214,15 @@ def load( with catch_warnings(): warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. - fp = self._file_system.open(path) + def load_error() -> TensorLoadError: + return TensorLoadError( + f"The '{path}' tensor file cannot be loaded. See the nested exception for details." + ) + + try: + fp = self._file_system.open(path) + except OSError as ex: + raise load_error() from ex try: data: dict[str, object] = torch.load( @@ -223,9 +231,7 @@ def load( except FileNotFoundError: raise except (RuntimeError, OSError, PickleError) as ex: - raise TensorLoadError( - f"The '{path}' tensor file cannot be loaded. See the nested exception for details." - ) from ex + raise load_error() from ex finally: fp.close() @@ -244,14 +250,20 @@ def dump(self, data: Mapping[str, object], path: Path) -> None: with catch_warnings(): warnings.simplefilter("ignore") # Suppress noisy FSDP warnings. - fp = self._file_system.open(path, mode=FileMode.WRITE) + def dump_error() -> TensorDumpError: + return TensorDumpError( + f"The '{path}' tensor file cannot be dumped. See the nested exception for details.", + ) + + try: + fp = self._file_system.open(path, mode=FileMode.WRITE) + except OSError as ex: + raise dump_error() from ex try: torch.save(data, fp) except (RuntimeError, OSError, PickleError) as ex: - raise TensorDumpError( - f"The '{path}' tensor file cannot be dumped. See the nested exception for details.", - ) from ex + raise dump_error() from ex finally: fp.close() @@ -282,10 +294,21 @@ def load( "Safetensors only supports `torch.device` and `str` for the `map_location` parameter." ) - if self._file_system.is_dir(path): - file_iter = self._file_system.glob(path, "*.safetensors") + try: + is_dir = self._file_system.is_dir(path) + except OSError as ex: + raise TensorLoadError( + f"The '{path}' path cannot be accessed. See the nested exception for details." + ) from ex + + if is_dir: + try: + files = list(self._file_system.glob(path, "*.safetensors")) + except OSError as ex: + raise TensorLoadError( + f"The '{path}' directory cannot be traversed. See the nested exception for details." + ) from ex - files = list(file_iter) if not files: raise TensorLoadError( f"No Safetensors file found under the '{path}' directory." @@ -329,11 +352,13 @@ def __init__(self, file_system: FileSystem) -> None: def load( self, path: Path, *, map_location: MapLocation = None ) -> dict[str, object]: - def has_files(path: Path, extension: str) -> bool: - file_iter = self._file_system.glob(path, f"*{extension}") - + def has_files(extension: str) -> bool: try: - next(iter(file_iter)) + next(iter(self._file_system.glob(path, f"*{extension}"))) + except OSError as ex: + raise TensorLoadError( + f"The '{path}' directory cannot be traversed. See the nested exception for details." + ) from ex except StopIteration: return False @@ -341,8 +366,15 @@ def has_files(path: Path, extension: str) -> bool: loader: TensorLoader - if self._file_system.is_dir(path): - if not has_files(path, ".safetensors"): + try: + is_dir = self._file_system.is_dir(path) + except OSError as ex: + raise TensorLoadError( + f"The '{path}' path cannot be accessed. See the nested exception for details." + ) from ex + + if is_dir: + if not has_files(".safetensors"): raise TensorLoadError( f"The '{path}' directory does not contain any supported tensor files." )