Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse marlin AQT layout #621

Merged
merged 8 commits into from
Sep 6, 2024
Merged

Conversation

Diogo-V
Copy link
Contributor

@Diogo-V Diogo-V commented Aug 6, 2024

Summary

Introduces a Layout for AQT class called Sparse Marlin 2:4.

E2E Results

Screenshot 2024-09-06 at 8 57 06 AM

Tests

test/sparsity/test_marlin.py

Considerations for reviewers

  1. To solve the issue with torch.compile I had to covert the section of here and here from numpy to torch. If I used torch.uint32, it would throw an error of rshift_cpu not supported for uint32. Using torch.int64 seems to not cause any problems
  2. I was able to completely remove the dependency from numpy, so we won't need to add it as a dependency

Notes

  • Closes Add 2:4 sparse marlin kernels to torchao #549
  • Since there is a deadline for the torchAO announcement tomorrow and, in the case that this work was to be part of it, I will be around today to work on any changes that this PR might need before it gets merged in.

Feel free to let me know if there is anything that needs to be refactored!

Copy link

pytorch-bot bot commented Aug 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/621

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 333a88f with merge base a246d87 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @Diogo-V!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 6, 2024
@@ -0,0 +1,49 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to also create a test here https://github.com/pytorch/ao/tree/main/test for the cuda kernel and a test for an end to end flow https://github.com/pytorch/ao/tree/main/test/sparsity

@supriyar supriyar requested a review from jcaip August 7, 2024 20:45
@Diogo-V Diogo-V mentioned this pull request Aug 13, 2024
3 tasks
@msaroufim msaroufim marked this pull request as ready for review August 14, 2024 19:16
@msaroufim msaroufim marked this pull request as draft August 14, 2024 19:55
@msaroufim
Copy link
Member

so you should have perms to trigger CI, keep in mind that CI wont run if you have merge conflicts like now

Copy link
Contributor

@jcaip jcaip left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is coming along nicely @Diogo-V ! Left a couple comments

BTW, once you get the subclass hooked up we should try this on SAM:
https://github.com/pytorch/ao/blob/main/torchao/_models/sam/eval_combo.py#L286

I'm putting together a collection of ViT sparsification results for our PTC poster and would love to add your result and give you a shout out.

Also, feel free to @ me, I may not be the most responsive otherwise.


class TestQuantSparseMarlin(TestCase):

@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_3, reason="pytorch 2.3+ feature")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need a skip here AFAIK, we can add it back if CI fails though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

torchao/dtypes/affine_quantized_tensor.py Show resolved Hide resolved
self.layout_type = layout_type
self.initial_shape = initial_shape

@classmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI Jerry merged in some helper functions for utils, so you can do this a little nicer like this:

from torchao.dtypes.utils import (

return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)


def _mask_creator(tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move these to a general sparsity/utils.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def _sparse_semi_structured_from_dense_cutlass(dense):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that this are the same functions that are available here: https://github.com/pytorch/pytorch/blob/main/torch/sparse/_semi_structured_conversions.py

So we don't have to copy them over.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems some of them have small differences:

Do you want me to leave it as is or should I change anything?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @alexsamardzic I'm assuming they may have changed some stuff for int4 support. n00b question here, would it make sense to upstream some of the changes that they made?

@Diogo-V let's leave it for now, we can revisit once Alex responds.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - it would be good to have this upstream, but it would be good then also to have corresponding tests in test/test_sparse_semi_structured.py extended accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you have bandwidth to update that? Otherwise I can take a stab after I'm done with pytorch/pytorch#132928

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make a note; if taking on this, I'd also fully extend CUTLASS variant of SparseSemiStructuredTensor for int4 support. Am at the moment working on stuff related to mixed data types MM, so can't promise exactly if/when I'll have cycles for this - sorry about that.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 16, 2024

@jcaip - Thank you for the comments! I am going to continue working on this today and will take care of them.
I think I have most things hooked together. There seems to be a weird bug somewhere but I will grind through it.

Once we merge this, I’d be happy to add it to SAM and run the benchmarks!

Also, feel free to let me know if there is anything that I can help out on for the PTC poster. I managed to finish mine earlier this week and so I have a bit more time 🙂

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 16, 2024

so you should have perms to trigger CI, keep in mind that CI wont run if you have merge conflicts like now

Awesome! I am going to resolve them and try it out once I am at the computer

@Diogo-V Diogo-V force-pushed the 549-sparse-marlin-kernel branch 2 times, most recently from fecb1f8 to 7f9c65a Compare August 20, 2024 22:53
from transformers import AutoTokenizer, LlamaForCausalLM

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
name = "meta-llama/Llama-2-7b-hf"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use "nm-testing/SparseLlama-3-8B-pruned_50.2of4" as a checkpoint


# Marlin Sparse op dispatch registration
@MarlinSparseAQTLayout.implements(aten.detach.default)
def block_sparse_detach(func, types, args, kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think @jerryzh168 has been using def _( for all of these, you should do the same.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 20, 2024

Hi, I did a bit more work on this PR this last few days and, if possible, I would like to ask for another preliminary review on it. It is not yet finished because I am still having a bug somewhere that is causing llama2 to output some weird text (I am using llama2 to validate everything E2E). Will be working on it the next couple of days:

image

I added a file in root of the project called wip_test_llama.py just to make it easier for you in case you want to check it. This file will be removed before the PR is merged.

The main difference from last time is that instead of using the code from the original repo I went with the code that they have in the nm-vllm repo. Three reasons for it: the code was easier to follow, had better APIs and their kernel implementation also supports fp8 making it fairly straightforward to add support for it in the future if we want to.

There are two parts of the code of which I am not confident if it is correct. To make it easier for you to look for them, I added a comment above the code section with the text NOTE(reviewers):

  1. In here, I am not sure if I should transpose it and/or if it is even correct how I did it. It was done to match the in_features, out_features format that was in the original repo and because I was having some trouble with the shapes 😅
  2. In here I am not sure if I am handling the reshape of the input tensor correctly.

Let me know if you find something that needs refactoring!

Aside from that, I have a question:

  • There is a section of code that is using numpy. I tried converting it to torch but it was throwing the error Promotion for uint16, uint32, uint64 types is not supported, attempted to promote UInt32 and Int. Is there a way for me to write this part of the code in native torch that would work? If not, then we will probably need to add numpy as a dependency to torchao before this gets merged. How should I proceed on this?

tagging for visibility: @msaroufim @jcaip

@jcaip
Copy link
Contributor

jcaip commented Aug 20, 2024

Thanks for the update @Diogo-V! I 'll take a more thorough pass through later tonight

@jcaip jcaip force-pushed the 549-sparse-marlin-kernel branch from 7f9c65a to e5390a6 Compare August 22, 2024 03:11
@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 22, 2024

Hi @jcaip, just saw that you did a commit getting torch.compile to work. Thank you!
Later tonight I will be working on getting the failing tests to pass. Is there anything else that you feel like is still missing/needs refactoring and I could take a stab at?

@jcaip
Copy link
Contributor

jcaip commented Aug 22, 2024

@Diogo-V Thanks for your help, and good job landing the op!

I took a look at your code and I do think there might be an issue in the way we hook up the marlin gemm op. I noticed for that changing the batch dimension (32, 16, 14096) for the input breaks the test, which shouldn't be the case. The compile test is also failing for me.

I don't know exactly what's where the problem is, but the general gist is that we have XW implemented for marlin op but actually for linear we calculate xW', which is why I think you need to transpose int_data and scales in from_float. Somewhere along the lines I think we're passing in the wrong dimensions, but need to do more debugging today.

Could you open a PR with just the marlin op (without the affine quantized tensor implementation) for now? You can git squash all your commits and then remove the AQT files with: https://stackoverflow.com/questions/12481639/remove-file-from-latest-commit. We can merge that in first, since I'll be on PTO next week. cc @jerryzh168 @msaroufim

I think the compile debugging will be a bit more involved, so don't want you to get blocked on that. Will update if I find the bug.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 22, 2024

Yeah totally! Will open the PR in a bit

I am not super acquainted with the compile process (have been learning it the past few weeks). Do you have any resources that would be good for me to take a look at so that I can help out more and get across this issues in future PRs?

This weekend, I will take another stab at this and try to get the tests to pass. I would also be happy to pass this work off to someone else if you feel like it would be too tight of a schedule for the poster.

@Diogo-V Diogo-V mentioned this pull request Aug 22, 2024
3 tasks
@Diogo-V Diogo-V changed the title [WIP]: Add 2:4 sparse marlin kernels [WIP]: Add sparse marlin AQT layout Aug 22, 2024
@jcaip
Copy link
Contributor

jcaip commented Aug 23, 2024

Hey @Diogo-V

I am not super acquainted with the compile process (have been learning it the past few weeks). Do you have any resources that would be good for me to take a look at so that I can help out more and get across this issues in future PRs?

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html is a good tutorial. I would actually recommend not debugging with compile, until you first fix the batch size=32 eager test case, since it'll be a lot easier to be able to debug in eager.

This weekend, I will take another stab at this and try to get the tests to pass. I would also be happy to pass this work off to someone else if you feel like it would be too tight of a schedule for the poster.

Thanks, that would be awesome! It would definitely help to have another person debug this. Let me try to explain a bit more about what I think is going on here:

On CUDA, tensors are represented as just blobs of data, and we keep track of the strides / offsets / dimensions of the tensor separately in metadata. This is nice because we can do something like a transpose without needing to modify our existing blog of data, we just change how we access it via the strides / offsets.

The Marlin kernel also takes in these m.k.n values, likely to make the kernel implementation easier. I think somewhere along the line we are flubbing a transpose somewhere, where we are switching up our mkn dimensions (because transpose), but still passing in the contiguous data of the non-transposed tensor. I think this mismatch is the cause of our error.

This is especially confusing because for 2:4 spares support, the tensor cores only support the first element being sparse really, which we can deal with by using transpose properties to move the second element sparse.

But note how we are now returning a transposed matrix, which means it is not contiguous, hence why it is in column major format. I think this row/column major format mismatch, or something related is the cause of our compile / batch_size=n bug. To make sure, we can check that both the marlin output and the reference output share the same strides in our unit test: assert sparse_result.stride() == reference_result.stride()

Just a tip, feel free to do whatever you feel is best: I would gradually work up from your test_marlin_24 test here, by writing a nn.Linear layer that outputs the same values as your reference mm implementation.

Basically my thinking is that we go from mm unit test -> nn.Linear unit test -> e2e test. Let me know if you have any questions, I'll try to answer them the best I can.

As far as deadlines - just FYI: 9/6 is the deadline for our torchAO announcement, with the 9/16 is the deadline for making updates to our PTC poster.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Aug 23, 2024

https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html is a good tutorial. I would actually recommend not debugging with compile, until you first fix the batch size=32 eager test case, since it'll be a lot easier to be able to debug in eager.

Thank you very much! I will start off by tackling the batch size issue before moving onto the compile test as you suggested.

On CUDA, tensors are represented as just blobs of data (...) unit test: assert sparse_result.stride() == reference_result.stride()

Yep, I agree with you on this. Thanks for sharing your insights about the current problem! I think I have a better idea of how to debug this after reading your explanation 🙂

Just a tip, feel free to do whatever you feel is best: I would gradually work up from your test_marlin_24 test here, by writing a nn.Linear layer that outputs the same values as your reference mm implementation.

Basically my thinking is that we go from mm unit test -> nn.Linear unit test -> e2e test. Let me know if you have any questions, I'll try to answer them the best I can.

I think that suggestion is pretty good. I will follow that flow to debug this issue.

As far as deadlines - just FYI: 9/6 is the deadline for our torchAO announcement, with the 9/16 is the deadline for making updates to our PTC poster.

Thanks for sharing the deadlines. I was getting a bit concerned because I know that this is a time sensitive task and I don't want to risk missing an important date. I will try to surface updates and blocks as often as I can.

int_data, scale, zero_point = self.get_plain()
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
# This is a hack, torch.compile tries to trace the __repr__ function which then calls `dequantize` function, causing an error.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these changes be reverted now with dynamo.disable decorator?

Copy link
Contributor Author

@Diogo-V Diogo-V Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, will take care of that. let me just take care of finding this bug first

@Diogo-V Diogo-V force-pushed the 549-sparse-marlin-kernel branch 2 times, most recently from 9bcc422 to 694e2d1 Compare August 26, 2024 09:16
@Diogo-V Diogo-V force-pushed the 549-sparse-marlin-kernel branch from e78d93f to 552603d Compare September 5, 2024 22:34
Diogo-V and others added 6 commits September 5, 2024 22:35
fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2
* wip

* feat: wip
@Diogo-V Diogo-V force-pushed the 549-sparse-marlin-kernel branch from 552603d to ea70c74 Compare September 5, 2024 22:35
@jcaip
Copy link
Contributor

jcaip commented Sep 5, 2024

I tried runnning this with SAM but was running into a CUDA graph error, however with the 24_sparse_hf_example, I'm seeing:

Tokens/second: 134.497 for 2:4 sparse marlin 
Tokens/second: 111.182 for our int4 tinygemm version 
Tokens/second: ~80 for 2:4 sparse
Tokens/second: ~75 for compile baseline

🚀 🚀 🚀 cc @Diogo-V

I'll write this up and update the README @jerryzh168 for BE day tomorrow. It would be best if we could hook up into @HDCharles LLaMa benchmarks for consistency.

@jerryzh168
Copy link
Contributor

I had previous gotten around this issue by installing the nightly version as you had suggested. Is this the correct thing to do or should I try something else?

only test this in nightly sounds good if that's the issue

@jerryzh168
Copy link
Contributor

jerryzh168 commented Sep 6, 2024

@jerryzh168 - Just addressed your review and I have a question:

You suggested adding a test in the file test_sparse_api.py that would be similar to those there. Wouldn't that be repeating the tests implemented in test_marlin.py?

Was this the kind of test that you were expecting to add on that file or did you have something else on your mind?

    def test_sparse_marlin(self):
        input = torch.rand((256, 256)).half().cuda()
        model = (
            nn.Sequential(
                nn.Linear(256, 1024),
                nn.Linear(1024, 256),
            )
            .half()
            .cuda()
        )

        apply_fake_sparsity(model)
        model_copy = copy.deepcopy(model)

        # Quantized
        quantize_(model_copy.bfloat16(), int4_weight_only())
        dense_result = model_copy(input.bfloat16()).half()

        # Sparse + quantized
        quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
        sparse_result = model(input)

        assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

Let me know if there is anything else that needs to be updated!

yeah looks good to me, this is trying to test the e2e API so it's slightly different from test_marlin.py I think

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, please move the test we are discussing to test_sparse_api.py and make sure CI passes before merging

@@ -1219,6 +1413,47 @@ def _linear_fp_act_fp8_weight_impl(
).reshape(out_shape)


def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jcaip now we support registering layout type and these things in a separate file, maybe we can create a folder for all sparse layouts and move these

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds like a good idea.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Sep 6, 2024

@jerryzh168 - Should be good to go now :)

@jcaip
Copy link
Contributor

jcaip commented Sep 6, 2024

Perfect, let me add my benchmarking updates to this branch so it's easier for Andrew to cherry-pick the whole thing.

BTW I'm seeing 226 tok/s for LLama3 - a 25% increase from 180 tok/s for our tiny gemm int4 kernel.
The perplexity is bad, we'll have to use the neuralmagic checkpoint for a fair comparison, but the weights are in a different format than the LLama3 checkpoint and I don't want to block this on that.

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Sep 6, 2024

@jcaip - Yeah, I also noticed that sometimes the model's outputs were not super great. My simple sanity check was comparing against a quantize_(fake_sparsity(model), int4_wo(layout=Marlin...)) version and they were equal most of the time.

Let me know if there is anything else that I might be of help to get this across today and feel free to ping me on discord (Diogo-V on #cuda-mode) if you run into any issues!

@jcaip
Copy link
Contributor

jcaip commented Sep 6, 2024

Yes, for context, I am seeing the same text generation output on the hf example with the neuralmagic 2:4 sparse model, when I run quantize_(model, int4_wo()) and quantize_(model, int4_wo(layout=SparseMarlinLayout())) so I don't think it's an implementation issue. I'm guessing we'll may need to also use GPTQ as well for the best accuracy numbers.

But I'm very pleased with the speedup you've achieved (+25%), thats a great result :)

@Diogo-V
Copy link
Contributor Author

Diogo-V commented Sep 6, 2024

I wouldn't mind taking a look at adding support for GPTQ with Sparse Marlin if you feel like it would be a good thing to work on afterwards. I would have some time for it after the 21st of September.

@jcaip jcaip merged commit 65d86c6 into pytorch:main Sep 6, 2024
17 checks passed
andrewor14 pushed a commit that referenced this pull request Sep 6, 2024
* feat: starting layout implementation

fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2

* compile kind of working

* fix: batching and layout outputs correct results

* fix: torch.compile

* wip

* feat: wip

* chore: cleanup

* chore: review

* chore: review v2

* update benchmarks + README

---------

Co-authored-by: Jesse Cai <jcjessecai@gmail.com>
andrewor14 added a commit that referenced this pull request Sep 6, 2024
andrewor14 added a commit that referenced this pull request Sep 6, 2024
andrewor14 added a commit that referenced this pull request Sep 6, 2024
HDCharles pushed a commit that referenced this pull request Sep 9, 2024
* feat: starting layout implementation

fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2

* compile kind of working

* fix: batching and layout outputs correct results

* fix: torch.compile

* wip

* feat: wip

* chore: cleanup

* chore: review

* chore: review v2

* update benchmarks + README

---------

Co-authored-by: Jesse Cai <jcjessecai@gmail.com>
jainapurva pushed a commit that referenced this pull request Sep 9, 2024
* feat: starting layout implementation

fix: namespace of common modules

chore: remove not needed test file

fix: op name being registered

chore: can compile the cuda kernel

fix: segmentation fault

chore: wip - paste test code just to check if everything passes

feat: wip - adding layout. unpack not working

fix: circular import

feat: wip - can almost revert

feat: can unpack. just needs cleanup

chore: improve layout code

chore: wip - mm needs work

feat: wip - something seems wrong

fix: e2e test

feat: wip - add group param

fix: unpack weights

feat: marlin is implemented and correct

chore: rebase

chore: remove old import

feat: use int4 instead of dequantizing

chore: remove unused fn

feat: add checks and validation

feat: add new kernel and refactor code (#1)

* feat: wip - adding new kernel

* feat: wip - continue working on the unpack

* feat: wip - working on unpacking

* feat: remove old op

* feat: more code changes

* chore: remove old code

* feat: more code

* chore: more code changes

* chore: more code changes

* feat: add more documentation

* fix: dataclass

* feat: add more docs

* feat: remove assert

chore: block 8 bits

chore: update comment

feat: refactor dispatch

chore: add validation on group size

chore: wip - working on fixing unpack

feat: add small readme with sources

feat: add checks

feat: tests pass & can execute llama2

* compile kind of working

* fix: batching and layout outputs correct results

* fix: torch.compile

* wip

* feat: wip

* chore: cleanup

* chore: review

* chore: review v2

* update benchmarks + README

---------

Co-authored-by: Jesse Cai <jcjessecai@gmail.com>
jainapurva pushed a commit that referenced this pull request Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add 2:4 sparse marlin kernels to torchao
6 participants