-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up FP6-LLM #304
Clean up FP6-LLM #304
Conversation
gau-nernst
commented
Jun 3, 2024
- Remove original FP6 quantization code (qtorch and C++ bit-packing)
- Replace FP32<->FP6 dtype conversion with @vkuzo's implementation for MX dtypes
- I also migrate some of my FP32->FP6 rounding test cases to MX custom cast test.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/304
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 6f8e7e9 with merge base 000a0fd (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I have two questions:
|
A question for @msaroufim. Is there a guideline when we should or should not decorate a function with Update: FP32->FP6_E3M2 (8192,8192) matrix (main branch) - benchmark with
CUDA is memory-bound so the implementation does not matter much (as long as it is correct). For CPU, your implementation is faster, especially with torch.compile (and faster than my C++ implementation). Though I found that CPU benchmark results tend to vary greatly across CPUs... from functools import partial
import torch
import pandas as pd
from torch.utils.benchmark import Timer
from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked
from torchao.dtypes.float6_e3m2 import _to_float6_e3m2_pt
def benchmark(f, *args):
measurement = Timer(
stmt="f(*args)",
globals={"f": f, "args": args},
).blocked_autorange()
return measurement.median * 1000
if __name__ == "__main__":
M = 8192
N = 8192
fp32_weight = torch.randn(M, N)
fp32_weight_cuda = fp32_weight.cuda()
functions = [
("_to_float6_e3m2_pt", partial(_to_float6_e3m2_pt, no_bit_packing=True)),
("f32_to_f6_e3m2_unpacked", f32_to_f6_e3m2_unpacked),
]
results = []
for name, f in functions:
results.append(["CPU", "eager", name, benchmark(f, fp32_weight)])
results.append(["CUDA", "eager", name, benchmark(f, fp32_weight_cuda)])
results.append(["CPU", "compile", name, benchmark(torch.compile(f), fp32_weight)])
results.append(["CUDA", "compile", name, benchmark(torch.compile(f), fp32_weight_cuda)])
df = pd.DataFrame(results, columns=["device", "mode", "op", "time (ms)"])
df = df.sort_values(["device", "mode"], ascending=[True, False])
print(df.to_markdown(index=False)) |
So for Windows the main issue is I would say overall everything should be compilable, the cold start problems is indeed annoying and is actively being worked, there are some broader plans that have been shared though https://dev-discuss.pytorch.org/t/how-to-bring-compile-time-down-to-zero-our-plans-and-direction-may-14th-edition/2089 Regarding dynamic shapes the way I iterate through things is first eliminate graph breaks then recompilations, this has been my goto guide https://github.com/pytorch/pytorch/blob/main/docs/source/torch.compiler_troubleshooting.rst Also just FYI we removed the requirement to have branches up to date before merge, there was a breaking change in PyTorch that was just reverted so please rebase your changes to get rid of CI flakes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a lot of deletions 🗡️
* override load from state dict * fix prefix * migrate to mx primitive * remove unneeded code * comment out test * remove * add rounding test for f6_e3m2 * update tests * remove openmp flag * update benchmark script * test negative number * remove qtorch dep * fix type casting * add view * fix strange pytest behavior * only skip tests requiring PyTorch 2.4 * remove weight loading magic
* eval and GPTQ work Summary: fleshing out the eval code so it works reliably, adding ci, adding gptq. fixed defaults for eval/gptq so they generally working meaningfully without being specified. note, we need a better way to save/load gptq models since they take so long to quantize. I tried using .so but it doesn't seem to work reliably. also added eval and gptq to ci. Test Plan: python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \ --device cuda --dtype bfloat16 python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \ --dtype bfloat16 --device cuda \ --quant '{"linear:int4" : {"groupsize" : 32} }' \ --compile python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \ --dtype bfloat16 --device cuda \ --quant '{"linear:int4" : {"groupsize" : 32} }' python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \ --dtype bfloat16 --device cuda \ --quant '{"linear:int4-gptq" : {"groupsize" : 32} }' ...running... Reviewers: Subscribers: Tasks: Tags: * fix language in help doc Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * declare scales_and_zeros --------- Co-authored-by: HDCharles <charlesdavidhernandez@gmail.com>