Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for untrainable params for stage3. #48577

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _add_manage_info(trainable_param):

current_params = list()
for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size:
if p._numel() > self._segment_size:
current_params.append(_add_manage_info(p))
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))
Expand Down Expand Up @@ -430,7 +430,11 @@ def _param_storage(self, param, buffer_size):
param.status = "part"

# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
if (
param.trainable
and param.dtype == Type.fp16.value
and not self._offload
):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
master_tensor.name = param.name
self._optim._master_weights[param.fw_storage.name] = master_tensor
Expand Down Expand Up @@ -599,6 +603,9 @@ def _register_backward_hooks(self):
def _get_allreduce_fn(self, param):
@paddle.autograd.no_grad()
def allreduce_(*_):
assert (
param.trainable
), "the param must be trainable for grad allreduced"
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now
Expand Down Expand Up @@ -962,6 +969,8 @@ def _allgather_buffer(
@paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params:
if not param.trainable:
continue
if param.name in task_flow.full_grad.keys():
continue
assert isinstance(param2buffer_size[param.name], int)
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/sharding/group_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def check_dtype(param):

params_fp16 = list(filter(check_dtype, model.parameters()))
if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.")
logger_.warning(
"the input of scaler is None, please ensure the logic of your scaler outside is same as GroupShardedScaler."
)
# convert model/optimizer/scaler
if level in ['os', 'os_g']:
logger_.info("*" * 30)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np

import paddle
from paddle import nn
from paddle.distributed.sharding import group_sharded_parallel
from paddle.fluid.framework import _test_eager_guard

paddle.seed(2022)
np.random.seed(2022)


class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.first_stage = nn.Linear(4096, 4096, bias_attr=False)
self.center_stage = nn.Linear(4096, 4096)
self.center_stage.weight.stop_gradient = True
self.center_stage.bias.stop_gradient = True
self.final_stage = nn.Linear(4096, 2, bias_attr=False)

def forward(self, x):
x = self.first_stage(x)
x = self.center_stage(x)
x = self.final_stage(x)
return x


def optimizer_setting(model, use_multi_precision):
optimizer = paddle.optimizer.AdamW(
learning_rate=0.001,
parameters=model.parameters(),
multi_precision=use_multi_precision,
)
return optimizer


def train_mlp(
model,
shard_level="p_g_os",
use_multi_precision=False,
output_dir="",
amp_level='O1',
sync_buffers=False,
use_sharding=True,
data=None,
):
optimizer = optimizer_setting(
model=model, use_multi_precision=use_multi_precision
)
if use_multi_precision:
model = paddle.amp.decorate(models=model, level=amp_level)

scaler = paddle.amp.GradScaler(init_loss_scaling=32768)

if use_sharding:
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level=shard_level,
scaler=scaler,
sync_buffers=sync_buffers,
)

res_loss = []
for i in range(20):
model.train()
img = data[i]
with paddle.amp.auto_cast(use_multi_precision, level=amp_level):
out = model(img)
avg_loss = out.mean()

res_loss.append(avg_loss.item())

if not use_multi_precision:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()

optimizer.clear_grad()

return res_loss


def test_sharding_api():
paddle.distributed.init_parallel_env()

# just test warning
model = Model()
model = paddle.amp.decorate(models=model, level="O2")
optimizer = optimizer_setting(model=model, use_multi_precision=True)
model, optimizer, scaler = group_sharded_parallel(
model=model,
optimizer=optimizer,
level="p_g_os",
)

data = [paddle.randn([8, 4096]) for i in range(20)]

model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())

# dp fp32
dp_fp32_loss = train_mlp(
model, use_multi_precision=False, use_sharding=False, data=data
)

# stage3 fp32
sd3_fp32_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=False,
use_sharding=True,
data=data,
)

print("dp_fp32_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)

for i in range(len(dp_fp32_loss)):
np.testing.assert_allclose(
np.array(dp_fp32_loss[i]),
np.array(sd3_fp32_loss[i]),
rtol=1e-8,
atol=1e-8,
)

model = Model()
sd3_model = Model()
sd3_model.set_state_dict(model.state_dict())

# dp fp16
dp_fp16_loss = train_mlp(
model, use_multi_precision=True, use_sharding=False, data=data
)

# stage3 fp16
sd3_fp16_loss = train_mlp(
sd3_model,
shard_level="p_g_os",
use_multi_precision=True,
use_sharding=True,
data=data,
)

print("dp_fp316_loss: ", dp_fp32_loss)
print("sd3_fp32_loss: ", sd3_fp32_loss)

for i in range(len(dp_fp16_loss)):
np.testing.assert_allclose(
np.array(dp_fp16_loss[i]),
np.array(sd3_fp16_loss[i]),
rtol=1e-5,
atol=1e-5,
)


if __name__ == '__main__':
with _test_eager_guard():
test_sharding_api()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class TestDygraphGroupSharded(TestMultipleGpus):
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api_eager.py')

# check stage3 for some functions.
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage3_eager.py')


if __name__ == "__main__":
unittest.main()