Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Dynamic Per-Token Activation Quantization #5037

Merged
merged 80 commits into from
Jun 7, 2024

Conversation

dsikka
Copy link
Contributor

@dsikka dsikka commented May 24, 2024

Summary

  • Adds an additional CompressedTensorsScheme to support w8a8 models with dynamic per-token activation quantization, CompressedTensorsW8A8DynamicToken. This scheme adds support for w8a8 dynamic per-token models quantized through sparseml and saved through compressed-tensors
  • Expands CompressedTensorsConfig to use QuantizationArgs, QuantizationStrategy and find_first_name_or_class_match to help match the appropriate scheme to each layer.
  • Add dynamic_int8_quant_kernel CUDA kernel that performs int8 dynamic-per-token quantization
  • Refactor utilities in reduction_utils.cuh

From Neural Magic, co-authored by @varun-sundar-rabindranath

dsikka and others added 30 commits April 30, 2024 18:50
…for static W8A8 per tensor (#195)

- Depending on how we end up parsing `ignore` and `targets` (layer_name
vs layer_type) we may not need layer_name to be added to the
linear_method. Will experiment using a compressed-tensors function in a
follow-up PR

- Initial implementation for Compressed Config support + Activation
Quantization for static per tensor w8a8
- Includes fused kernels added by @varun-sundar-rabindranath

```python
from vllm import LLM, SamplingParams
import torch

prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The US president is",
    "The future of AI is"
]
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)

llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml")

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- Verification of the different inputs expected for `targets` and
`ignore` --> use functions to parse the layer names which can be shared
by sparseml and vllm; would live in compressed tensors
(https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86)
- Updates to further optimize fake qunat

---------

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
vllm CI fixes

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
lazy cutlass_gemm_dq import

---------

Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
# Dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
dtype=torch.float32,
device="cuda")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to keep consistent with fp8, maybe this should be input.device

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it 👍

"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
if scale is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the names of the variables used internally in this function match the scaled_fp8_quant function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed q to output. I believe the other variables are good as it is. please take a look. Thanks.

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

csrc/ops.h Outdated
Comment on lines 100 to 101
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scales);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scales);
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);

const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

float amax_val = 0.0f;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would it be more readable as absmax_val?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. amax is confusing.

const float zero = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = (float)input[token_idx * hidden_size + i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's best to use static cast instead of C-style casts when possible, since they are checked by the compiler.

Suggested change
float val = (float)input[token_idx * hidden_size + i];
float val = static_cast<float>(input[token_idx * hidden_size + i]);

Comment on lines 69 to 70
out[token_idx * hidden_size + i] = float_to_int8_rn(
((float)input[token_idx * hidden_size + i]) * tmp_scale);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out[token_idx * hidden_size + i] = float_to_int8_rn(
((float)input[token_idx * hidden_size + i]) * tmp_scale);
out[token_idx * hidden_size + i] = float_to_int8_rn(
(static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale));

Comment on lines +42 to +46
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a common place we can put CUDA utils like this? We have the exact same helper fn in csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some sleuthing, but can't find a good place to put it. Should we create a math_utils.cuh file ? @robertgshaw2-neuralmagic @mgoin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need another refactoring for csrc/quantization...but I don't have an out-of-box solution for this ATM.

@@ -10,21 +10,52 @@
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a larger hidden size (> 1024) that's not nice number as well? I see 5120, but it is a multiple of 256

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added hidden-sizes 5137 and 8193

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +42 to +46
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need another refactoring for csrc/quantization...but I don't have an out-of-box solution for this ATM.

@WoosukKwon WoosukKwon merged commit ca3ea51 into vllm-project:main Jun 7, 2024
101 of 103 checks passed
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Jun 10, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 11, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 27, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants