Skip to content

Commit

Permalink
Handle restrict checkpoint loading per model instead of family (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Feb 19, 2025
1 parent 4a9e836 commit a3ec8d4
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 33 deletions.
4 changes: 4 additions & 0 deletions src/fairseq2/assets/cards/models/nllb-200.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,24 +219,28 @@ name: nllb-200_dense_1b
base: nllb-200
model_arch: nllb_dense_1b
checkpoint: "https://tinyurl.com/nllb200dense1bcheckpoint"
restrict: false

---

name: nllb-200_dense_3b
base: nllb-200
model_arch: nllb_dense_3b
checkpoint: "https://tinyurl.com/nllb200dense3bcheckpoint"
restrict: false

---

name: nllb-200_dense_distill_1b
base: nllb-200
model_arch: nllb_dense_1b
checkpoint: "https://tinyurl.com/nllb200densedst1bcheckpoint"
restrict: false

---

name: nllb-200_dense_distill_600m
base: nllb-200
model_arch: nllb_dense_600m
checkpoint: "https://tinyurl.com/nllb200densedst600mcheckpoint"
restrict: false
2 changes: 2 additions & 0 deletions src/fairseq2/assets/cards/models/s2t_conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ model_arch: conformer_medium
task: translation
target_langs: [de]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/abs_asr_pt_avg_last_10_checkpoint.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip;path=spm_char.model"
tokenizer_family: s2t_transformer

Expand All @@ -24,5 +25,6 @@ model_config:
task: translation
target_langs: [de]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/conformer/covost2/en_de/rel_pos_asr_pt_avg_last_10_checkpoint.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip;path=spm_char.model"
tokenizer_family: s2t_transformer
5 changes: 5 additions & 0 deletions src/fairseq2/assets/cards/models/s2t_transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ model_config:
task: transcription
target_langs: [en]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_transformer_s.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip;path=spm_unigram_5000.model"
tokenizer_family: s2t_transformer

Expand All @@ -28,6 +29,7 @@ model_config:
task: transcription
target_langs: [en]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_transformer_s.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip;path=spm_unigram_5000.model"
tokenizer_family: s2t_transformer

Expand All @@ -39,6 +41,7 @@ model_arch: medium
task: transcription
target_langs: [en]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_transformer_m.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_joint_asr_vocab_unigram10000.zip;path=spm_unigram_10000.model"
tokenizer_family: s2t_transformer

Expand All @@ -54,6 +57,7 @@ model_config:
task: translation
target_langs: [de]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_transformer_s.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip;path=spm_unigram_8000.model"
tokenizer_family: s2t_transformer

Expand All @@ -65,5 +69,6 @@ model_arch: medium
task: translation
target_langs: [de, es, fr, it, nl, pt, ro, ru]
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_transformer_m.pt"
restrict: false
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_multilingual_st_vocab_unigram10000.zip;path=spm_unigram_10000.model"
tokenizer_family: s2t_transformer
8 changes: 6 additions & 2 deletions src/fairseq2/checkpoint/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def maybe_load_part(filename: str) -> dict[str, object]:
file = step_dir.joinpath(filename)

try:
part = self._tensor_loader.load(file, map_location=CPU)
part = self._tensor_loader.load(
file, map_location=CPU, restrict=False
)
except FileNotFoundError:
part = {}
except TensorLoadError as ex:
Expand Down Expand Up @@ -483,7 +485,9 @@ def load_metadata(self, step_nr: int) -> dict[str, object] | None:
)

try:
metadata = self._tensor_loader.load(metadata_file, map_location=CPU)
metadata = self._tensor_loader.load(
metadata_file, map_location=CPU, restrict=False
)
except FileNotFoundError:
metadata = None
except TensorLoadError as ex:
Expand Down
57 changes: 45 additions & 12 deletions src/fairseq2/models/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ def load(

@abstractmethod
def load_from_path(
self, path: Path, model_name: str, config: object, gangs: Gangs, dtype: DataType
self,
path: Path,
model_name: str,
config: object,
gangs: Gangs,
dtype: DataType,
*,
restrict: bool = True,
) -> Module: ...

@abstractmethod
Expand Down Expand Up @@ -100,18 +107,22 @@ class AbstractModelHandler(ModelHandler):
_default_arch: str
_asset_download_manager: AssetDownloadManager
_tensor_loader: TensorLoader
_restrict: bool

def __init__(
self,
configs: ConfigProvider[object],
default_arch: str,
asset_download_manager: AssetDownloadManager,
tensor_loader: TensorLoader,
*,
restrict: bool = True,
) -> None:
self._configs = configs
self._default_arch = default_arch
self._asset_download_manager = asset_download_manager
self._tensor_loader = tensor_loader
self._restrict = restrict

@final
@override
Expand Down Expand Up @@ -268,6 +279,19 @@ def load(
if gangs.tp.size != num_shards:
raise ShardedModelLoadError(model_name, num_shards, gangs.tp.size)

# Load the checkpoint.
try:
checkpoint_uri = card.field("checkpoint").as_uri()
except AssetCardError as ex:
raise model_asset_card_error(model_name) from ex

shard_idx = gangs.tp.rank if num_shards > 1 else None

path = self._asset_download_manager.download_checkpoint(
checkpoint_uri, model_name, shard_idx=shard_idx
)

# Load the configuration.
if config is None:
try:
config = self.load_config(card)
Expand All @@ -280,20 +304,17 @@ def load(
else:
has_custom_config = True

# Load the checkpoint.
try:
checkpoint_uri = card.field("checkpoint").as_uri()
restrict = card.field("restrict").as_(bool)
except AssetCardFieldNotFoundError:
restrict = None
except AssetCardError as ex:
raise model_asset_card_error(model_name) from ex

shard_idx = gangs.tp.rank if num_shards > 1 else None

path = self._asset_download_manager.download_checkpoint(
checkpoint_uri, model_name, shard_idx=shard_idx
)

try:
return self.load_from_path(path, model_name, config, gangs, dtype)
return self.load_from_path(
path, model_name, config, gangs, dtype, restrict=restrict
)
except FileNotFoundError:
raise ModelLoadError(
model_name, f"The '{model_name}' model cannot be found at the '{path}' path." # fmt: skip
Expand All @@ -309,16 +330,28 @@ def load(
@final
@override
def load_from_path(
self, path: Path, model_name: str, config: object, gangs: Gangs, dtype: DataType
self,
path: Path,
model_name: str,
config: object,
gangs: Gangs,
dtype: DataType,
*,
restrict: bool | None = None,
) -> Module:
if gangs.root.device.type == "meta":
raise ValueError(
"`gangs` must be on a real device, but is on the meta device instead."
)

if restrict is None:
restrict = self._restrict

with load_with_sdp_gang(gangs): # Required for ShardedTensor
try:
checkpoint = self._tensor_loader.load(path, map_location=CPU)
checkpoint = self._tensor_loader.load(
path, map_location=CPU, restrict=restrict
)
except TensorLoadError as ex:
raise ModelLoadError(
model_name, f"The checkpoint of the '{model_name}' model cannot be loaded. See the nested exception for details." # fmt: skip
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/common/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def create_checkpoint_manager(

file_system = context.file_system

tensor_loader = TorchTensorLoader(file_system, restrict=False)
tensor_loader = TorchTensorLoader(file_system)
tensor_dumper = TorchTensorDumper(file_system)

return FileCheckpointManager(
Expand Down
12 changes: 5 additions & 7 deletions src/fairseq2/setup/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,14 @@
Wav2Vec2AsrModelHandler,
register_wav2vec2_asr_configs,
)
from fairseq2.utils.file import StandardTensorLoader, TorchTensorLoader
from fairseq2.utils.file import StandardTensorLoader


def _register_models(context: RuntimeContext) -> None:
asset_download_manager = context.asset_download_manager

tensor_loader = StandardTensorLoader(context.file_system)

unsafe_tensor_loader = TorchTensorLoader(context.file_system, restrict=False)

registry = context.get_registry(ModelHandler)

handler: ModelHandler
Expand Down Expand Up @@ -124,7 +122,7 @@ def _register_models(context: RuntimeContext) -> None:
default_arch = "medium"

handler = S2TTransformerModelHandler(
configs, default_arch, asset_download_manager, unsafe_tensor_loader
configs, default_arch, asset_download_manager, tensor_loader
)

registry.register(handler.family, handler)
Expand All @@ -137,7 +135,7 @@ def _register_models(context: RuntimeContext) -> None:
default_arch = "base"

handler = TransformerModelHandler(
configs, default_arch, asset_download_manager, unsafe_tensor_loader
configs, default_arch, asset_download_manager, tensor_loader
)

registry.register(handler.family, handler)
Expand All @@ -163,7 +161,7 @@ def _register_models(context: RuntimeContext) -> None:
default_arch = "base"

handler = Wav2Vec2ModelHandler(
configs, default_arch, asset_download_manager, unsafe_tensor_loader
configs, default_arch, asset_download_manager, tensor_loader
)

registry.register(handler.family, handler)
Expand All @@ -176,7 +174,7 @@ def _register_models(context: RuntimeContext) -> None:
default_arch = "base_10h"

handler = Wav2Vec2AsrModelHandler(
configs, default_arch, asset_download_manager, unsafe_tensor_loader
configs, default_arch, asset_download_manager, tensor_loader
)

registry.register(handler.family, handler)
Expand Down
20 changes: 9 additions & 11 deletions src/fairseq2/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class TensorLoader(ABC):

@abstractmethod
def load(
self, path: Path, *, map_location: MapLocation = None
self, path: Path, *, map_location: MapLocation = None, restrict: bool = True
) -> dict[str, object]:
"""
:param path:
Expand All @@ -205,15 +205,13 @@ def dump(self, data: Mapping[str, object], path: Path) -> None:
@final
class TorchTensorLoader(TensorLoader):
_file_system: FileSystem
_restrict: bool

def __init__(self, file_system: FileSystem, restrict: bool = True) -> None:
def __init__(self, file_system: FileSystem) -> None:
self._file_system = file_system
self._restrict = restrict

@override
def load(
self, path: Path, *, map_location: MapLocation = None
self, path: Path, *, map_location: MapLocation = None, restrict: bool = True
) -> dict[str, object]:
with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -234,7 +232,7 @@ def load_error() -> TensorLoadError:

try:
data: dict[str, object] = torch.load(
fp, map_location, weights_only=self._restrict # type: ignore[arg-type]
fp, map_location, weights_only=restrict # type: ignore[arg-type]
)
except (RuntimeError, OSError, PickleError) as ex:
raise load_error() from ex
Expand Down Expand Up @@ -290,7 +288,7 @@ def __init__(self, file_system: FileSystem) -> None:

@override
def load(
self, path: Path, *, map_location: MapLocation = None
self, path: Path, *, map_location: MapLocation = None, restrict: bool = True
) -> dict[str, object]:
try:
from safetensors import safe_open # type: ignore[import-not-found]
Expand Down Expand Up @@ -361,9 +359,9 @@ def __init__(self, file_system: FileSystem) -> None:

@override
def load(
self, path: Path, *, map_location: MapLocation = None
self, path: Path, *, map_location: MapLocation = None, restrict: bool = True
) -> dict[str, object]:
def has_files(extension: str) -> bool:
def has_file(extension: str) -> bool:
try:
next(iter(self._file_system.glob(path, f"*{extension}")))
except OSError as ex:
Expand All @@ -385,7 +383,7 @@ def has_files(extension: str) -> bool:
) from ex

if is_dir:
if not has_files(".safetensors"):
if has_file(".safetensors"):
raise TensorLoadError(
path, f"The '{path}' directory does not contain any supported tensor files." # fmt: skip
)
Expand All @@ -396,7 +394,7 @@ def has_files(extension: str) -> bool:
else:
loader = self._default_tensor_loader

return loader.load(path, map_location=map_location)
return loader.load(path, map_location=map_location, restrict=restrict)


class TensorLoadError(Exception):
Expand Down

0 comments on commit a3ec8d4

Please sign in to comment.