Skip to content

Commit

Permalink
Materialize empty model on rank 0
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Feb 15, 2025
1 parent f3e2b59 commit 45f4ff0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
24 changes: 15 additions & 9 deletions src/fairseq2/nn/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,16 @@ def broadcast_module(
gang: Gang,
*,
source_rank: int = 0,
broadcast_buffers: bool = True,
non_persistent_buffers: bool = False,
skip_modules: set[Module] | None = None,
) -> None:
"""Broadcasts ``module`` to all processes in ``gang``.
:param module: The module to broadcast.
:param gang The gang over which to broadcast ``module``.
:param source_rank: The rank of the source process from which to broadcast.
:param broadcast_buffers: If ``True``, broadcasts not only the parameters,
but the buffers as well.
:param non_persistent_buffers: If ``True``, broadcasts the non-persistent
buffers as well.
:param skip_modules: The set of modules that won't be broadcasted.
"""
to_device(module, gang.device)
Expand Down Expand Up @@ -424,14 +424,20 @@ def collect_tensors(m: Module) -> None:

warned = True

if broadcast_buffers:
for buffer in m.buffers(recurse=False):
if buffer in memo:
continue
for buffer_name, buffer in m.named_buffers(recurse=False):
if buffer in memo:
continue

memo.add(buffer)
memo.add(buffer)

if not non_persistent_buffers:
# TODO(balioglu): Surprisingly, PyTorch still does not offer a
# public API to check the type of a module buffer. This should
# be updated in the future.
if buffer_name in m._non_persistent_buffers_set:
continue

tensors.append(buffer.detach())
tensors.append(buffer.detach())

collect_tensors(module)

Expand Down
21 changes: 12 additions & 9 deletions src/fairseq2/recipes/common/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,15 +517,18 @@ def load(self, recipe_config: object, gangs: Gangs) -> Module:
log.info("Checkpoint found. Loading '{}' model on data parallel rank 0.", model_name) # fmt: skip

try:
if gangs.dp.rank == 0 and saved_model_path is not None:
try:
model = handler.load_from_path(
saved_model_path, model_name, model_config, gangs, dtype
)
except FileNotFoundError:
raise ModelLoadError(
model_name, f"The '{model_name}' model cannot be found at the '{saved_model_path}' path." # fmt: skip
) from None
if gangs.dp.rank == 0:
if saved_model_path is not None:
try:
model = handler.load_from_path(
saved_model_path, model_name, model_config, gangs, dtype
)
except FileNotFoundError:
raise ModelLoadError(
model_name, f"The '{model_name}' model cannot be found at the '{saved_model_path}' path." # fmt: skip
) from None
else:
model = handler.create(model_config, gangs, dtype, meta=False)
else:
model = handler.create(
model_config, gangs, dtype, meta=handler.supports_meta
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/common/_ref_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def load(self, model_name: str, gangs: Gangs, dtype: DataType, mp: bool) -> Mode
model = handler.load(card, gangs, dtype, model_config)
else:
model = handler.create(
model_config, gangs, dtype, handler.supports_meta
model_config, gangs, dtype, meta=handler.supports_meta
)
except NotSupportedError as ex:
raise ModelLoadError(
Expand Down

0 comments on commit 45f4ff0

Please sign in to comment.