diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index 8f39ddc052..48e023dc39 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -364,7 +364,7 @@ def maybe_report_missing_and_unexpected_keys( logger.warning(error_msg) -def _validate_common_state_dict(common_state_dict: CommonStateDict): +def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None: """Validate consistancy across ranks for the common state dict We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving. @@ -372,10 +372,13 @@ def _validate_common_state_dict(common_state_dict: CommonStateDict): Args: common_state_dict: The common state dict present in all ransk """ - other_rank_state_dicts = [None] * torch.distributed.get_world_size() - torch.distributed.all_gather_object(other_rank_state_dicts, common_state_dict) + + # Gather the common state dict across ranks onto rank 0 for comparison + rank = torch.distributed.get_rank() + other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None + torch.distributed.gather_object(common_state_dict, other_rank_state_dicts) common_state_dict_diff = {} - if torch.distributed.get_rank() == 0: + if rank == 0: main_rank_state_dict = common_state_dict for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1): only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)