Skip to content

Commit

Permalink
Group sharded stage3 amp 02 (PaddlePaddle#57934)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 authored Oct 10, 2023
1 parent dd6f5da commit 30c2544
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def __init__(
if "grad_clip" in item.keys():
item["grad_clip"] = self._optim._grad_clip

# check main_grad
self._check_main_grad()

# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
Expand Down Expand Up @@ -203,6 +206,16 @@ def __init__(
self._redefine_opt_step()
self._redefine_opt_clear()

def _check_main_grad(self):
self.use_main_grad = None
for param in self._layer.parameters():
if self.use_main_grad is None and hasattr(param, "main_grad"):
self.use_main_grad = True
if self.use_main_grad:
assert hasattr(
param, "main_grad"
), "Params have different main grad attributes."

@paddle.autograd.no_grad()
def _sync_params_and_buffers(self):
"""
Expand Down Expand Up @@ -235,8 +248,11 @@ def _clear_gradients(self):
assert hasattr(
param, "fw_storage"
), f"Find {param.name} don't have fw_storage attribute."

param.fw_storage.clear_gradient(False)
if self.use_main_grad:
param.fw_storage.main_grad._clear()
param.fw_storage.main_grad = None
else:
param.fw_storage.clear_gradient(False)
param.bw_storage._clear()
param.bw_storage = None
# 2.Handle unslice param
Expand All @@ -245,7 +261,12 @@ def _clear_gradients(self):
grad_storage.buffer.zero_()
else:
for param in list(self._unslice_params):
param.clear_gradient(False)
if self.use_main_grad:
param.main_grad._clear()
param.main_grad = None
else:
param.clear_gradient(False)

if (
self._default_device
in paddle.device.get_all_custom_device_type()
Expand Down Expand Up @@ -350,7 +371,9 @@ def _handle_unslice_params(self):
if param.dtype not in self._grad_storages.keys():
self._grad_storages[param.dtype] = GradStorage(
buffer_size[param.dtype],
dtype=param.dtype,
dtype=param.dtype
if not self.use_main_grad
else paddle.float32,
device=self._default_device,
destination=self._rank,
parm2align=self._unslice_params2align,
Expand Down Expand Up @@ -596,8 +619,11 @@ def _update_params(self):
), f"Find {param.name} don't have fw_storage attribute"

param.fw_storage = _TensorWrapper(param)
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
if self.use_main_grad:
param.fw_storage.main_grad = param.bw_storage
else:
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)

# 2.Handle unslice param
Expand All @@ -617,9 +643,13 @@ def _update_params(self):

for grad_storage in self._grad_storages.values():
for p in grad_storage._params:
tmp_g = _device2cpu(p.grad, convert_dtype=True)
p.clear_gradient(False)
p._copy_gradient_from(tmp_g)
if self.use_main_grad:
tmp_g = _device2cpu(p.main_grad, convert_dtype=True)
p.main_grad = tmp_g
else:
tmp_g = _device2cpu(p.grad, convert_dtype=True)
p.clear_gradient(False)
p._copy_gradient_from(tmp_g)
del tmp_g
grad_storage.buffer._clear()

Expand Down Expand Up @@ -650,6 +680,7 @@ def get_all_parameters(self, convert2cpu=False):
if convert2cpu:
for param in trainable_params:
t_flow.full_param[param.name][0]._share_buffer_to(param)
del t_flow.full_param[param.name]

# a _allgather_buffer call should be matched with a _release_param call later,
# but the _allgather_buffer call here has no match.
Expand Down Expand Up @@ -708,7 +739,11 @@ def allreduce_(*_):
param.bw_storage,
full_grad._slice(start, end).detach().clone(),
)
param.clear_gradient(False)

if self.use_main_grad:
param.main_grad = None
else:
param.clear_gradient(False)
del self._task_flow.full_grad[param.name]

if param.name in self._task_flow.full_param.keys():
Expand All @@ -726,6 +761,7 @@ def allreduce_(*_):
del self._task_flow.full_param[param.name]

if self._offload:
# revert back to cpu for offload update
param.fw_storage._clear_data()
param.master_weight._share_buffer_to(param.fw_storage)

Expand Down Expand Up @@ -929,11 +965,14 @@ class TaskFlow:

def __init__(
self,
full_param={},
full_grad={},
use_calc={},
callback=None,
):
self.full_param = {}
self.full_grad = {}
self.use_calc = {}
self.full_param = full_param
self.full_grad = full_grad
self.use_calc = use_calc
self.callback = callback


Expand Down Expand Up @@ -1014,6 +1053,7 @@ def _allgather_buffer(
continue

if offload:
# convert to device for collective comm
param.fw_storage = _cpu2device(param)

buffer_size = param2buffer_size[param.name]
Expand Down Expand Up @@ -1046,17 +1086,22 @@ def _allgather_buffer(
@paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params:
use_main_grad = hasattr(param, "main_grad")
if not param.trainable:
continue
if param.name in task_flow.full_grad.keys():
continue
assert isinstance(param2buffer_size[param.name], int)
temp_grad = paddle.zeros(
[param2buffer_size[param.name]], dtype=param.dtype
[param2buffer_size[param.name]],
dtype=param.dtype if not use_main_grad else paddle.float32,
)
temp_tensor = temp_grad._slice(0, param._numel())
temp_tensor.get_tensor()._set_dims(param.shape)
param._copy_gradient_from(temp_tensor)
if use_main_grad:
param.main_grad = temp_tensor
else:
param._copy_gradient_from(temp_tensor)
del temp_tensor
task_flow.full_grad[param.name] = temp_grad
return task_flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def fused_allreduce_gradients(parameter_list, hcg):
scale = 1.0
if dp_enabled:
group = hcg.get_data_parallel_group()
scale = group.nranks
scale = scale / group.nranks
if sep_enabled:
sep_group = hcg.get_sep_parallel_group()
dp_sep_group = hcg.get_dp_sep_parallel_group()
Expand Down
15 changes: 15 additions & 0 deletions test/collective/fleet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
set_tests_properties(test_dygraph_sharding_stage3_for_eager PROPERTIES TIMEOUT
"350")
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
bash_test_modules(
test_dygraph_sharding_stage3_bf16
START_BASH
../../legacy_test/dist_test.sh
TIMEOUT
"200"
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=22038;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_dygraph_sharding_stage3_bf16 PROPERTIES TIMEOUT
"200")
endif()
if(WITH_NCCL)
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
py_test_modules(
Expand Down
Loading

0 comments on commit 30c2544

Please sign in to comment.