Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shard grad reduce #55495

Merged
merged 4 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

######

import os
from functools import reduce

import paddle
Expand All @@ -23,6 +23,16 @@
from ...utils.log_util import logger
from ...utils.tensor_fusion_helper import fused_parameters

g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")
g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")

if g_shard_norm_align_dp:
assert (
not g_shard_use_reduce
), "g_shard_norm_align_dp is not support if g_shard_use_reduce is true"


def _is_trainable(param):
return not param.stop_gradient
Expand Down Expand Up @@ -203,18 +213,22 @@ def reduce_gradients(self, parameter_list, hcg):
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name]
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce(
# g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank],
# group=hcg.get_sharding_parallel_group(),
# sync_op=True,
# )
if not g_shard_use_reduce:
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
else:
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
paddle.distributed.reduce(
g_var,
dst=hcg.get_sharding_parallel_group().ranks[
param_rank
],
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)

def _sharding_sync_parameters(self):
"""
Expand Down Expand Up @@ -294,11 +308,11 @@ def step(self):
if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad
params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'):
self._inner_opt._grad_clip.not_sharding_stage1 = False
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)

if g_shard_norm_align_dp:
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
rank_params = (
self._rank2params[self._sharding_rank]
if not self.tensor_fusion
Expand All @@ -313,8 +327,9 @@ def step(self):
startup_program=None,
params_grads=update_params_grads,
)
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
if g_shard_norm_align_dp:
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)

# sync parameters across sharding ranks
self._sharding_sync_parameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
Expand All @@ -37,13 +39,52 @@

__all__ = []

g_shard_norm_align_dp = int(os.environ.get("FLAGS_shard_norm_align_dp", 1))
logger.info(f"g_shard_norm_align_dp {g_shard_norm_align_dp}")


class HybridParallelClipGrad:
def __init__(self, clip, hcg):
self._clip = clip
self._hcg = hcg
self.not_sharding_stage1 = True

def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
# sharding first
sharding_flag = (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
)
mp_flag = self._hcg.get_model_parallel_world_size() > 1

# add all reduce to get global norm of distributed params_and_grads
if sharding_flag and not g_shard_norm_align_dp:
# norm of mp distributed variable
if mp_flag:
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# not dist only reduce among sharding group and pp group later
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
# norm of mp distributed variable
if mp_flag:
# dist should reduce among sharding group、mp group、pp group
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)

# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)

@no_grad()
def _dygraph_clip(self, params_grads):
sum_square_dist_fp16 = []
Expand Down Expand Up @@ -157,37 +198,7 @@ def _dygraph_clip(self, params_grads):
+ global_norm_not_dist_fp32
)

# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
):
sharding_flag = True
paddle.distributed.all_reduce(
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)

# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group(),
)

# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
)
self._global_norm(global_norm_var_dist, global_norm_var_not_dist)

global_norm_var_fp32 = paddle.sqrt(
global_norm_var_dist + global_norm_var_not_dist
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@

__all__ = []

g_shard_use_reduce = int(os.environ.get("FLAGS_shard_use_reduce", 0))
logger.info(f"g_shard_use_reduce {g_shard_use_reduce}")


# assume only the first stage and last stage need data, and data consumption is ordred
# to be replaced by real micro dataset from reader
Expand Down Expand Up @@ -295,8 +298,12 @@ def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank

act = HOOK_ACTION.ALL_REDUCE if dp else HOOK_ACTION.REDUCE
# Note: after sharding change to reduce operation, here need to be cleared
act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
else HOOK_ACTION.REDUCE
)

for model in models:
# For virtual pipeline. Will separate parameters in different chunk into
Expand Down
28 changes: 13 additions & 15 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,19 @@ def add_grad(self, param):
def _comm_grads(self):
assert self._all_params_checked_in

# Note: after sharding change to reduce operation here also need to be updated
# if self._act == HOOK_ACTION.ALL_REDUCE:
# task = paddle.distributed.all_reduce(
# self.grad_storage, group=self._comm_group, sync_op=False
# )
# elif self._act == HOOK_ACTION.REDUCE:
# task = paddle.distributed.reduce(
# self.grad_storage,
# dst=self._dst,
# group=self._comm_group,
# sync_op=False,
# )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
if self._act == HOOK_ACTION.ALL_REDUCE:
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)

elif self._act == HOOK_ACTION.REDUCE:
task = paddle.distributed.reduce(
self.grad_storage,
dst=self._dst,
group=self._comm_group,
sync_op=False,
)

self._task = task

@imperative_base.no_grad
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
Expand All @@ -20,6 +21,13 @@
class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self):
# test shard grad reduce
os.environ["FLAGS_shard_use_reduce"] = "1"
os.environ["FLAGS_shard_norm_align_dp"] = "0"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
# test shard grad allreduce
os.environ["FLAGS_shard_use_reduce"] = "0"
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')

def test_hybrid_parallel_sharding_tensor_fusion(self):
Expand Down