-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Float8 Weight Only and FP8 weight + dynamic activation #740
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/740
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c771c64 with merge base 05224a9 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
287fc8f
to
6e6462f
Compare
e0b3d31
to
7244571
Compare
862fb22
to
b772ad9
Compare
b772ad9
to
ffeeb9a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good overall, just had some nit comments inline
ffeeb9a
to
bfd78c0
Compare
bfd78c0
to
c771c64
Compare
@@ -30,14 +30,19 @@ def addmm_float8_unwrapped( | |||
output_scale: Optional[torch.Tensor] = None, | |||
bias: Optional[torch.Tensor] = None, | |||
use_fast_accum: bool = False, | |||
inverse_scale: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean AffineQuantizedTensor
uses scale in the same way Float8Tensor
uses 1 / scale, or something else?
@@ -57,6 +59,7 @@ | |||
from .utils import _get_per_token_block_size | |||
import logging | |||
from .autoquant import autoquant, AutoQuantizableLinearWeight | |||
from torchao.float8.float8_tensor import ScaledMMConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like this PR makes ScaledMMConfig
a public API. Is this intended? If yes, can we move this object to float8/__init__.py
, maybe make a dataclass, ensure it's documented, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah Ill put up a PR
Summary
This PR makes some tweaks to the existing FP8 weight only work flow and adds float8_dynamic_activation_float8_weight to the quantize_ API.
New Apis made public:
TODO
Changes
Memory snapshot compare
With no_grad():
Normal conditions:
FIX
@jerryzh168 If I decorate quantize affine:
And
I get the proper memory usage:
Is it okay If I land these as well or is there some other use case I am missing? I imagine all other input_tenosrs have been ints and havent propgated grads before
Perf micro benchmarks
Runner Script:
https://gist.github.com/drisspg/3baf802f8c8631df2a549840f39e7e0d
Trace: internal_link
Max Autotune
Needs this fix: pytorch/pytorch#134765
Trace: internal_link