Skip to content

Commit

Permalink
support stage2 for gradient merge. (#47711)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Nov 17, 2022
1 parent 460d504 commit c20eb7a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,6 @@ def cleanup():
)
)

if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()

Expand Down Expand Up @@ -485,17 +474,6 @@ def cleanup():
)
)

if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._dp_group,
sync_op=True,
)

cleanup()

# Clear the task flow and trigger callback to clear the redundant gradient
Expand Down Expand Up @@ -648,8 +626,34 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
)
return rank_buffer_size

def _dp_allreduce(self):
# do dp allreduce here for gradient merge.
if self._dp_group and self._dp_group.nranks > 1:
for dtype in self._grad_storages.keys():
for rank, g in sorted(
self._grad_storages[dtype].items(), key=lambda x: x[0]
):
if g.destination == self._rank:
assert g.buffer._is_initialized()
dist.all_reduce(
tensor=g.buffer,
group=self._dp_group,
sync_op=True,
)
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
dst_rank = self._trainable_param2rank[param.name]
if dst_rank == self._rank:
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)

def _redefine_opt_step(self):
grad_func = self._grad_scale
dp_allreduce_func = self._dp_allreduce

for opt in self._sharding_optimizers:
opt_step = opt.step

Expand All @@ -658,7 +662,9 @@ def _opt_step(self):
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()

grad_func()
dp_allreduce_func()
opt_step()

opt.step = MethodType(_opt_step, opt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode


class Taskflow:
Expand Down Expand Up @@ -245,18 +244,8 @@ def unscale_method(self, optimizer):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)

paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,6 @@ def test_sharding_api():

output_dir = tempfile.mkdtemp()

# test sharding + dp, just for test
dp_group = paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)

# fp16
stage2_params = train_mlp(
mlp1,
Expand Down

0 comments on commit c20eb7a

Please sign in to comment.