Skip to content

Commit

Permalink
[HybridParallel]fix bug of check_inf in fleet_base.py (#36651)
Browse files Browse the repository at this point in the history
* fix bug of check_inf

* fix allreduce
  • Loading branch information
haohongxiang authored Oct 25, 2021
1 parent 50778ad commit 59d8b8c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 4 additions & 4 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,16 +1586,16 @@ def unscale_method(self, optimizer):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32,
temp_found_inf_fp32)

self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

# TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce(
paddle.to_tensor(
[self._found_inf], dtype="int32"),
op=paddle.distributed.ReduceOp.MAX,
group=None)
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
self._found_inf = is_found_inf.numpy()[0]

# Only tensor_parallel and pipeline_parallel need to modify scaler
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ def _apply_collective_grads(parameters, comm_group):
nranks = paddle.distributed.get_world_size(
) if comm_group is None else comm_group.nranks
div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad,
'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1})

paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)


Expand Down

0 comments on commit 59d8b8c

Please sign in to comment.