Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove optimizer step on initialization #5104

Merged
merged 5 commits into from
Feb 11, 2024
Merged

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 8, 2024

All ZeRO 1/2/3 stages call the optimizer's step() on its initialization. This increments a counter in the optimizer and produces a different result in parameter update with the normal usage of PyTorch. This PR eliminates step() in the initialization and lazily configures some internal states (linking hp_params) after the first step() call.

@tohtana tohtana marked this pull request as ready for review February 9, 2024 22:28
@tohtana tohtana added this pull request to the merge queue Feb 10, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Feb 10, 2024
@tohtana tohtana added this pull request to the merge queue Feb 11, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Feb 11, 2024
@tohtana tohtana merged commit 1817980 into master Feb 11, 2024
12 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
All ZeRO 1/2/3 stages call the optimizer's `step()` on its
initialization. This increments a counter in the optimizer and produces
a different result in parameter update with the normal usage of PyTorch.
This PR eliminates `step()` in the initialization and lazily configures
some internal states (linking *hp_params*) after the first `step()`
call.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
github-merge-queue bot pushed a commit that referenced this pull request Mar 13, 2024
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
All ZeRO 1/2/3 stages call the optimizer's `step()` on its
initialization. This increments a counter in the optimizer and produces
a different result in parameter update with the normal usage of PyTorch.
This PR eliminates `step()` in the initialization and lazily configures
some internal states (linking *hp_params*) after the first `step()`
call.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](microsoft#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
microsoft@2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this pull request Jun 11, 2024
This PR fixes the following two points regarding checkpoint loading.

- Load optimizer states
With [this PR](microsoft#5104), we
removed optimizer's `step()` on initialization. This made the DS's
parameter update match with PyTorch's normal behavior. However, we don't
have keys in optimizer states any more when we load a checkpoint.
For legacy/elastic checkpoints, the PR changed the checkpoint loaders to
create keys and buffers on loading. However, the loader for universal
checkpoints still relies on keys in optimizer states. As the result,
loading a universal checkpoint fails.
This PR fixes the loader to find optimizer state keys from a given
checkpoint.

- Resume step count
microsoft@2943e6a
The checkpoint loader for a universal checkpoint resumes step count for
optimizer only when the param group already has `step`. But some
optimizers creates the key `step` in a param group at the first call of
`step()` (e.g. Apex [Fused
Adam](https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/apex/optimizers/fused_adam.py#L154).
In this case, the step count is not restored. This PR changes this
behavior to always set step count in a param group.
This PR also stop incrementing the step count when loading. I didn't see
why we need to increment the step count for my small example, but we may
need a discussion to consider various cases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants