Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AWQ quantization support for LLaMA #1032

Merged
merged 95 commits into from
Sep 16, 2023
Merged

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Sep 13, 2023

This PR adds initial support for AWQ. To minimize the scope, it only covers LLaMA models for now. This PR is adopted from great previous PRs #762 and #926 written by @ri938 @casper-hansen @julian-q .

NOTE: We need to do refactoring after this PR. Especially, the weight loading logic and ParallLinear layers are messy at the moment.

Tested:

  • python examples/llm_engine_example.py --model casperhansen/vicuna-7b-v1.5-awq --quantization awq
  • python examples/llm_engine_example.py --model casperhansen/vicuna-7b-v1.5-awq --quantization awq -tp 2
  • python examples/llm_engine_example.py --model abhinavkulkarni/lmsys-vicuna-33b-v1.3-w4-g128-awq --quantization awq -tp 2

@WoosukKwon WoosukKwon deleted the add_awq_quant_support branch September 16, 2023 07:06
@WoosukKwon
Copy link
Collaborator Author

@casper-hansen

Optimizing GEMM kernels for high throughput is an open problem when it comes to quantized models. My own tests indicated it is fastest below batch size 8 and it has equal performance at batch size 16 compared to FP16.

Thanks for letting us know! Do you think it would be possible to use OpenAI Triton for writing more optimized AWQ kernel?

@TheBloke
Copy link

Great work on the merge.

I'll start releasing AWQ models this weekend. Will aim to do a couple of hundred by early next week. Primarily Llama 2 models, plus some of the old still-popular Llama 1s.

@WoosukKwon WoosukKwon mentioned this pull request Sep 16, 2023
@casper-hansen
Copy link
Contributor

@casper-hansen

Optimizing GEMM kernels for high throughput is an open problem when it comes to quantized models. My own tests indicated it is fastest below batch size 8 and it has equal performance at batch size 16 compared to FP16.

Thanks for letting us know! Do you think it would be possible to use OpenAI Triton for writing more optimized AWQ kernel?

I would absolutely love it if someone wrote a better kernel whether in Triton or CUDA.

I do believe a Triton kernel is possible, there is just a few obstacles like how to dequantize and how to run optimized instructions like asm volatile. Triton kernels exist for GPTQ but it has also been shown that CUDA kernels are faster (ExLlama v2).

@WoosukKwon WoosukKwon mentioned this pull request Sep 16, 2023
@WoosukKwon
Copy link
Collaborator Author

WoosukKwon commented Sep 16, 2023

@TheBloke

I'll start releasing AWQ models this weekend. Will aim to do a couple of hundred by early next week. Primarily Llama 2 models, plus some of the old still-popular Llama 1s.

Got it. Thanks for letting us know! Actually, there's no blocker for us to add AWQ support to other model types. Once you upload the AWQ models, we can add other model support accordingly.

@casper-hansen
Copy link
Contributor

@TheBloke

I'll start releasing AWQ models this weekend. Will aim to do a couple of hundred by early next week. Primarily Llama 2 models, plus some of the old still-popular Llama 1s.

Got it. Thanks for letting us know! Actually, there's no blocker for us to add AWQ support to other model types. Once you upload the AWQ models, we can add other model support accordingly.

If you want something to test with, I have repositories for Falcon and MPT. I am sure The Bloke will send out so many more models soon though.

https://huggingface.co/casperhansen/mpt-7b-8k-chat-awq
https://huggingface.co/casperhansen/falcon-7b-awq

@WoosukKwon
Copy link
Collaborator Author

@casper-hansen

I do believe a Triton kernel is possible, there is just a few obstacles like how to dequantize and how to run optimized instructions like asm volatile. Triton kernels exist for GPTQ but it has also been shown that CUDA kernels are faster (ExLlama v2).

Got it. So you mean the dequantization logic is not easy to implement in triton, right? Hmm... I think for now we will focus on code cleaning & supporting different quantization schemes without considering much about their performance, and later will get back to this performance issue.

@casper-hansen
Copy link
Contributor

@casper-hansen

I do believe a Triton kernel is possible, there is just a few obstacles like how to dequantize and how to run optimized instructions like asm volatile. Triton kernels exist for GPTQ but it has also been shown that CUDA kernels are faster (ExLlama v2).

Got it. So you mean the dequantization logic is not easy to implement in triton, right? Hmm... I think for now we will focus on code cleaning & supporting different quantization schemes without considering much about their performance, and later will get back to this performance issue.

Yes, dequantization is easily the most tricky part to get right. The rest is normal FP16 operations.

I agree on the strategy. When you get to GPTQ, you should definitely explore all the optimized work like Exllama V2. It’s the fastest repository for running GPTQ models at batch size 1 but has not been tested much for high throughput.

@esmeetu
Copy link
Collaborator

esmeetu commented Sep 16, 2023

@WoosukKwon Doesn't it support Turing arch? my GPU's compute capabitlity is 7.5. CUDA-12.1.

build Error message:

ptxas /tmp/tmpxft_0006e7c4_00000000-6_gemm_kernels.ptx, line 928; error : Feature '.m16n8k16' requires .target sm_80 or higher

If not, hope can add backward compatibility for kernel build.

@casper-hansen
Copy link
Contributor

@WoosukKwon Doesn't it support Turing arch? my GPU's compute capabitlity is 7.5. CUDA-12.1.

build Error message:

ptxas /tmp/tmpxft_0006e7c4_00000000-6_gemm_kernels.ptx, line 928; error : Feature '.m16n8k16' requires .target sm_80 or higher

If not, hope can add backward compatibility for kernel build.

No, it supports SM80 and up. So Ampere and later is supported.

@esmeetu
Copy link
Collaborator

esmeetu commented Sep 16, 2023

@casper-hansen Ok..., do you mean AWQ quant doesn't support <8.0? And For <8.0, only choice is GPTQ for now?

@casper-hansen
Copy link
Contributor

@casper-hansen Ok..., do you mean AWQ quant doesn't support <8.0? And For <8.0, only choice is GPTQ for now?

Yes, so Turing is not supported. T4 is not supported. The GEMM kernel makes use of tensor cores, but some instructions in the kernel make it incompatible with 7.5.

@esmeetu
Copy link
Collaborator

esmeetu commented Sep 16, 2023

@casper-hansen Thanks for you reply.

@ryanshrott
Copy link

@WoosukKwon Would you mind providing example usage of AWQ support for a llama2 fine-tuned variant like this one?
https://huggingface.co/rshrott/description-awq-4bit/tree/main

@casper-hansen
Copy link
Contributor

casper-hansen commented Sep 17, 2023

Got it. So you mean the dequantization logic is not easy to implement in triton, right? Hmm... I think for now we will focus on code cleaning & supporting different quantization schemes without considering much about their performance, and later will get back to this performance issue.

@WoosukKwon After more research, I found dequantization int4 kernels for Triton below. Seems doable, just needs to fit the AWQ format.
https://github.com/ModelTC/lightllm/blob/main/lightllm/common/basemodel/triton_kernel/dequantize_gemm_int4.py

@AniZpZ
Copy link

AniZpZ commented Sep 20, 2023

We tested awq with vllm as well and found it's hard to get throughput improvement with W4A16 method. Therefore, we impelement W8W8 method(produce int8 weight with smoothquant) which can increase throughput by 20% this pr: #1112

@viktor-ferenczi
Copy link
Contributor

It is not only about the throughput. It is also the much worse quality (perplexity) of 4-bit weight quantized models than the 8-bit ones. I've tested 4-bit AWQ for my use cases and it is just not cutting it. The heavily quantized model makes frequent small mistakes. Running the same model at full 16 bits (or from a 8 bit GGUF using ctransformers) don't make such mistakes at all. I'm looking forward for the W8A8 quantization mostly because of this. Model: WizardCoder 13B and 34B

@vince62s
Copy link

vince62s commented Dec 2, 2023

How did you guys deal with this: casper-hansen/AutoAWQ#234
GEMM and GEMV have their param reversed (out, in) "standard way for GEMV but (in, out) for GEMM which trigger a different way to slice when using tensor parallel.

@vedantroy
Copy link

vedantroy commented Dec 9, 2023

@WoosukKwon

I implemented Triton kernels for AWQ inference. They are much faster then the existing CUDA kernels, especially at larger batch sizes:
image

They are also simpler (core kernel is ~ 50-100 lines of Triton).

The weights format is a bit different (it's using the most recent AWQ weight format, albeit weights are transposed):

    a: (M, K)
    qw: (K // pack_num, N)
    scales: (K // group_size, N)
    qzeros: (K // group_size // pack_num, N)

You can find the code here: https://github.com/vedantroy/gpu_kernels/tree/main
I'd be happy to help integrate these kernels into VLLM!

Warning: I'm pretty sure my kernels are correct, but not 100%. I'll be emailing the authors to double-check.

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.