-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a prototype of MX format training and inference #264
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/264
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4425d0d with merge base 5b04ff0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
need to add license and fix CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first round of feedback on docs and ci related stuff - will do another pass
```python | ||
from torchao.prototype.mx_formats.mx_tensor import MXTensor | ||
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4 | ||
x = torch.randn(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put a functioning snippet that people can copy paste
```python | ||
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear | ||
|
||
m = Model(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment on a functional snippet
aef63f9
to
0c84b1a
Compare
ok, CI is green, going to address the other comments now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamping for now, will need a day or more to read the mx spec and don't wanna block your PR until then
I can wait for review, would rather only land once people are ok with the code. |
455f148
to
77541c5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Really enjoyed reviewing this. some minor nits but we should be good to merge
README.md
Outdated
@@ -99,6 +99,7 @@ To learn more try out our APIs, you can check out API examples in | |||
3. Support for lower precision [dtypes](./torchao/dtypes) such as | |||
- [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code | |||
- [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py) | |||
- [MX](https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats) implementing the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), prototype as the hardware support is not available yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth expanding to mention MX including fp8/6/4 and int8 - MX is still new terminology
import torch | ||
|
||
# This is conceptually an enum of non-core dtypes | ||
# if someone has time to verify torch.compile compatibility, it could be made |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the comment intending to say that torch.compile is breaking on enum or that in the future torch.compile support can be checked AND indepedently this could be made into an enum.
Indeed I feel like an enum would make this significantly easier to read cause you could conceptually print every row in the spec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weird thing is that some of these dtypes are in core (the float8 ones) and some aren't (float6/float4/mx's spec of int8, etc). I think it would be nice to have a clean structure unifying all of that, I just haven't had the time. Definitely open for someone (or future me) to improve this.
def compute_error(x, y): | ||
Ps = torch.norm(x) # noqa: TOR101 | ||
Pn = torch.norm(x - y) # noqa: TOR101 | ||
return 20 * torch.log10(Ps / Pn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's already a util for this exactl function in the code, somewhere in gptq IIRC so can we put this in torchao/utils.py instead?
|
||
### MXTensor | ||
|
||
This is casts between fp32/bf16 and MX formats implemented in native PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw the spec didn't seem too prescriptive around what the source dtype should be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp32 and bf16 is what we have today, we can make it clearer that other dtypes can be added in the future
|
||
def get_fp_scale(scale_e8m0): | ||
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS | ||
# TODO(later): it would be nice if there was a way to do the 2^x operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this is helpful https://pytorch.org/docs/stable/generated/torch.ldexp.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense! I will punt this to a future person, this shouldn't be that important for e2e performance.
return g, None, None | ||
|
||
|
||
@torch._dynamo.allow_in_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a public API torch.compiler.allow_in_graph
- also curious why this was needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is necessary for compile to fully support training, and this line is copy-pasta from float8_experimental, ideally while these two products are in different codebases I'm hoping for these kind of issues to get fixed in float8_experimental first and be copied here. Once we unify it will be easier.
return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E3M2) | ||
|
||
|
||
if has_triton(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was the inductor codegen note adequate? Wondering if we can eventually remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current codegen was slow, tracked in pytorch/pytorch#124002 .
print("\n") | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wdyy about renaming this file to have spec in the name? I quite like and we can recommend people to cross reference the text spec with your code in the main README
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Really enjoyed reviewing this. some minor nits but we should be good to merge
77541c5
to
ad704f0
Compare
Summary: The MX numerical formats are new low precision formats with recent acceptance into the OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication. Test Plan: ``` // tests pytest -s test/prototype/mx_formats/* // benchmarks python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ``` Reviewers: Subscribers: Tasks: Tags:
ad704f0
to
4425d0d
Compare
@msaroufim needs a review again since I think this repo is setup to re-require reviews after changes, all of the feedback has been either addressed or explained why not addressed right now. |
and thank you for the review! |
Summary: The MX numerical formats are new low precision formats with recent acceptance into the OCP spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication. Test Plan: ``` // tests pytest -s test/prototype/mx_formats/* // benchmarks python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ``` Reviewers: Subscribers: Tasks: Tags:
Summary: - Removed Int8DynActInt4Weight code - Use torchao to achieve the same Test Plan: python export.py --quant '{"linear:a8w4dq" : {"groupsize": 128}}' --checkpoint-path stories110M.pt --params-path params.json --output-pte-path /tmp/stories110m_a8w4dq.pte Run ./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z /tmp/tokenizer.bin -n 200 -t 0 Reviewers: Subscribers: Tasks: Tags:
Summary:
The MX numerical formats are new low precision formats with recent acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
This PR adds a reference native PyTorch implementation of training and inference primitives for using MX accelerated matrix multiplications. Currently, we use a reference layout (scale and raw data stored separately) and an emulated matrix multiplication.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: