-
Notifications
You must be signed in to change notification settings - Fork 207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GaLore and fused kernel prototypes #95
Conversation
Hi @jeromeku! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits on packaging and minor questions on the kernels, will do another pass to review the kernels properly - let's ensure the tests run in CI and if a T4 machine is not enough then we need to get a beefier GPU asap
prototype/triton/README.md
Outdated
#### TODO | ||
|
||
- Common quant / dequant kernels for popular quantization frameworks | ||
- [ ] GPTQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have code for GPTQ https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/GPTQ.py
Bits and bytes is also interesting granted we also have kernels for QLoRA that are codegened here as well
HQQ would be nice to add
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok -- will work on kernels for bitsandbytes
AdamW8bit
then onto HQQ.
HQQ's 4 and 8-bit should be easy to adapt; 1, 2, and 3-bit might require some additional preprocessing to optimize.
Their current CUDA dequant implementations can definitely be optimized. Will work on re-implementing in CUDA
and triton
.
Also am looking into how to decomposing Marlin
kernel design into reusable building blocks for optimized quant inference.
prototype/galore/setup.py
Outdated
@@ -0,0 +1,20 @@ | |||
from setuptools import setup |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can rework the exisitng setup.py to package your kernels into the core package - Happy to credit you in the files directly and/or README
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np -- whatever is easiest
prototype/cutlass/README.md
Outdated
@@ -0,0 +1,3 @@ | |||
# Cutlass Quant | |||
|
|||
### Pythonic tools for defining `cutlass` kernels and quantization ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @andrewor14 who has also been thinking about CUTLASS in the context of #86
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cutlass
is very neat.
Cutlass 3.x
and theCuTe
framework that it introduces has many useful primitives and patterns for defining bespoke kernels of relevance (mixed type GEMM, MoE, etc.), though it is targeted primarily atsm90+
architectures.- The
2.x
api has limited support for sub-byte mixed type quant kernels (without preprocessing weights to custom format -- I believepytorch
already has this integrated undertorch.quantization._quantized_conversions
).
Currently working on using Cutlass 3.x
/ CuTe
to adapt / improve pre-Hopper
kernels useful for quant ops. Would love to also test on Hopper
but unfortunately don't have access to H100.
prototype/galore/README.md
Outdated
|
||
- [ ] Implement `FusedGaLoreOptimizer` | ||
- [ ] `Cutlass` - given fixed GEMM shape, experiment with `Cutlass` GEMMs (`split-k`, `stream-k`, fast `tensorops`). Interestingly, profiling `torch.matmul` for down projection shows that `cuBlas` dispatches to a `Cutlass` kernel of shape `128x128x16`. | ||
- [ ] Repeat with `AdamW8bit` - pure `triton` implementation of `bitsandbytes` `AdamW8bit` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this would be very helpful
prototype/galore/README.md
Outdated
#### Installation | ||
|
||
``` | ||
pip install --editable . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mentioned this already but we can package prototype under its own namespace in ao as opposed to its own package
3. normalized `grad` is projected to full rank --> additional matmul | ||
4. `params` are updated with the normalized full rank grad | ||
|
||
#### Implementation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
appreciated this ncie simple explanation
print(f"Finished benchmark, results saved to {save_path}") | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the unittest format simiilarly to other tests in the repo, lmk if you need help here
logger = logging.getLogger(__file__) | ||
|
||
|
||
class Autotuner(KernelInterface): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch we have a generic kernel auto tuner now right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tweaked the Autotuner
to print additional info such as pruned configs, best config, cache hit, etc.
|
||
#### Next Steps | ||
|
||
- [ ] Implement `FusedGaLoreOptimizer` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Generally for next steps I'd rather they get mentioned in a github issue vs docs
a = a.to(AB_DTYPE) | ||
b = b.to(AB_DTYPE) | ||
if fp8_fast_accum: | ||
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe I'm missing something dumb but how this an fp8 accum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is directly from the triton
matmul
implementation.
In the launcher, the constexpr
AB_DTYPE
gets set to None
if a
and b
are fp8
type. I'm guessing triton
's underlying tl.dot
implementation is overloaded to handle this case, which is probably why the signature differs slightly from the non-fp8
case: tl.dot(a, b, acc, ...)
vs. tl.dot(a, b, ...)
where acc
is passed as an additional arg in the former. Need to dig a bit further to confirm.
|
||
# make the smaller matrix always to be orthogonal matrix | ||
if type == "right": | ||
A = U[:, :rank] @ torch.diag(s[:rank]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unused; there's a PR upstream: jiaweizzhao/GaLore#18
Same comment on line 28 as well.
@@ -0,0 +1,65 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to self: boths tests pass but we need to decide whether we want benchmarks as tests or just accuracy checks
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Hi @jeromeku sorry for the delay, it took me a while to figure out the right way to review this but here's what I'm thinking we need to do to merge this. And feel free to reach out to me directly on Discord if you want to pair program any of this
|
Hi @jeromeku, glad to see a GaLore folder popping up in torchao :D I'm brainstorming ideas about supporting GaLore better from the core PyTorch side (e.g., making it work with distributed + checkpointing, using torch.compile instead, seeing if there's a way to generalize the GaLore technique across all optimizers without needing to create a separate optimizer for each). I see your PR introduces performant kernels, which is awesome, and I'd like to know if you want to collaborate on implementing something together. I notice your next steps are about adding more optimizers and doing analysis with torch.compile so there may be some opportunity to jam together there--I'm happy to discuss more on Slack or discord if this is something you're interested in. BTW, review-wise, I would push for @msaroufim points 6 + 5 especially! |
Good to meet you @janeyx99! Great
Yes - agree that it makes sense to generalize I have most of the pieces for a pure Working on points 5 & 6 per review though need to take care of some other stuff first. cc @msaroufim |
Updates:
Next steps:
|
Hey @jeromeku I do wanna make sure we merge something of yours, the roadmap is ambitious and the right one but I'd suggest breaking it apart this way For this PR
Next PR
|
|
test/galore/README.md
Outdated
| median | 516.3 | 403.6 | 0.0 | 0.0 | 75.7 | 272.8 | 0.0 | 18.1 | | ||
| max | 595.0 | 403.6 | 0.3 | 6.6 | 1,336.0 | 395.3 | 312.9 | 173.6 | | ||
|
||
- The `optimizer state` is indeed smaller for the `GaLoreAdamW` optimizer. Interestingly, the `Parameter` sizes balloons in the `GaLore` optimizer, likely due to extra data copies. Admittedly, the implementation is only a reference (per original repo) and leaves much room for optimization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for noting this, how are you thinking about next steps here? One hope I had for merging this is for other repos to take a dependency like https://github.com/pytorch/torchtune
Also maybe @janeyx99 has some idea of what might be going wrong
test/galore/configs/llama_1b.json
Outdated
"transformers_version": "4.28.1", | ||
"use_cache": true, | ||
"vocab_size": 32000 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a heads up we can't merge binary files in the git repo so this includes html, pt files and most of the json files that were merged in
@@ -0,0 +1,70 @@ | |||
import pandas as pd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is a nice utility can make it standalone outside of galore context
|
||
|
||
@contextmanager | ||
def nsys_profiler(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this would use the nvidia nsys profiler
@@ -0,0 +1,72 @@ | |||
from functools import partial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if possible I'd rather we not have notebooks merged in, just turn the code into a python file and let's check that in
PR is looking good, I think we're close to merging this - main things that need to be fixed
@jeromeku do you mind if I make changes to your PR directly? It might help make the review process smoother |
@msaroufim No worries - feel free to make whatever changes necessary. |
Made the following changes:
|
Additional conditions for skipping tests to avoid CI failure. Rename files as they are not actual tests but profiling tools to avoid triggering CI runs.
TL;DR
Benchmark notesConfirming the benchmark script works, on an H100 on my end I get On benchmark script I get
So indeed things are fastest for the hybrid approach, fused seems slower than eager and compile is fast but not fastest probably because I didn't enable tensor cores I made a minor change to the way the flag is set there which is recommended over using - torch.backends.cuda.matmul.allow_tf32 = allow_tf32
+ if allow_tf32:
+ torch.set_float32_matmul_precision('high') On nightlies this gives some cuda graph errors we can fix at a later time - not urgent. But it does highlight the importance of running these benchmark scripts in CI regularly. I'll make a PR myself to run everything in
If we allow tf32 I instead get torch.compile being universally faster, how should I read this? That it's best to express Galore in python code and run torch.compile? I'm fine if the answer ends up we need the fused kernels so we can also be faster in eager
API notesI did confirm first that dir(torchao.prototype.galore)
['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'adam_downproj_fused', 'adam_step', 'custom_autotune', 'fused_adam_mm_launcher', 'kernels', 'matmul', 'quant', 'triton_adam_launcher', 'triton_dequant_blockwise', 'triton_mm_launcher', 'triton_quantize_blockwise'] Which was making me wonder how would someone use these kernels exactly, the answer was in Specifically here, the function is named def make_data(M, N, rank, dtype):
grad = torch.randn(M, N, device="cuda", dtype=dtype)
params = torch.randn(M, N, device="cuda", dtype=dtype)
galore_proj = GaLoreProjector(rank=rank)
galore_proj.update_orthogonal_matrix(grad)
if M >= N:
exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype)
else:
exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype)
exp_avg2 = exp_avg**2
return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params Regardless as I kept going I noticed Tests
I was however able to confirm that the memory reductions are there
I also feel like it'll be challenging for us to accept bitsandbytes as a core dependency so we can relegate a lot of the 8 bit work to a tutorial for now as opposed to core functionality in the library I also see some code duplication for example test_fused_kernels.py and test_galore_downproj are very useful and easy to read thank you! I could verify on my end that the test work and they are picked up in CI
DocsI appreciate the roadmap discusion here torchao/prototype/README.md but for now let's keep it in a github issue it's easier to discuss longer term work there Thank you for clarifying the differences between the fused and hybrid implementations The Galore 8 bit adam is a cherry on top but as previously mentioned I don't think we're ready to take on a new depdency like bits and bytes so you can keep this as a tutorial but would suggest removing bits and bytes specific code from the PR I will say the most important function of the docs will be to communicate how you want people to use this work, so far it seems like the KernelsYou're much better than me at Triton lol so will let you decide how you want to surface things or change things here for better perf. As long as the kernels are correct I think it's fine to merge the code as is because we can always go through rounds of profiling and improvements. There's some minor nits I have here mostly around removing commented code I also think there's some dead code like OptimizerThis code was a joy to read, felt like a tutorial. Makes me wonder why you don't expose a public for a Galore optimizer as the main way for people to consume your work. EDIT: You do in the memory profile scripts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work, there's a bunch of cleanup I'll follow up on later but nothing that should block landing this
* initial commit * add placeholders for cutlass and triton * update readme * fix versions * minor text edits * clean up * add triton bnb quant kernel and test * add notes on triton quant kernel * refactor code structure * add galore downproj test * refactor test utils * add fused kernel tests * add fused benchmark * add dequant kernel * update docs * add galore memory test * add adamw8bit * fix README * clean up binaries * remove notebook, add instructions to README * remove sample data * Update galore tests Skip tests if no GPU * rename galore docs * More test edits Additional conditions for skipping tests to avoid CI failure. Rename files as they are not actual tests but profiling tools to avoid triggering CI runs. * decrease fused matmul parametrizations * remove long-running tests * remove tf32 test for now --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
Prototype Kernels and Utils
Currently:
GaLore
GaLore
memory efficient training.TODO:
triton
triton
kernels for quantized training and inferencecutlass
cutlass
kernels and otherquant
ops@msaroufim