Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Fix load state dict when device is different #1021

Merged
merged 4 commits into from
Oct 6, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Oct 5, 2024

In optim.load_state_dict(state_dict), if optim dtype != state_dict dtype, aten._to_copy.default is called. This PR simply implements this op and add appropriate tests.

Update: In PyTorch pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when dtype is the same but device is different. Thus, I have to manually override .to() method instead. This is only done for PyTorch pre-2.4. FP8 is not affected since FP8 CUDA requires PyTorch 2.4 anyway. We can remove this hack once we drop 2.3 support.

Copy link

pytorch-bot bot commented Oct 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1021

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2c82fc6 with merge base 9e2a253 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 5, 2024
@gau-nernst gau-nernst requested a review from msaroufim October 5, 2024 02:11

for p1, p2 in zip(model.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small typo, availablle -> available

Suggested change
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available")

@@ -109,6 +109,7 @@ def step(self, closure=None):

# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default
# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default
# NOTE: should we cast inputs to FP32 to ensure computations are always in FP32?
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i override a few methods to cast the input to the weight dtype because the Flux model occasionally upcasts things to fp32 inside layernorm. you are saying fp32 is more correct than bf16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For optimizer step, internal calculations should be done in FP32 to ensure accurate results.

return float_data.view(self._shape).to(dtype)
if output_dtype is not None:
float_data = float_data.to(output_dtype)
return float_data.view(self._shape)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we encounter non-contiguous tensor? if so, view cannot be used, reshape must be

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, unlikely we will have non-contiguous tensor here, because this handles internal data of the tensor subclass, not outputs of another layer in a model.

@gau-nernst gau-nernst merged commit c187f87 into pytorch:main Oct 6, 2024
17 checks passed
@gau-nernst gau-nernst deleted the optim_serialization branch October 6, 2024 09:21
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
* fix serialization

* fix pytorch 2.3

* fix typo

* update note
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Add OPENAI_API_VERSION constant to routes

* Add seed, temperature, max_tokens  and system_fingerprint paramters to request/response (pytorch#1016)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants