Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Dynamic Per-Token Activation Quantization #5037

Merged
merged 80 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
4d27a2c
Initial `CompressedTensors` config + Activation Quantization support …
dsikka Apr 30, 2024
92b3703
add get_quant method to compressed tensors config
dsikka Apr 30, 2024
2a3eb83
small rebase fixed
dsikka Apr 30, 2024
3dd1fe8
format
dsikka Apr 30, 2024
f2f8c52
fix mypy complaints
Apr 30, 2024
c9308eb
Merge branch 'main' into ds-quant
dsikka Apr 30, 2024
d9d49b5
format fixes
dsikka Apr 30, 2024
b111ee6
Merge branch 'main' into ds-quant
dsikka May 1, 2024
c31a7af
format fix post rebase
dsikka May 1, 2024
ca01b39
lazy import CompressedTensorsW8A8StaticTensor (#220)
varun-sundar-rabindranath May 1, 2024
f0197d4
lazy cutlass_gemm_dq import (#221)
varun-sundar-rabindranath May 1, 2024
4624b46
fix asm
May 1, 2024
75757d5
update shape change
dsikka May 2, 2024
e1df0eb
add todo
dsikka May 2, 2024
bc0991c
Rename quant_per_tensor -> static_scaled_int8_quant
May 2, 2024
74ad650
Remove cruft
May 2, 2024
43c43f3
Merge branch 'main' into ds-quant
dsikka May 14, 2024
cf5600f
fixes : typo
May 14, 2024
169ce7f
py-cutlass temporary hack for num_prompts==1
May 15, 2024
03b53e7
yapf
May 15, 2024
f9df31b
add test_int8_quant
May 16, 2024
ba4b6b3
call cpp cutlass
May 17, 2024
3c223c6
Merge branch 'main' into ds-quant
dsikka May 17, 2024
b27f31a
remove cutlass py interface
May 17, 2024
b589cdd
format.sh
May 17, 2024
98159cf
remove fake-quant
May 17, 2024
8dbeb31
add compressed tensors test
dsikka May 17, 2024
5eeb40a
remove torch.int8
dsikka May 17, 2024
c55e023
format
dsikka May 17, 2024
f5cbbd3
fix config parsing to match new model
dsikka May 20, 2024
a685957
revert parsing to use default pathway
dsikka May 20, 2024
4dfb37f
PR comments
dsikka May 21, 2024
de81f9e
Fix scales/zero-points device allocation
May 21, 2024
15f1863
ruff
May 21, 2024
bd53847
add better comments
May 21, 2024
b2926f3
add comment
dsikka May 22, 2024
1274386
Merge branch 'main' into ds-quant
dsikka May 22, 2024
18640c8
clang format
dsikka May 22, 2024
5c5dc84
clang format again
dsikka May 22, 2024
a44b4a0
address PR comments
May 22, 2024
6f0e6e1
clang-format
May 22, 2024
0090454
remove layer name
dsikka May 23, 2024
4b10fd7
remove unused import
dsikka May 23, 2024
68a59c7
remove parent name
dsikka May 23, 2024
b0afe67
Fix rounding
May 22, 2024
4f4951e
comment
May 23, 2024
869de3f
cruft
May 23, 2024
e68e391
yapf
May 23, 2024
d77cf50
remove unquantized check
dsikka May 23, 2024
51a4e59
update parsing to use compressed-tensors; add dynamic per token parsi…
dsikka May 2, 2024
6777319
add dynamic quantization arg, fill out create_weights/apply
dsikka May 2, 2024
54c797a
Add quant_per_token kernels
May 2, 2024
6bcab22
make changes to config parsing based on sparseml updates; test dynami…
dsikka May 3, 2024
ece93e1
fix shape for cutlass issues
dsikka May 6, 2024
1d87a99
remove dicts; use quantization args directly
dsikka May 6, 2024
3dd1b5f
update compressed-tensors; add docstring
dsikka May 7, 2024
fed7cdd
Dyn per token varun cleanup (#227)
varun-sundar-rabindranath May 13, 2024
66719a9
add test_int8_quant
May 16, 2024
2ec6a2c
remove fake quant
dsikka May 24, 2024
0c7f870
Merge branch 'main' into dyn-per-token
dsikka May 24, 2024
34e2e12
format
dsikka May 24, 2024
e79517e
combine static and dynamic quant computation
May 24, 2024
39e66d1
TORCH_CHECK and nits
May 24, 2024
59f8ec1
use Union
May 24, 2024
7a83601
clang-format
May 24, 2024
9ea47c8
fix typo
May 24, 2024
7abb2c8
isort
May 24, 2024
eb4e119
update test case
dsikka May 28, 2024
d62930d
fix isort
dsikka May 28, 2024
80b6fac
store input scales in gpu
May 29, 2024
fa1ceef
Merge branch 'main' into dyn-per-token
dsikka Jun 5, 2024
7075318
tensor device location fixes
Jun 6, 2024
60a6d73
format.sh
Jun 6, 2024
f36519b
remove compressed tensors
dsikka Jun 6, 2024
2c6e580
format fix
dsikka Jun 6, 2024
b3d692a
add comments; some clean-up
dsikka Jun 6, 2024
f3bf9e3
review comments
Jun 6, 2024
2bd62e0
review comments and const correctness
Jun 6, 2024
460f514
format.sh
Jun 6, 2024
dfcd61a
nit fixes
dsikka Jun 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

Expand Down
3 changes: 3 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");

ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
"Compute int8 quantized tensor and scaling factor");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
Expand Down
75 changes: 64 additions & 11 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cmath>

#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
Expand All @@ -27,17 +28,48 @@ namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
}
}

template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
}
__syncthreads();

float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
}

} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
Expand All @@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);

int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
Expand All @@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
scale.data_ptr<float>(), hidden_size);
});
}

void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
});
}
54 changes: 41 additions & 13 deletions csrc/reduction_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,47 @@
#include "cuda_compat.h"

namespace vllm {

namespace detail {

template <typename T>
__inline__ __device__ T _max(T a, T b) {
return max(a, b);
}

template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}

} // namespace detail

template <typename T>
using ReduceFnType = T (*)(T, T);

// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
Comment on lines +42 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a common place we can put CUDA utils like this? We have the exact same helper fn in csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some sleuthing, but can't find a good place to put it. Should we create a math_utils.cuh file ? @robertgshaw2-neuralmagic @mgoin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely need another refactoring for csrc/quantization...but I don't have an out-of-box solution for this ATM.


template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) {
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));

// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
return val;
}

/* Calculate the sum of all elements in a block */
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val);
val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
Expand All @@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {

val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
}
return val;
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}

template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}

} // namespace vllm
44 changes: 38 additions & 6 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,59 @@
from vllm._C import ops

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a larger hidden size (> 1024) that's not nice number as well? I see 5120, but it is a multiple of 256

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added hidden-sizes 5137 and 8193

@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)

ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)

assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
atol=1) # big atol to account for rounding errors


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
seed: int, scale: float) -> None:
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

out1 = (x / scale).round().clamp(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")

Expand Down
19 changes: 18 additions & 1 deletion tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch

from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)


def test_compressed_tensors_w8a8_static_setup(vllm_runner):
Expand Down Expand Up @@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32


def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
llm = vllm_runner(model_path,
quantization="sparseml",
enforce_eager=True,
dtype=torch.float16)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj

assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.weight.dtype is torch.int8
28 changes: 20 additions & 8 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,33 @@ def scaled_fp8_quant(


# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
def scaled_int8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Quantize the input tensor to int8 and return the quantized tensor and scale.

Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.

Returns:
torch.Tensor: Output tensor in int8.
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the names of the variables used internally in this function match the scaled_fp8_quant function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed q to output. I believe the other variables are good as it is. please take a look. Thanks.

# static-per-tensor quantization.
vllm_ops.static_scaled_int8_quant(output, input, scale)
return output, scale

# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales


# moe
Expand Down
Loading
Loading