Skip to content

Commit

Permalink
update for untrainable params for stage3. (PaddlePaddle#48577)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Dec 6, 2022
1 parent 5fdb1ef commit 125b08c
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _add_manage_info(trainable_param):

current_params = list()
for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size:
if p._numel() > self._segment_size:
current_params.append(_add_manage_info(p))
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))
Expand Down Expand Up @@ -430,7 +430,11 @@ def _param_storage(self, param, buffer_size):
param.status = "part"

# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
if (
param.trainable
and param.dtype == Type.fp16.value
and not self._offload
):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
master_tensor.name = param.name
self._optim._master_weights[param.fw_storage.name] = master_tensor
Expand Down Expand Up @@ -599,6 +603,9 @@ def _register_backward_hooks(self):
def _get_allreduce_fn(self, param):
@paddle.autograd.no_grad()
def allreduce_(*_):
assert (
param.trainable
), "the param must be trainable for grad allreduced"
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now
Expand Down Expand Up @@ -962,6 +969,8 @@ def _allgather_buffer(
@paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params:
if not param.trainable:
continue
if param.name in task_flow.full_grad.keys():
continue
assert isinstance(param2buffer_size[param.name], int)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/sharding/group_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def check_dtype(param):

params_fp16 = list(filter(check_dtype, model.parameters()))
if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.")
logger_.warning(
"the input of scaler is None, please ensure the logic of your scaler outside is same as GroupShardedScaler."
)
# convert model/optimizer/scaler
if level in ['os', 'os_g']:
logger_.info("*" * 30)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

import paddle
from paddle import nn
from paddle.distributed.sharding import group_sharded_parallel
from paddle.fluid.framework import _test_eager_guard

paddle.seed(2022)
np.random.seed(2022)


class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.first_stage = nn.Linear(4096, 4096, bias_attr=False)
self.center_stage = nn.Linear(4096, 4096)
self.center_stage.weight.stop_gradient = True
self.center_stage.bias.stop_gradient = True
self.final_stage = nn.Linear(4096, 2, bias_attr=False)

def forward(self, x):
x = self.first_stage(x)
x = self.center_stage(x)
x = self.final_stage(x)
return x


def optimizer_setting(model, use_multi_precision):
optimizer = paddle.optimizer.AdamW(
learning_rate=0.001,
parameters=model.parameters(),
multi_precision=use_multi_precision,
)
return optimizer


def train_mlp(
model,
shard_level="p_g_os",
use_multi_precision=False,
output_dir="",
amp_level='O1',
sync_buffers=False,
use_sharding=True,
data=None,
):
optimizer = optimizer_setting(
model=model, use_multi_precision=use_multi_precision
)
if use_multi_precision:
model = paddle.amp.decorate(models=model, level=amp_level)

scaler = paddle.amp.GradScaler(init_loss_scaling=32768)

if use_sharding:
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level=shard_level,
scaler=scaler,
sync_buffers=sync_buffers,
)

res_loss = []
for i in range(20):
model.train()
img = data[i]
with paddle.amp.auto_cast(use_multi_precision, level=amp_level):
out = model(img)
avg_loss = out.mean()

res_loss.append(avg_loss.item())

if not use_multi_precision:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()

optimizer.clear_grad()

return res_loss


def test_sharding_api():
paddle.distributed.init_parallel_env()

# just test warning
model = Model()
model = paddle.amp.decorate(models=model, level="O2")
optimizer = optimizer_setting(model=model, use_multi_precision=True)
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level="p_g_os",
)

data = [paddle.randn([8, 4096]) for i in range(20)]

model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())

# dp fp32
dp_fp32_loss = train_mlp(
model, use_multi_precision=False, use_sharding=False, data=data
)

# stage3 fp32
sd3_fp32_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=False,
use_sharding=True,
data=data,
)

print("dp_fp32_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)

for i in range(len(dp_fp32_loss)):
np.testing.assert_allclose(
np.array(dp_fp32_loss[i]),
np.array(sd3_fp32_loss[i]),
rtol=1e-8,
atol=1e-8,
)

model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())

# dp fp16
dp_fp16_loss = train_mlp(
model, use_multi_precision=True, use_sharding=False, data=data
)

# stage3 fp16
sd3_fp16_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=True,
use_sharding=True,
data=data,
)

print("dp_fp316_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)

for i in range(len(dp_fp16_loss)):
np.testing.assert_allclose(
np.array(dp_fp16_loss[i]),
np.array(sd3_fp16_loss[i]),
rtol=1e-5,
atol=1e-5,
)


if __name__ == '__main__':
with _test_eager_guard():
test_sharding_api()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class TestDygraphGroupSharded(TestMultipleGpus):
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')

# check stage3 for some functions.
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_eager.py')


if __name__ == "__main__":
unittest.main()

0 comments on commit 125b08c

Please sign in to comment.