Skip to content

Commit

Permalink
Merge branch 'debug-ckpt-oom' into 'main'
Browse files Browse the repository at this point in the history
[dist ckpt] Use gather object instead of all gather object when running consistency check

See merge request ADLR/megatron-lm!2413
  • Loading branch information
ericharper committed Dec 8, 2024
2 parents 44fd429 + e7503a4 commit d677ca3
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions megatron/core/dist_checkpointing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,21 @@ def maybe_report_missing_and_unexpected_keys(
logger.warning(error_msg)


def _validate_common_state_dict(common_state_dict: CommonStateDict):
def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
"""Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
Args:
common_state_dict: The common state dict present in all ransk
"""
other_rank_state_dicts = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(other_rank_state_dicts, common_state_dict)

# Gather the common state dict across ranks onto rank 0 for comparison
rank = torch.distributed.get_rank()
other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None
torch.distributed.gather_object(common_state_dict, other_rank_state_dicts)
common_state_dict_diff = {}
if torch.distributed.get_rank() == 0:
if rank == 0:
main_rank_state_dict = common_state_dict
for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
Expand Down

0 comments on commit d677ca3

Please sign in to comment.