Skip to content

Commit

Permalink
support sharding stage2 + mp hybrid_parallel. (#47535)
Browse files Browse the repository at this point in the history
* support sharding stage2 + mp hybrid_parallel.

* fix the group of check_nan_inf.

* update hcg.
  • Loading branch information
wuhuachaocoding authored Nov 3, 2022
1 parent 605bc00 commit 818132a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from paddle.fluid import core
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed import fleet, ParallelMode

HybridParallelClipGrad = (
fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer.HybridParallelClipGrad
)
from paddle.distributed.collective import (
_get_global_group,
broadcast,
Expand Down Expand Up @@ -157,9 +162,18 @@ def __init__(
"While using ClipGradByGlobalNorm in GroupShardedOptimizerStage2, the grad clip of original optimizer will be changed."
)

self._optim._grad_clip = GroupShardedClipGrad(
self._optim._grad_clip, paddle.get_device(), self._group
)
hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
if (
hcg
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
):
self._optim._grad_clip = HybridParallelClipGrad(
self._optim._grad_clip, hcg
)
else:
self._optim._grad_clip = GroupShardedClipGrad(
self._optim._grad_clip, paddle.get_device(), self._group
)
if self._optim._parameter_list and isinstance(
self._optim._parameter_list[0], dict
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode


class Taskflow:
Expand Down Expand Up @@ -244,10 +245,18 @@ def unscale_method(self, optimizer):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)

paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=optimizer._group,
group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
)
self._found_inf = is_found_inf.numpy()[0]

Expand Down

0 comments on commit 818132a

Please sign in to comment.