From cd7d6b210b69908529aa76fc96750c4df4512cef Mon Sep 17 00:00:00 2001 From: "Liu, Mingzhi" Date: Thu, 17 Aug 2023 08:59:11 +0000 Subject: [PATCH 1/2] skip all-gather --- deepspeed/runtime/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 72b6447da2fa..a2cde172ea69 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -944,6 +944,9 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align partition_id = dist.get_rank(group=dp_process_group[group_id]) dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + if dp_world_size == 1: + # no groups share optimizer states. + continue num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size) shard_size = partitioned_params[partition_id].numel() // num_shards From d2e0741442ea9af0227347dcf777493b21468482 Mon Sep 17 00:00:00 2001 From: "Liu, Mingzhi" Date: Thu, 17 Aug 2023 09:09:05 +0000 Subject: [PATCH 2/2] add notes --- deepspeed/runtime/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index a2cde172ea69..7024b93d6820 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -945,7 +945,8 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) if dp_world_size == 1: - # no groups share optimizer states. + # no groups share optimizer states + # pipeline parallel with bf16 will default call this even if dp size = 1. continue num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)