From 444d04313675abf20010970051b9a9b8538e7d6b Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 28 Jun 2023 14:34:06 +0800 Subject: [PATCH 1/2] make FLAGS_force_align_vpp_grad_sum_order default to false --- .../dygraph_optimizer/hybrid_parallel_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index f0cbbd9278345..c5d2dd39a93de 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -49,7 +49,7 @@ def __init__(self, clip, hcg): self.not_sharding_stage1 = True self._vpp_chunk_num = None self._force_align_vpp_grad_sum_order = distutils.util.strtobool( - os.getenv('FLAGS_force_align_vpp_grad_sum_order', '1') + os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0') ) def _get_vpp_chunk_num(self, params_grads): From 357a65c0191bc311525e504a716c75ad49057841 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 28 Jun 2023 15:39:11 +0800 Subject: [PATCH 2/2] polish code --- .../dygraph_optimizer/hybrid_parallel_optimizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index c5d2dd39a93de..bbac67dde4d44 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -168,9 +168,10 @@ def _add_sum_squares(self, sum_squares): @no_grad() def _dygraph_clip(self, params_grads): - chunk_num = self._get_vpp_chunk_num(params_grads) - if chunk_num > 0 and self._force_align_vpp_grad_sum_order: - return self._vpp_dygraph_clip(params_grads, chunk_num) + if self._force_align_vpp_grad_sum_order: + chunk_num = self._get_vpp_chunk_num(params_grads) + if chunk_num > 0: + return self._vpp_dygraph_clip(params_grads, chunk_num) sum_square_dist_fp16 = [] sum_square_dist_bf16 = []