Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Add Load-time W8A16 quantization for TPU Backend #7005

Merged
merged 7 commits into from
Aug 9, 2024

Conversation

lsy323
Copy link
Contributor

@lsy323 lsy323 commented Jul 31, 2024

Add Load-time W8A16 quantization for TPU Backend. The workflow is similar to the existing load-time fp8 quantization. Open the PR to help discussion process.

  • Added a new quantization type tpu_int8 for load-time int8 weight only quantization for tpu Backend. (e.g. LLM(model="google/gemma-2b", quantization="tpu_int8")
  • Added TPUInt8LinearMethod which quantizes bfloat16 weights to int8 weights for linear layers, and calls TPU quantized ops in forward.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Jul 31, 2024
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, this looks reasonable to me. The important note is to lazily import the torch xla function when needed, rather than at the top of the quant file.

examples/offline_inference_tpu.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/tpu_int8.py Outdated Show resolved Hide resolved
@robertgshaw2-neuralmagic
Copy link
Collaborator

Super cool!

As a follow up, we can work on hooking this up to some of the existing checkpoints we have in addition to inplace quantization

@robertgshaw2-neuralmagic
Copy link
Collaborator

By chance, what schemes does the following support:

  • torch.ops.xla.quantized_matmul(x, weight, scale)

Channelwise?
Activations?

@lsy323 lsy323 changed the title Add Load-time W8A16 quantization for TPU Backend [TPU] Add Load-time W8A16 quantization for TPU Backend Aug 1, 2024
@lsy323 lsy323 force-pushed the lsiyuan/quant branch 2 times, most recently from f4b8dd7 to b1a04b3 Compare August 2, 2024 21:37
@lsy323 lsy323 requested a review from mgoin August 2, 2024 23:47
@lsy323
Copy link
Contributor Author

lsy323 commented Aug 2, 2024

Hi @mgoin, @robertgshaw2-neuralmagic,

Thank your for reviewing my PR! Excited to work with you to enable quantization for TPU backend through compressed-tensors!

By chance, what schemes does the following support:

  • torch.ops.xla.quantized_matmul(x, weight, scale)

Channelwise? Activations?

We have the quantized ops (Equivalent to the quantized cuda kernels in vLLM, but for TPU) in PyTorch/XLA here. The quantized matmul kernel is registered as a torch op and is compatible with torch.compile, it can be configured to support per-channel/blockwise quantization, both int8 and int4 are supported (int4 is not optimized now) The quantized matmul supporting matrix is here.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Hi @mgoin, @robertgshaw2-neuralmagic,

Thank your for reviewing my PR! Excited to work with you to enable quantization for TPU backend through compressed-tensors!

By chance, what schemes does the following support:

  • torch.ops.xla.quantized_matmul(x, weight, scale)

Channelwise? Activations?

We have the quantized ops (Equivalent to the quantized cuda kernels in vLLM, but for TPU) in PyTorch/XLA here. The quantized matmul kernel is registered as a torch op and is compatible with torch.compile, it can be configured to support per-channel/blockwise quantization, both int8 and int4 are supported (int4 is not optimial) The quantized matmul supporting matrix is here.

This is so awesome!!!! Running the same compressed models on various hardware backends is going to be an awesome feature

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lsy323 Thanks for the PR! I'm really looking forward to using this feature!

I think we have two things to figure out on this PR:

  1. Where to put this quantization config and the linear method? Do we want to put this as a new quantization config (like in the current PR) or in compressed-tensors?
  2. IIRC, this currently does not support the case when the BF16 weights exceed the HBM size of the TPU while INT8 weights do (e.g., Llama 8B on TPUv5e which has 16 GB HBM). Could you please remind us of why this isn't supported?

vllm/model_executor/layers/quantization/tpu_int8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/tpu_int8.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Show resolved Hide resolved
vllm/model_executor/layers/quantization/tpu_int8.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/tpu_int8.py Outdated Show resolved Hide resolved
@miladm
Copy link

miladm commented Aug 5, 2024

TorchXLA:TPU FP8 support is WIP (partially supported). @lsy323 do we have an outlined plan somewhere that extends this effort to FP8?

@lsy323
Copy link
Contributor Author

lsy323 commented Aug 5, 2024

  1. Where to put this quantization config and the linear method? Do we want to put this as a new quantization config (like in the current PR) or in compressed-tensors?

I think for compressed-tensors config, we assume all the checkpoints are in compressed-tensors format. This load-time quantization doesn't seem to belong that flow, hence I think keeping it in a separate file looks cleaner.

  1. IIRC, this currently does not support the case when the BF16 weights exceed the HBM size of the TPU while INT8 weights do (e.g., Llama 8B on TPUv5e which has 16 GB HBM). Could you please remind us of why this isn't supported?

Sure, the current flow is:

  1. move bfloat16 weights from host to TPU
  2. Quantize bfloat16 weights to int8

When the BF16 weights exceed the HBM size of the TPU, step 1 would hit OOM. To avoid this problem, we can delay weight transferring if load-time quantization is enabled.

@lsy323
Copy link
Contributor Author

lsy323 commented Aug 6, 2024

TorchXLA:TPU FP8 support is WIP (partially supported). @lsy323 do we have an outlined plan somewhere that extends this effort to FP8?

I don't have a crystal plan for this, alternatives are as follows:

  1. Reuse fp8.py which works for CUDA workflow. (Support both load-time fp8 quantization and offline quantized fp8 ckpt using TensorRT-LLM ref)
  2. Have another quantizaiotn config for fp8 on TPU (e.g. fp8_tpu)
  3. Extend compressed_tensors to support TPU backend (This will support fp8 ckpt in compressed_tensors format)

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Aug 6, 2024

TorchXLA:TPU FP8 support is WIP (partially supported). @lsy323 do we have an outlined plan somewhere that extends this effort to FP8?

I don't have a crystal plan for this, alternatives are as follows:

  1. Reuse fp8.py which works for CUDA workflow. (Support both load-time fp8 quantization and offline quantized fp8 ckpt using TensorRT-LLM ref)
  2. Have another quantizaiotn config for fp8 on TPU (e.g. fp8_tpu)
  3. Extend compressed_tensors to support TPU backend (This will support fp8 ckpt in compressed_tensors format)

Hey guys - there are a couple considerations here. For vLLM, we want to support both cases:

  • In-place quantization
  • Pre-quantized checkpoints

We will be making all go forward checkpoints inside the compressed-tensors integration for mixed precision, integer activation quantization, and floating point activation quantization. And so, I think we should focus on this pathway.

Both fp8.py and compressed-tensors share the same backend code (you can see they both use apply_fp8_linear. We factored this utility out, such that the kernel calls are shared by the various integrations. If you add the TPU calls to this function, then you should get all these integrations "for free"

@lsy323
Copy link
Contributor Author

lsy323 commented Aug 7, 2024

  1. IIRC, this currently does not support the case when the BF16 weights exceed the HBM size of the TPU while INT8 weights do (e.g., Llama 8B on TPUv5e which has 16 GB HBM). Could you please remind us of why this isn't supported?

hi @WoosukKwon, I looked into this in detail, it doesn't look like to be a straightforward change, I think we can consider support that in a separate PR.

In the current flow, weights are moved to device as the model is initialized (ref), then load time quant will be done on device ref. We need to introduce a new flow to support this case.

@WoosukKwon
Copy link
Collaborator

@lsy323 Seems like my previous comment was not addressed for some reason. Can you please check it again?

lsy323 and others added 3 commits August 8, 2024 10:09
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@lsy323
Copy link
Contributor Author

lsy323 commented Aug 8, 2024

@lsy323 Seems like my previous comment was not addressed for some reason. Can you please check it again?

@WoosukKwon Somehow I force pushed without the suggested change commits. Now should be fixed. Thank you for reminding!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR! It works well on my machine 🎉 🎉

Looking forward to the next step! (adding INT8 activation quantization in tpu-int8).

@WoosukKwon WoosukKwon merged commit 0fa1490 into vllm-project:main Aug 9, 2024
27 checks passed
@lsy323 lsy323 deleted the lsiyuan/quant branch August 9, 2024 21:09
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants