Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] AQ AZP 3/4: Asymmetric quantization kernels #7270

Merged
merged 29 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9c5b95b
DRAFT dynamic azp quant kernel - failing non-deterministically
ProExpertProg Jul 22, 2024
c0916db
Fixed blockReduce bug! Also using round-to-even for azp
ProExpertProg Jul 23, 2024
69f9493
Remove scale adjustment
ProExpertProg Jul 23, 2024
15e4c72
Fixed saturation in kernel
ProExpertProg Jul 23, 2024
6353a8b
Integer allclose comparison
ProExpertProg Jul 23, 2024
a95790a
utils fix
ProExpertProg Jul 24, 2024
84db5cd
Fixed torch ref conversion
ProExpertProg Jul 24, 2024
d11340c
Format
ProExpertProg Jul 24, 2024
fe91441
Inverted azp sign to be consistent with RFC, unit tests, and compress…
ProExpertProg Jul 24, 2024
f769c99
Fix order of rounding in test (doesn't matter for small numbers, just…
ProExpertProg Jul 25, 2024
9e49812
Fewer tests
ProExpertProg Jul 25, 2024
c1ad358
Static per-tensor kernels added
ProExpertProg Jul 25, 2024
25d0f58
Reduced test size, fixed custom_ops wrapper
ProExpertProg Jul 27, 2024
e05068c
format
ProExpertProg Aug 6, 2024
5d249fe
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/a…
ProExpertProg Aug 27, 2024
d02c568
Merge fixes
ProExpertProg Aug 27, 2024
e4dc101
Fix for AMD build
ProExpertProg Aug 29, 2024
31b3e44
PR comments: Python nits
ProExpertProg Sep 10, 2024
5a9762e
PR comments: saturation code
ProExpertProg Sep 10, 2024
8aed02a
explicit nearest rounding mode
ProExpertProg Sep 10, 2024
557db87
Added rounding mode guard
ProExpertProg Sep 10, 2024
2b24032
Rounding mode stuff removed, added comment
ProExpertProg Sep 10, 2024
5e9a0cb
Fixed test
ProExpertProg Sep 10, 2024
65b2f9c
Improved nearbyint rounding comment
ProExpertProg Sep 10, 2024
45e1d9e
Added saturating cast test
ProExpertProg Sep 10, 2024
2232b6d
Fixed scaled_int8_quant in qqq
ProExpertProg Sep 11, 2024
8df3b2d
Merge remote-tracking branch 'upstream/main' into luka/aq-azp-kernels
ProExpertProg Sep 11, 2024
04a539e
Fixed ops_check & azp test atol
ProExpertProg Sep 12, 2024
a3b9f6a
Fixed cpu bindings
ProExpertProg Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
#endif

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp);

void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);
torch::Tensor& scales,
c10::optional<torch::Tensor> const& azp);

void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
Expand Down
150 changes: 138 additions & 12 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
static const float i8_min =
static const auto i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static const auto i8_max =
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
Expand All @@ -31,6 +31,43 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
#endif
}

static inline __device__ int32_t float_to_int32_rn(float x) {
#ifdef USE_ROCM
static const auto i32_min =
static_cast<float>(std::numeric_limits<int32_t>::min());
static const auto i32_max =
static_cast<float>(std::numeric_limits<int32_t>::max());
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
// round
float dst = std::nearbyint(x);
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
// saturate
dst = std::clamp(dst, i32_min, i32_max);
return static_cast<int32_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int32_t&>(dst);
#endif
}

static inline __device__ int8_t int32_to_int8(int32_t x) {
#ifdef USE_ROCM
static const auto i8_min =
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
static const auto i8_max =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());

// saturate
int32_t dst = std::clamp(x, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}

namespace vllm {

template <typename scalar_t, typename scale_type>
Expand All @@ -47,6 +84,23 @@ __global__ void static_scaled_int8_quant_kernel(
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void static_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;

for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
out[token_idx * hidden_size + i] = quant_val;
}
}

template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
Expand Down Expand Up @@ -80,14 +134,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;

// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}

// Reduce the max and min values across the block
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
__syncthreads(); // Make sure min doesn't mess with max shared memory
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);

__shared__ scale_type scale_sh;
__shared__ azp_type azp_sh;

// Compute the scale and zero point and store them, only on the first thread
if (threadIdx.x == 0) {
float const scale_val = (max_val - min_val) / 255.0f;
// Use rounding to even (same as torch.round)
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
auto const azp_val = static_cast<azp_type>(azp_float);

// Store the scale and azp into shared and global
scale[token_idx] = scale_sh = scale_val;
azp[token_idx] = azp_sh = azp_val;
}

// Wait for the scale and azp to be computed
__syncthreads();

float const scale_val = scale_sh;
azp_type const azp_val = azp_sh;

// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
}
}

} // namespace vllm

void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
torch::Tensor const& scale,
c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp || azp->numel() == 1);

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
Expand All @@ -96,19 +204,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
if (!azp) {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
} else {
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}

void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(!azp || azp->is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
Expand All @@ -117,9 +235,17 @@ void dynamic_scaled_int8_quant(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
if (!azp) {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
} else {
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
<<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
hidden_size);
}
});
}
8 changes: 4 additions & 4 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
"()");
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);
}
Expand Down
89 changes: 84 additions & 5 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]


def allclose_int(input, other, atol: int = 0, rtol: float = 1e-5):
INT_DTYPES = [
torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8,
torch.uint16, torch.uint32, torch.uint64
]
assert input.dtype in INT_DTYPES and other.dtype in INT_DTYPES
diff = torch.abs(input.to(torch.int64) - other.to(torch.int64))
return torch.all(
diff <= atol +
torch.ceil(rtol * torch.abs(other).to(torch.float32)).to(torch.int64))
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -27,14 +39,52 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
# kernel
ops_out, ops_scales = scaled_int8_quant(x)
ops_out, ops_scales, _ = scaled_int8_quant(x)

torch.testing.assert_close(ops_scales, ref_scales)
torch.testing.assert_close(
ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300

x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)

# calculate scale and azp, and adjust the range
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
azps = torch.round(-128.0 - x_token_min / scales).to(torch.int32)

torch_out = ((x / scales).round() + azps).clamp(
int8_traits.min, int8_traits.max).to(torch.int8)
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max

ops_out = torch.empty_like(x, dtype=torch.int8)
scales_out = torch.empty_like(scales, dtype=torch.float32)
azp_out = torch.empty_like(azps, dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)

if (not torch.allclose(scales_out, scales)):
print(torch.argmax(torch.abs(scales_out - scales)))
torch.testing.assert_close(scales_out, scales)
assert allclose_int(azp_out, azps, atol=1) # azp rounding error
assert allclose_int(torch_out, ops_out, atol=1) # azp rounding error


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -53,8 +103,37 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,

out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2, _ = scaled_int8_quant(x, scale)
out2, _, _ = scaled_int8_quant(x, scale)

torch.testing.assert_close(
out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time
@pytest.mark.parametrize("azp", [-255, 54])
@torch.inference_mode()
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300

out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_argument = torch.tensor([azp], dtype=torch.int32, device="cuda")

torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument,
azp_argument)

# big atol to account for rounding errors
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
Loading
Loading