diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 72b6447da2fa..7024b93d6820 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -944,6 +944,10 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align partition_id = dist.get_rank(group=dp_process_group[group_id]) dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + if dp_world_size == 1: + # no groups share optimizer states + # pipeline parallel with bf16 will default call this even if dp size = 1. + continue num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size) shard_size = partitioned_params[partition_id].numel() // num_shards