Skip to content

Commit

Permalink
fix_sharding_grad_clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed Mar 16, 2022
1 parent 7ced301 commit e4ad66e
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _dygraph_clip(self, params_grads):
global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32)

# global norm of non-distributed FP16 params_and_grads for slice parameter
# global norm of non-distributed FP16 params_and_grads for unslice parameter
if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
Expand All @@ -104,21 +104,20 @@ def _dygraph_clip(self, params_grads):
[0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32)

# global norm of non-distributed FP32 params_and_grads for slice parameter
# global norm of non-distributed FP32 params_and_grads for unslice parameter
global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
unslice_params_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32

global_norm_var = global_norm_fp16 + global_norm_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32 + 1.0 / self._group.nranks * global_unslice_var

# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group)

global_norm_var += global_unslice_var
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
Expand Down

1 comment on commit e4ad66e

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.