From 30c2544fc1b7a5c20a89082c6be291e607c008b7 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 10 Oct 2023 13:26:13 +0800 Subject: [PATCH] Group sharded stage3 amp 02 (#57934) --- .../sharding/group_sharded_stage3.py | 75 ++++-- .../fleet/utils/hybrid_parallel_util.py | 2 +- test/collective/fleet/CMakeLists.txt | 15 ++ .../dygraph_group_sharded_stage3_bf16.py | 227 ++++++++++++++++++ .../test_dygraph_sharding_stage3_bf16.py | 26 ++ test/collective/fleet/testslist.csv | 1 + tools/gpups_test.sh | 1 + 7 files changed, 331 insertions(+), 16 deletions(-) create mode 100644 test/collective/fleet/dygraph_group_sharded_stage3_bf16.py create mode 100644 test/collective/fleet/test_dygraph_sharding_stage3_bf16.py diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index 12c5ac37c8b10b..8a61ab904cb304 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -176,6 +176,9 @@ def __init__( if "grad_clip" in item.keys(): item["grad_clip"] = self._optim._grad_clip + # check main_grad + self._check_main_grad() + # Synchronous all ranks models if pertrain_sync_models: self._sync_params_and_buffers() @@ -203,6 +206,16 @@ def __init__( self._redefine_opt_step() self._redefine_opt_clear() + def _check_main_grad(self): + self.use_main_grad = None + for param in self._layer.parameters(): + if self.use_main_grad is None and hasattr(param, "main_grad"): + self.use_main_grad = True + if self.use_main_grad: + assert hasattr( + param, "main_grad" + ), "Params have different main grad attributes." + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ @@ -235,8 +248,11 @@ def _clear_gradients(self): assert hasattr( param, "fw_storage" ), f"Find {param.name} don't have fw_storage attribute." - - param.fw_storage.clear_gradient(False) + if self.use_main_grad: + param.fw_storage.main_grad._clear() + param.fw_storage.main_grad = None + else: + param.fw_storage.clear_gradient(False) param.bw_storage._clear() param.bw_storage = None # 2.Handle unslice param @@ -245,7 +261,12 @@ def _clear_gradients(self): grad_storage.buffer.zero_() else: for param in list(self._unslice_params): - param.clear_gradient(False) + if self.use_main_grad: + param.main_grad._clear() + param.main_grad = None + else: + param.clear_gradient(False) + if ( self._default_device in paddle.device.get_all_custom_device_type() @@ -350,7 +371,9 @@ def _handle_unslice_params(self): if param.dtype not in self._grad_storages.keys(): self._grad_storages[param.dtype] = GradStorage( buffer_size[param.dtype], - dtype=param.dtype, + dtype=param.dtype + if not self.use_main_grad + else paddle.float32, device=self._default_device, destination=self._rank, parm2align=self._unslice_params2align, @@ -596,8 +619,11 @@ def _update_params(self): ), f"Find {param.name} don't have fw_storage attribute" param.fw_storage = _TensorWrapper(param) - assert param.fw_storage.grad is None - param.fw_storage._copy_gradient_from(param.bw_storage) + if self.use_main_grad: + param.fw_storage.main_grad = param.bw_storage + else: + assert param.fw_storage.grad is None + param.fw_storage._copy_gradient_from(param.bw_storage) update_list.append(param) # 2.Handle unslice param @@ -617,9 +643,13 @@ def _update_params(self): for grad_storage in self._grad_storages.values(): for p in grad_storage._params: - tmp_g = _device2cpu(p.grad, convert_dtype=True) - p.clear_gradient(False) - p._copy_gradient_from(tmp_g) + if self.use_main_grad: + tmp_g = _device2cpu(p.main_grad, convert_dtype=True) + p.main_grad = tmp_g + else: + tmp_g = _device2cpu(p.grad, convert_dtype=True) + p.clear_gradient(False) + p._copy_gradient_from(tmp_g) del tmp_g grad_storage.buffer._clear() @@ -650,6 +680,7 @@ def get_all_parameters(self, convert2cpu=False): if convert2cpu: for param in trainable_params: t_flow.full_param[param.name][0]._share_buffer_to(param) + del t_flow.full_param[param.name] # a _allgather_buffer call should be matched with a _release_param call later, # but the _allgather_buffer call here has no match. @@ -708,7 +739,11 @@ def allreduce_(*_): param.bw_storage, full_grad._slice(start, end).detach().clone(), ) - param.clear_gradient(False) + + if self.use_main_grad: + param.main_grad = None + else: + param.clear_gradient(False) del self._task_flow.full_grad[param.name] if param.name in self._task_flow.full_param.keys(): @@ -726,6 +761,7 @@ def allreduce_(*_): del self._task_flow.full_param[param.name] if self._offload: + # revert back to cpu for offload update param.fw_storage._clear_data() param.master_weight._share_buffer_to(param.fw_storage) @@ -929,11 +965,14 @@ class TaskFlow: def __init__( self, + full_param={}, + full_grad={}, + use_calc={}, callback=None, ): - self.full_param = {} - self.full_grad = {} - self.use_calc = {} + self.full_param = full_param + self.full_grad = full_grad + self.use_calc = use_calc self.callback = callback @@ -1014,6 +1053,7 @@ def _allgather_buffer( continue if offload: + # convert to device for collective comm param.fw_storage = _cpu2device(param) buffer_size = param2buffer_size[param.name] @@ -1046,17 +1086,22 @@ def _allgather_buffer( @paddle.autograd.no_grad() def _create_params_grad(trainable_params, param2buffer_size, task_flow): for param in trainable_params: + use_main_grad = hasattr(param, "main_grad") if not param.trainable: continue if param.name in task_flow.full_grad.keys(): continue assert isinstance(param2buffer_size[param.name], int) temp_grad = paddle.zeros( - [param2buffer_size[param.name]], dtype=param.dtype + [param2buffer_size[param.name]], + dtype=param.dtype if not use_main_grad else paddle.float32, ) temp_tensor = temp_grad._slice(0, param._numel()) temp_tensor.get_tensor()._set_dims(param.shape) - param._copy_gradient_from(temp_tensor) + if use_main_grad: + param.main_grad = temp_tensor + else: + param._copy_gradient_from(temp_tensor) del temp_tensor task_flow.full_grad[param.name] = temp_grad return task_flow diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index c68dfeefd2c600..86194c66016b29 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -252,7 +252,7 @@ def fused_allreduce_gradients(parameter_list, hcg): scale = 1.0 if dp_enabled: group = hcg.get_data_parallel_group() - scale = group.nranks + scale = scale / group.nranks if sep_enabled: sep_group = hcg.get_sep_parallel_group() dp_sep_group = hcg.get_dp_sep_parallel_group() diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index 4e1a2a970d3e95..309acb6164007d 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -134,6 +134,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) set_tests_properties(test_dygraph_sharding_stage3_for_eager PROPERTIES TIMEOUT "350") endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_dygraph_sharding_stage3_bf16 + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "200" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=22038;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_dygraph_sharding_stage3_bf16 PROPERTIES TIMEOUT + "200") +endif() if(WITH_NCCL) if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) py_test_modules( diff --git a/test/collective/fleet/dygraph_group_sharded_stage3_bf16.py b/test/collective/fleet/dygraph_group_sharded_stage3_bf16.py new file mode 100644 index 00000000000000..002426e94b0d22 --- /dev/null +++ b/test/collective/fleet/dygraph_group_sharded_stage3_bf16.py @@ -0,0 +1,227 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ( + GroupShardedStage3, +) +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( + GroupShardedScaler, +) +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.nn import Linear, ReLU + +seed = 2022 +epoch = 2 +linear_size = 1000 + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, 4 * linear_size) + self._linear2 = Linear(4 * linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + + def __getitem__(self, idx): + img = np.random.rand(self.linear_size).astype('float32') + return img + + def __len__(self): + return self.num_samples + + +def optimizer_setting(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def train_mlp( + model, + sharding_stage, + use_pure_bf16=False, + accumulate_grad=False, + use_main_grad=False, + test_scaler=False, +): + if sharding_stage != "dp": + group = paddle.distributed.new_group([0, 1], backend="nccl") + scaler = None + if test_scaler: + assert sharding_stage == 2 + assert not accumulate_grad + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + scaler = GroupShardedScaler(scaler) + optimizer = optimizer_setting( + model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad + ) + if use_pure_bf16: + level = 'O2' + custom_white_list = None + model = paddle.amp.decorate(models=model, dtype="bfloat16", level=level) + else: + level = 'O1' + custom_white_list = [ + "matmul_v2", + "elementwise_add", + "relu", + "reduce_mean", + ] + + paddle.seed(2023) + np.random.seed(2023) + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) + + if sharding_stage == 3: + model.to(device="gpu") + + if not use_pure_bf16: + for param in model.parameters(): + t = paddle.cast( + paddle.cast(param, dtype='bfloat16'), dtype='float32' + ) + param.set_value(t) + + if sharding_stage == 3: + model = GroupShardedStage3(model, optimizer, group=group) + else: + model = paddle.DataParallel(model) + + losses = [] + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + data.stop_gradient = True + + with paddle.amp.auto_cast( + True, + level=level, + dtype="bfloat16", + custom_white_list=custom_white_list, + ): + out = model(data) + loss = paddle.mean(out) + + losses.append(loss) + + if test_scaler: + assert scaler is not None + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + else: + loss.backward() + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + return losses + + +def test_stage3_bf16(): + if not paddle.amp.is_bfloat16_supported(): + return + paddle.distributed.init_parallel_env() + mlp = MLP() + state_dict = mlp.state_dict() + + # stage3 bf16 O1 vs stage3 bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + o1_losses = train_mlp(mlp1, sharding_stage=3, use_pure_bf16=False) + o2_losses = train_mlp( + mlp2, sharding_stage=3, use_pure_bf16=True, use_main_grad=True + ) + for i in range(len(o1_losses)): + o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() + o2_32_loss = paddle.cast(o2_losses[i], dtype='float32').detach() + np.testing.assert_array_equal(o1_32_loss, o2_32_loss) + + # grad accumulation test + mlp3 = MLP() + mlp4 = MLP() + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + o1_losses_grad_acc = train_mlp( + mlp3, sharding_stage=3, use_pure_bf16=False, accumulate_grad=True + ) + o2_losses_grad_acc = train_mlp( + mlp4, + sharding_stage=3, + use_pure_bf16=True, + use_main_grad=True, + accumulate_grad=True, + ) + for i in range(len(o2_losses_grad_acc)): + o2_loss_grad_acc = paddle.cast( + o2_losses_grad_acc[i], dtype='float32' + ).detach() + o1_loss_grad_acc = paddle.cast( + o1_losses_grad_acc[i], dtype='float32' + ).detach() + np.testing.assert_array_equal(o2_loss_grad_acc, o1_loss_grad_acc) + + return + + +if __name__ == '__main__': + test_stage3_bf16() diff --git a/test/collective/fleet/test_dygraph_sharding_stage3_bf16.py b/test/collective/fleet/test_dygraph_sharding_stage3_bf16.py new file mode 100644 index 00000000000000..f34191d848605b --- /dev/null +++ b/test/collective/fleet/test_dygraph_sharding_stage3_bf16.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage3(TestMultipleGpus): + def test_dygraph_sharding_stage3_bf16(self): + self.run_mnist_2gpu('dygraph_group_sharded_stage3_bf16.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/testslist.csv b/test/collective/fleet/testslist.csv index 43dd55c3754b34..664bb0bc8a502d 100644 --- a/test/collective/fleet/testslist.csv +++ b/test/collective/fleet/testslist.csv @@ -11,6 +11,7 @@ test_rnn_dp,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_p test_parallel_dygraph_mp_layers,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_tcp_store,LINUX;APPLE,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_sharding_stage3_bf16,,,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_pipeline_parallel_sync_send,,GPU;XPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..;PADDLE_P2P_SYNC_SEND=1, diff --git a/tools/gpups_test.sh b/tools/gpups_test.sh index fff44b872461e6..a1e515355c9c09 100644 --- a/tools/gpups_test.sh +++ b/tools/gpups_test.sh @@ -59,6 +59,7 @@ parallel_list="^init_phi_test$|\ ^test_dist_fleet_ps11$|\ ^test_dist_fleet_ps12$|\ ^test_dygraph_sharding_stage2_bf16$|\ +^test_dygraph_sharding_stage3_bf16$|\ ^test_executor_feed_non_tensor$|\ ^test_flash_attention$|\ ^test_fused_adam_op$|\