-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[low-bit optim] Fix load state dict when device is different #1021
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1021
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2c82fc6 with merge base 9e2a253 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/prototype/test_low_bit_optim.py
Outdated
|
||
for p1, p2 in zip(model.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1) | ||
|
||
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small typo, availablle -> available
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") | |
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") |
@@ -109,6 +109,7 @@ def step(self, closure=None): | |||
|
|||
# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default | |||
# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default | |||
# NOTE: should we cast inputs to FP32 to ensure computations are always in FP32? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i override a few methods to cast the input to the weight dtype because the Flux model occasionally upcasts things to fp32 inside layernorm. you are saying fp32 is more correct than bf16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For optimizer step, internal calculations should be done in FP32 to ensure accurate results.
return float_data.view(self._shape).to(dtype) | ||
if output_dtype is not None: | ||
float_data = float_data.to(output_dtype) | ||
return float_data.view(self._shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will we encounter non-contiguous tensor? if so, view cannot be used, reshape must be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, unlikely we will have non-contiguous tensor here, because this handles internal data of the tensor subclass, not outputs of another layer in a model.
* fix serialization * fix pytorch 2.3 * fix typo * update note
* Add OPENAI_API_VERSION constant to routes * Add seed, temperature, max_tokens and system_fingerprint paramters to request/response (pytorch#1016)
In
optim.load_state_dict(state_dict)
, if optim dtype != state_dict dtype,aten._to_copy.default
is called. This PR simply implements this op and add appropriate tests.Update: In PyTorch pre-2.4, calling
.to(device, dtype)
will not dispatchaten._to_copy.default
when dtype is the same but device is different. Thus, I have to manually override.to()
method instead. This is only done for PyTorch pre-2.4. FP8 is not affected since FP8 CUDA requires PyTorch 2.4 anyway. We can remove this hack once we drop 2.3 support.