diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 19d3245982a96..ee8c576826ca9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -86,6 +86,11 @@ def __init__(self, # Default information self._optim = optim + # sharing stage 2 comm overlap flag + self._comm_overlap = False + # record the last task used for comm overlap for sharding stage 2 + self._comm_task = None + assert hasattr(self._optim, "_master_weights" ), "Must use optimizer with _master_weights attribute" @@ -157,6 +162,17 @@ def _sync_params_and_buffers(self): group=self._group, sync_op=True) + def _update_task(self, task): + if self._comm_overlap: + assert task is not None + # Only track of the last reduce task. + # Since all tasks are on the same stream, only need to wait the last one. + # After waiting for the last reduce task, all reduce tasks before have already finished. + self._comm_task = task + + def _set_comm_overlap(self, comm_overlap): + self._comm_overlap = comm_overlap + def _generate_master_params(self, trainable_params): if self.offload: for param in trainable_params: @@ -364,7 +380,8 @@ def step(self): """ A wrapper for Optimizer's step function to finish the update operation of the optimizer. """ - + # This method won't be called directly by opt.step()! + # The _redefine_opt_step() in class GroupShardedStage2 will wrap this function. if self.offload: params_list = [self.offload_params.buffer] diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py index d47f50b292db2..b8b350d28bf79 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage2.py @@ -100,6 +100,9 @@ def __init__( for optim in self._sharding_optimizers: self._all_params.extend(list(optim.local_params)) + # sharing stage 2 comm overlap flag + self._comm_overlap = False + self._trainable_params = [] self._grad_reduced = [] self._trainable_param2rank = {} @@ -306,6 +309,18 @@ def _clear_counters(self): for grad_storage in self._grad_storage_list: grad_storage.reset_checked_in() + def _set_comm_overlap(self, comm_overlap): + # Hacky way to not add an extra parameter to the `group_sharded_parallel` funct. + # User should use this like: + # model, optimizer, scaler = group_sharded_parallel(...) + # model._set_comm_overlap(True) + self._comm_overlap = comm_overlap + if self._comm_overlap: + assert len( + self._sharding_optimizers + ) == 1, "Only support comm overlap strategy for single optimizer" + self._sharding_optimizers[0]._set_comm_overlap(comm_overlap) + def _get_reduce_fn(self, index, param, dst_rank): """ There are two ways to reduce gradient. @@ -337,11 +352,12 @@ def cleanup(): del tmp_grad param.clear_gradient(False) - # Synchronize the reduce parameter gradient - collective.reduce(tensor=param.grad, - dst=self._group.ranks[dst_rank], - group=self._group) - # TODO (Baibaifan) Asynchronous the reduce parameter gradient + # Synchronize the reduce parameter gradient asynchronize + self._sharding_optimizers[0]._update_task( + collective.reduce(tensor=param.grad, + dst=self._group.ranks[dst_rank], + group=self._group, + sync_op=not self._comm_overlap)) # Clear the task flow and trigger callback to clear the redundant gradient # self._clear_task_flow() @@ -385,12 +401,13 @@ def cleanup(): # Reduce the bucket grad_storage.sent = True - # Synchronize the reduce parameter gradient - collective.reduce( - tensor=grad_storage.buffer, - dst=self._group.ranks[grad_storage.destination], - group=self._group) - # TODO (Baibaifan) Asynchronous the reduce parameter gradient + # Synchronize the reduce parameter gradient asynchronize + self._sharding_optimizers[0]._update_task( + collective.reduce( + tensor=grad_storage.buffer, + dst=self._group.ranks[grad_storage.destination], + group=self._group, + sync_op=not self._comm_overlap)) cleanup() @@ -528,6 +545,10 @@ def _redefine_opt_step(self): opt_step = opt.step def _opt_step(self): + if self._comm_overlap: + # Wait for the last reduce task. This wait must before grad scale function. + assert self._comm_task is not None + self._comm_task.wait() grad_func() opt_step() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py new file mode 100644 index 0000000000000..7208372b6c085 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage2_comm_overlap.py @@ -0,0 +1,241 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import numpy as np +import argparse +import tempfile +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn +from paddle.fluid.framework import _test_eager_guard + +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2 + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(fluid.Layer): + + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW(parameters=[{ + "params": model.parameters(), + }] if opt_group else model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, + sharding_stage, + batch_size=100, + use_pure_fp16=False, + accumulate_grad=False, + opt_group=False, + save_model=False, + test_minimize=False): + if sharding_stage != "dp": + group = paddle.distributed.new_group([0, 1], backend="nccl") + if opt_group: + optimizer = optimizer_setting(model=model, + use_pure_fp16=use_pure_fp16, + opt_group=opt_group) + else: + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + + if sharding_stage == 2: + optimizer = GroupShardedOptimizerStage2( + params=optimizer._parameter_list, optim=optimizer, group=group) + model = GroupShardedStage2(model, + optimizer, + group=group, + buffer_max_size=2**21) + model._set_comm_overlap(True) + else: + model = paddle.DataParallel(model) + + # check optimizer.minimize() error + if test_minimize: + try: + optimizer.minimize() + except: + print( + "====== Find sharding_stage2_optimizer.minimize() error ======") + return + + train_reader = paddle.batch(reader_decorator(), + batch_size=batch_size, + drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator(capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + if sharding_stage == 2: + model.to(device="gpu") + + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if batch_size == 20: + avg_loss = avg_loss / 5 + avg_loss.backward() + + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if save_model: + return model, optimizer + return model.parameters() + + +def test_dp_stage2(): + paddle.distributed.init_parallel_env() + mlp = MLP() + state_dict = mlp.state_dict() + mlp1 = MLP() + mlp2 = MLP() + mlp3 = MLP() + mlp4 = MLP() + mlp5 = MLP() + mlp6 = MLP() + mlp7 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + mlp7.set_state_dict(state_dict) + + # DP VS stage2 + dp_params = train_mlp(mlp1, + sharding_stage="dp", + use_pure_fp16=False, + opt_group=False) + stage2_params = train_mlp(mlp2, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False) + for i in range(len(dp_params)): + np.testing.assert_allclose(dp_params[i].numpy(), + stage2_params[i].numpy(), + rtol=1e-6) + + # stage2 accumulate grad + stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True) + stage2_accumulate_grad = train_mlp(mlp4, + sharding_stage=2, + batch_size=20, + accumulate_grad=True) + for i in range(len(stage2_params)): + np.testing.assert_allclose(stage2_params[i].numpy(), + stage2_accumulate_grad[i].numpy(), + rtol=1e-5, + atol=1e-5) + + # stage2 param list VS param group + stage2_params = train_mlp(mlp5, + sharding_stage=2, + use_pure_fp16=False, + opt_group=True) + for i in range(len(dp_params)): + np.testing.assert_allclose(dp_params[i].numpy(), + stage2_params[i].numpy(), + rtol=1e-6) + + # save/load model + output_dir = tempfile.mkdtemp() + model_file = os.path.join(output_dir, "model.pdmodel") + optimizer_file = os.path.join(output_dir, "model.pdopt") + model_stage2, optimizer_stage2 = train_mlp(mlp6, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + save_model=True) + paddle.save(model_stage2.state_dict(), model_file) + paddle.save(optimizer_stage2.state_dict(), optimizer_file) + m_state_dict = paddle.load(model_file) + opt_state_dict = paddle.load(optimizer_file) + model_stage2.set_state_dict(m_state_dict) + optimizer_stage2.set_state_dict(opt_state_dict) + shutil.rmtree(output_dir) + + # check optimizer.minimize() error + train_mlp(mlp7, sharding_stage=2, test_minimize=True) + return + + +if __name__ == '__main__': + with _test_eager_guard(): + test_dp_stage2() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py index 9d842d8719fe3..aeeae15fe0653 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_sharding_stage2.py @@ -33,6 +33,9 @@ def test_dygraph_sharding_stage2_offload(self): self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py', eager_mode=False) + def test_dygraph_sharding_stage2_with_comm_overlap(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage2_comm_overlap.py') + if __name__ == "__main__": os.environ["FLAGS_enable_eager_mode"] = "1"