Skip to content

Commit

Permalink
add set_lr & get_lr for stage2 optimizer. (#48857)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding authored Dec 9, 2022
1 parent 39ffef0 commit 8f1e24d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,13 @@ def __impl__(x, y):

return __impl__

def set_lr(self, lr):
super().set_lr(lr)
self._optim.set_lr(lr)

def get_lr(self):
return self._optim.get_lr()

@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def train_mlp(
dp_group=dp_group,
)

# just for test_coverage.
if shard_level == "os_g":
optimizer.set_lr(optimizer.get_lr())

train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True
)
Expand Down

0 comments on commit 8f1e24d

Please sign in to comment.