Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support stage2 for gradient merge. #47711

Merged
merged 13 commits into from
Nov 17, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -418,17 +418,6 @@ def cleanup():
)
)

if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()

Expand Down Expand Up @@ -485,17 +474,6 @@ def cleanup():
)
)

if self._dp_group and self._dp_group.nranks > 1:
assert (
not self._reduce_overlap
), 'dp + stage2 hybrid parallel only Synchronize due to the new communication lib.'
# TODO(wuhuachao):after the new communication lib upgrading, overlapping the comm of dp + stage2.
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._dp_group,
sync_op=True,
)

cleanup()

# Clear the task flow and trigger callback to clear the redundant gradient
Expand Down Expand Up @@ -648,8 +626,34 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
)
return rank_buffer_size

def _dp_allreduce(self):
# do dp allreduce here for gradient merge.
if self._dp_group and self._dp_group.nranks > 1:
for dtype in self._grad_storages.keys():
for rank, g in sorted(
self._grad_storages[dtype].items(), key=lambda x: x[0]
):
if g.destination == self._rank:
assert g.buffer._is_initialized()
dist.all_reduce(
tensor=g.buffer,
group=self._dp_group,
sync_op=True,
)
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
dst_rank = self._trainable_param2rank[param.name]
if dst_rank == self._rank:
dist.all_reduce(
tensor=param.grad,
group=self._dp_group,
sync_op=True,
)

def _redefine_opt_step(self):
grad_func = self._grad_scale
dp_allreduce_func = self._dp_allreduce

for opt in self._sharding_optimizers:
opt_step = opt.step

Expand All @@ -658,6 +662,8 @@ def _opt_step(self):
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()

dp_allreduce_func()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp group内的allreduce操作,最好移动到grad_scale后进行。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

grad_func()
opt_step()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from paddle.fluid import layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.distributed import fleet, ParallelMode


class Taskflow:
Expand Down Expand Up @@ -245,18 +244,8 @@ def unscale_method(self, optimizer):
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")

hcg = fleet.fleet._hcg if hasattr(fleet.fleet, "_hcg") else None
hybrid_parallel = (
hcg is not None
and hcg.get_parallel_mode() is not ParallelMode.DATA_PARALLEL
)

paddle.distributed.all_reduce(
is_found_inf,
op=paddle.distributed.ReduceOp.MAX,
group=hcg.get_check_parallel_group()
if hybrid_parallel
else optimizer._group,
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,6 @@ def test_sharding_api():

output_dir = tempfile.mkdtemp()

# test sharding + dp, just for test
dp_group = paddle.distributed.new_group(
list(range(paddle.distributed.get_world_size()))
)

# fp16
stage2_params = train_mlp(
mlp1,
Expand Down