Skip to content

Commit

Permalink
optimizer sharding paramters (#39581)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan authored Feb 17, 2022
1 parent 1f7f856 commit 18c6f40
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self,
params,
optim,
group=None,
broadcast_fp16=False,
offload=False,
device="gpu",
pertrain_sync_models=True,
**kw):

super().__init__(optim._learning_rate, params, kw)
Expand Down Expand Up @@ -98,8 +98,12 @@ def __init__(self,

self.world_size = self.group.nranks
self.rank = self.group.rank
self._global_root_rank = 0

# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()

self.broadcast_fp16 = broadcast_fp16
self.param_storages = {} # {dtype: {rank: InternalStorage}}

if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
Expand Down Expand Up @@ -132,6 +136,22 @@ def __init__(self,
# Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status()

@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""

for p in self._local_params:
dist.broadcast(
p,
src=self._global_root_rank,
group=self.group,
use_calc_stream=True)

# Multi stream operation will be supported later
dist.wait(tensor=p, group=self.group, use_calc_stream=True)

def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@ def __init__(
sharding_optimizer,
group=None,
sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB
auto_refresh_trainable=True,
device="gpu",
use_grad_storage=True,
accumulate_grads=False):
use_grad_storage=True):
super().__init__()

# training options
Expand All @@ -81,9 +79,6 @@ def __init__(
self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable

# Gradient accumulation, Gradient flip
self._accumulate_grads = accumulate_grads

# Communication related attributes
self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group
Expand Down Expand Up @@ -128,16 +123,11 @@ def __init__(
# Set backward pass hooks
self._bw_hooks = []

# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()

# Set tasks flow
self._tasks_flow = deque()

# Define optimizer step and clear_grad
if self._accumulate_grads:
self._redefine_opt_step()
self._redefine_opt_step()
self._redefine_opt_clear()

def forward(self, *inputs, **kwargs):
Expand Down Expand Up @@ -313,9 +303,6 @@ def reduce(*_):

# Change reduce information
self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)

# Clear the gradient that does not belong to the current rank through the callback function
def cleanup():
Expand Down Expand Up @@ -362,11 +349,6 @@ def reduce(*_):
if grad_storage.all_checked_in:
assert grad_storage.buffer is not None

# Normalize all ranks grad_storage
if not self._accumulate_grads:
grad_storage.buffer.scale_(
scale=self._world_size_scaling)

# Clearing up the grad_storage buffer
def cleanup():
if dst_rank != self._rank:
Expand Down Expand Up @@ -432,22 +414,6 @@ def _setup_backward_hooks(self):
self._bw_hooks.append(
param._register_backward_hook(reduce_function))

@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""

for t in self._layer.parameters():
dist.broadcast(
t,
src=self._global_root_rank,
group=self._group,
use_calc_stream=True)

# Multi stream operation will be supported later
dist.wait(tensor=t, group=self._group, use_calc_stream=True)

def _setup_use_grad_storage(self):
"""
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
Expand Down Expand Up @@ -555,8 +521,6 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
return rank_buffer_size

def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale
for opt in self._sharding_optimizers:
opt_step = opt.step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(self,
device="gpu",
segment_size=2**15,
pertrain_sync_models=True,
accumulate_grads=False,
offload=False,
sync_comm=False):
super().__init__()
Expand All @@ -82,7 +81,6 @@ def __init__(self,
self._layer = layer
self._default_device = device
self.__sync_buffers = sync_buffers
self._accumulate_grads = accumulate_grads
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
Expand Down Expand Up @@ -190,6 +188,7 @@ def _clear_gradients(self):
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
param.bw_storage = None
# 2.Handle unslice param
if not self._offload:
for grad_storage in self._grad_storages.values():
Expand Down Expand Up @@ -446,13 +445,12 @@ def _update_params(self):
param,
"fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name)

if self._accumulate_grads:
if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
# Gradient average
if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
Expand Down Expand Up @@ -526,17 +524,14 @@ def _get_allreduce_fn(self, param):
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad, group=self._group, use_calc_stream=True)
dist.wait(
tensor=full_grad, group=self._group, use_calc_stream=True)

start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
if param.bw_storage is None:
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
Expand Down
15 changes: 8 additions & 7 deletions python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2

seed = 2021
seed = 2022
epoch = 2
linear_size = 1000

Expand Down Expand Up @@ -105,11 +105,7 @@ def train_mlp(model,
params=model.parameters(), optim=optimizer, group=group)

model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
model, optimizer, group=group, buffer_max_size=2**21)
else:
optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model)
Expand Down Expand Up @@ -140,6 +136,8 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(input=out, label=label)

avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward()

if not accumulate_grad:
Expand All @@ -166,6 +164,7 @@ def test_dp_stage2():
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)

# DP VS stage2
dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp(
Expand All @@ -174,7 +173,8 @@ def test_dp_stage2():
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)

stage2_params = train_mlp(mlp3, sharding_stage=2)
# stage2 accumulate grad
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
stage2_accumulate_grad = train_mlp(
mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True)
for i in range(len(stage2_params)):
Expand All @@ -184,6 +184,7 @@ def test_dp_stage2():
rtol=1e-5,
atol=1e-5)

# stage2 param list VS param group
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ def train_mlp(model, offload=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True)

model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = ShardingScaler(scaler)

optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, offload=offload)
model = ShardingStage2(
model, optimizer, buffer_max_size=2**21, accumulate_grads=False)
model = ShardingStage2(model, optimizer, buffer_max_size=2**21)

train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
Expand Down
14 changes: 3 additions & 11 deletions python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,10 @@ def train_mlp(model,
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
model, optimizer, group=group, buffer_max_size=2**21)
elif sharding_stage == 3:
model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
model, optimizer=optimizer, group=group, sync_comm=recompute)

# check optimizer.minimize() error
if test_minimize:
Expand Down Expand Up @@ -231,7 +223,7 @@ def test_stage2_stage3():
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-4)
atol=1e-3)

# fp16 recompute
stage3_params = train_mlp(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ def train_mlp(model,
scaler = ShardingScaler(scaler)

model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
offload=offload,
accumulate_grads=accumulate_grad)
model, optimizer=optimizer, group=group, offload=offload)

train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
Expand Down

0 comments on commit 18c6f40

Please sign in to comment.