Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading quantized models #392

Closed
abhinavkulkarni opened this issue Jul 7, 2023 · 31 comments
Closed

Loading quantized models #392

abhinavkulkarni opened this issue Jul 7, 2023 · 31 comments

Comments

@abhinavkulkarni
Copy link

abhinavkulkarni commented Jul 7, 2023

Hi,

Is there a way to load quantized models using vLLM? For e.g. I have been using AWQ quantization and have released a few models here.

The model loading process looks like the following:

  1. Model is first initialized with empty weights
  2. Linear layers are replaced by a custom linear layer that supports zero-point quantization
  3. Weights are loaded from a checkpoint onto a GPU

Please note the matrix multiplication inside the linear layer is done by a Python extension that uses custom CUDA kernels since AWQ uses 4-bit quantization. Everything else - the attention mechanism, etc. is the same.

The model otherwise supports all the HuggingFace AutoModelForCausalLM APIs.

Thanks!

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Jul 8, 2023

Hi @abhinavkulkarni, thanks for exploring vLLM and requesting the feature! For now, vLLM doesn't support quantized models. But as you mentioned, we believe it is not very difficult to add it into vLLM. We will definitely look into it after finishing other urgent issues (e.g., Falcon support, and bug fixes). Besides, we are looking forward to contributions from community! Please make a PR if you have time to implement it.

@johnnysmithy123467
Copy link

Hello, I am considering adding support for 4 / 8 bit quantization into VLLM. What specifically would have to be done to implement this feature? Can I emulate the on the fly style weight quantization done in the fast chat repo?: https://github.com/lm-sys/FastChat/blob/main/fastchat/model/compression.py

@abhinavkulkarni
Copy link
Author

@johnnysmithy123467: It would be best if the models are loaded upfront before they are passed to vLLM. Usually, quantized models modify the structure of the base model by replacing Linear and LayerNorm layers with custom ones that support zero-point quantization. They also use custom CUDA extensions to do 4-bit matrix multiplication.

You can see how GPTQ quantized models are loaded here and AWQ quantized models here.

@TheBloke
Copy link

Given bitsandbytes recent work on 4bit performance - now claiming up to 4.2x faster performance from 4bit versus fp16 - adding bitsandbytes would seem like an obvious, and perhaps easiest, first quantization format to support?

image

@johnnysmithy123467
Copy link

@TheBloke what would it take to add bits and bytes support? considering implementing it myself and just want to know how I would go about doing it, seems the weight loading schemes are very customized on vllm

@casper-hansen
Copy link
Contributor

Given bitsandbytes recent work on 4bit performance - now claiming up to 4.2x faster performance from 4bit versus fp16 - adding bitsandbytes would seem like an obvious, and perhaps easiest, first quantization format to support?

image

I like the theoretical speedup but the problem is they do not support serialization/deserialization of the 4-bit models. Until that is resolved, it will not be great for production usage.

@creatorrr
Copy link

@zhuohan123 support for bnb-nf4 would be amazing. Especially for larger models. If you can point me in the right direction, I’d be happy to take a crack at implementing it

@casper-hansen
Copy link
Contributor

casper-hansen commented Jul 25, 2023

-1 for BNB and +1 for AWQ.

I will just let my test results speak for themselves. This is for the model mosaicml/mpt-7b-8k-chat. EDIT: LLaMa-2 7B also runs at 103 tokens/s with AWQ TinyChat on a 4090 according to my own testing.

System: 1 x RTX A6000, 4 vCPU, 61 GB RAM (RunPod Community Cloud)
BNB (NF4): 7.5 tokens/s (load time: 16.41 s)
AWQ (W4A16): 28 tokens/s (load time: 14.27 s)
AWQ tinychat (W4A16): 46.86 tokens/s (load time: 22.29s)

This is torch 2.0.1, bitsandbytes 0.41.0, transformers 4.31.0, and AWQ compiled from main.

CC @TheBloke and @WoosukKwon

@casper-hansen
Copy link
Contributor

117 tokens/s now on 4090 with MPT-7B model.

https://huggingface.co/casperhansen/mpt-7b-8k-chat-awq

@ri938
Copy link
Contributor

ri938 commented Jul 28, 2023

Was thinking of maybe working on this myself.

My understanding is that I only need to replace the linear layers. And I dont need to touch the attention code. Its a lot to simpler to implement if I dont need to get into the attention code with its custom C++ kernels and paged attention. Can someone confirm my understanding there?

I was going to use GPTQ and use text-generation-inference library as a template for how to implement this.

@johnnysmithy123467
Copy link

@ri938 That sounds right. Let us know how it goes, eager to hear from you. at least someone's taking a crack at it!

@ri938
Copy link
Contributor

ri938 commented Aug 6, 2023

Update on this: I have a hacky proof of concept quantisation working which reduces memory use and inference quality looks high quality.

However inference speed is much slower. Even with a single batch size. I tried the default and exllama kernels from TGI repo but both were about 60% the inference speed as with fp16 models. Its too slow for my particular use case and I think its important to get it to comparable inference speeds as fp16.

One option is I could open a PR for quantisation and then allow others with more experience in CUDA optimisations to then contribute to this. Otherwise I probably will continue to work on this.

@casper-hansen
Copy link
Contributor

Update on this: I have a hacky proof of concept quantisation working which reduces memory use and inference quality looks high quality.

However inference speed is much slower. Even with a single batch size. I tried the default and exllama kernels from TGI repo but both were about 60% the inference speed as with fp16 models. Its too slow for my particular use case and I think its important to get it to comparable inference speeds as fp16.

One option is I could open a PR for quantisation and then allow others with more experience in CUDA optimisations to then contribute to this. Otherwise I probably will continue to work on this.

GPTQ is pretty slow in general unless you build hyper optimized kernels on top. Did you try AWQ, might give you better results?

@ri938
Copy link
Contributor

ri938 commented Aug 6, 2023

Not tried AWQ. But for GPTQ I did use the kernels from text-generation-inference (including a port of exllama) so I thought this should be good performance.

@ri938
Copy link
Contributor

ri938 commented Aug 9, 2023

AWQ much faster than GPTQ. I feel it could be optimised a lot more and plan to give this a try.

Performance declines with larger batch sizes faster than fp16 version. Which means I can only reach batch size half of what I could with FP16 before reaching the peak throughput.

Think I'll maybe make a merge request and others with more experience with CUDA can continue to optimise it further.

@Jorghi12
Copy link

Jorghi12 commented Aug 9, 2023

Hey @ri938, I'd be happy to contribute to your repository if you'd like.

@petrasS3
Copy link

We need to add support for the quantized model in the VLLM project. We need this to run a llama quantized model via vllm. This involves implementing quantization techniques to optimize memory usage and runtime performance. A reward of $500 will be granted to the contributor who successfully completes this task.

@ri938
Copy link
Contributor

ri938 commented Aug 14, 2023

#762

work on this so far. Some TODOs and cleanup still required.

I tried lots of different quantization methods and AWQ performed the best.

WIP: if anyone wants to contribute more can send a MR

An initial review and some comments on what we still need to do to merge this would be appreciated. @WoosukKwon

@ri938
Copy link
Contributor

ri938 commented Aug 14, 2023

I think there is lots of room for optimizations because quantization scales poorly with batch size. But I reckon thats a seperate issue.

@wasertech
Copy link

My take on the issue is that quantization is a useful and important feature for LLMs, as it can enable faster and more efficient inference of such large language models. I think it would be beneficial to add support for different quantization techniques, such as bnb_nf4 , GPTQ, and AWQ, and allow users to choose the best one for their use case. I also think it would be helpful to provide some documentation and examples on how to use quantized models with vLLM. I appreciate the efforts of the contributors who are working on this issue and I hope they can successfully complete it soon.

@gesanqiu
Copy link
Contributor

Have you guys watched the OmniQuant? I think is much easier to integrate OmniQuant in vLLM than GPTQ or AWQ.

@viktor-ferenczi
Copy link
Contributor

OmniQuant paper for reference: https://arxiv.org/abs/2308.13137

@wasertech
Copy link

So, just in case you missed the exciting news, vLLM now proudly supports AWQ 🎉.
I was eager to test it on my Titan RTX, but it has a compute capability of 7.5, which is currently a work in progress (check out #1282 and #1252 for details). However, if your GPU can handle 8.0, you're in luck, as you can serve quantized models using the --quantization awq --dtype half flags.

@shatealaboxiaowang
Copy link

AWQ much faster than GPTQ. I feel it could be optimised a lot more and plan to give this a try.

Performance declines with larger batch sizes faster than fp16 version. Which means I can only reach batch size half of what I could with FP16 before reaching the peak throughput.

Think I'll maybe make a merge request and others with more experience with CUDA can continue to optimise it further.

I conducted performance tests on codellama-13B-AWQ, and the results are as follows:
The latency per request is smaller than the non-quantized model(codellama-13B-hf), but the average response time is worse than the non-quantized model when the concurrency is increased.

this is why ?

Looking forward to your reply very much!

@gesanqiu
Copy link
Contributor

AWQ much faster than GPTQ. I feel it could be optimised a lot more and plan to give this a try.
Performance declines with larger batch sizes faster than fp16 version. Which means I can only reach batch size half of what I could with FP16 before reaching the peak throughput.
Think I'll maybe make a merge request and others with more experience with CUDA can continue to optimise it further.

I conducted performance tests on codellama-13B-AWQ, and the results are as follows: The latency per request is smaller than the non-quantized model(codellama-13B-hf), but the average response time is worse than the non-quantized model when the concurrency is increased.

this is why ?

Looking forward to your reply very much!

The REDEME of AutoAWQ explain this issue, AWQ is not good in large context or large batch scenario. If you try more test case, you will also found that the prefilling latency of AWQ model is much larger than FP16 model.

@shatealaboxiaowang
Copy link

AWQ much faster than GPTQ. I feel it could be optimised a lot more and plan to give this a try.
Performance declines with larger batch sizes faster than fp16 version. Which means I can only reach batch size half of what I could with FP16 before reaching the peak throughput.
Think I'll maybe make a merge request and others with more experience with CUDA can continue to optimise it further.

I conducted performance tests on codellama-13B-AWQ, and the results are as follows: The latency per request is smaller than the non-quantized model(codellama-13B-hf), but the average response time is worse than the non-quantized model when the concurrency is increased.
this is why ?
Looking forward to your reply very much!

The REDEME of AutoAWQ explain this issue, AWQ is not good in large context or large batch scenario. If you try more test case, you will also found that the prefilling latency of AWQ model is much larger than FP16 model.

Thank you very much for your reply, i got it!
is there any optimization technology that can improve the generation speed and greatly improve the throughput under the same hardware resources ?

@gesanqiu
Copy link
Contributor

@shatealaboxiaowang Right now I got two directions:

  1. SmoothQuant W8A8, it use torch-int and truly INT8 kernel, throughput is higher than current AWQ solution, see Support W8A8 inference in vllm #1508 for more details, it still WIP.
  2. speculative decoding, like Medusa Attention, see [Discussion] Will vLLM consider using Speculative Sampling to accelerating LLM decoding? #1171 for more details, also WIP.

@AniZpZ
Copy link

AniZpZ commented Nov 22, 2023

shatealaboxiaowang

We have just released per token w8a8 method in #1508 and now it is fully usable.

@dhruvmullick
Copy link

Is there any information on when vLLM might support bnb quantization?

@KobanBanan
Copy link

@shatealaboxiaowang Hi! I need to pass already quantized model or vllm will quantize it while loading ?

@ann-lab52
Copy link

@shatealaboxiaowang Hi! I need to pass already quantized model or vllm will quantize it while loading ?

Hi @KobanBanan, you need to quantize the model first, then load it with vllm follow your quantize configuration.

joerunde added a commit to joerunde/vllm that referenced this issue Jul 31, 2024
With all the extra fun refactors

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this issue Oct 17, 2024
This PR raises the allowed relative tolerance in GSM8K to 0.06, and
moves Llama-70B test to 4xG2 from 2xG2 until memory usage is
investigated (success run: vLLM-CI-Pipeline/206)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests