Skip to content

Commit

Permalink
[Bugfix][Kernel] Prevent integer overflow in fp8 dynamic per-token qu…
Browse files Browse the repository at this point in the history
…antize kernel (vllm-project#9425)
  • Loading branch information
tlrmchlsmth authored Oct 16, 2024
1 parent 776dbd7 commit c3fab5f
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
Expand Down

0 comments on commit c3fab5f

Please sign in to comment.