From 68df9c487e672b4a4ea3be97aed63a48aac5945b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 8 Aug 2024 23:05:31 -0700 Subject: [PATCH] feat: more sampling operator options (#431) 1. implement the first top-k then top-p sampling to align with vllm and huggingface's behavior https://github.com/vllm-project/vllm/pull/7137#issuecomment-2275167700 2. add options of using a scalar/tensor for top-p/top-k thresholds for all sampling operators. --- docs/api/python/sampling.rst | 2 + include/flashinfer/sampling.cuh | 311 ++++++++++++++++++++++---------- python/csrc/flashinfer_ops.cu | 1 + python/csrc/flashinfer_ops.h | 28 ++- python/csrc/sampling.cu | 166 +++++++++++++---- python/flashinfer/__init__.py | 52 ++++-- python/flashinfer/sampling.py | 235 +++++++++++++++++++++--- python/tests/test_sampling.py | 80 +++++++- src/bench_sampling.cu | 8 +- src/test_sampling.cu | 4 +- 10 files changed, 691 insertions(+), 196 deletions(-) diff --git a/docs/api/python/sampling.rst b/docs/api/python/sampling.rst index 63df15f6..e3f5cda3 100644 --- a/docs/api/python/sampling.rst +++ b/docs/api/python/sampling.rst @@ -14,7 +14,9 @@ Kernels for LLM sampling. top_p_sampling_from_probs top_k_sampling_from_probs min_p_sampling_from_probs + top_k_top_p_sampling_from_logits top_k_top_p_sampling_from_probs top_p_renorm_prob top_k_renorm_prob + top_k_mask_logits chain_speculative_sampling diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2d631a48..4f4272f6 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -279,10 +279,11 @@ template __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, - bool* success, uint32_t k, uint32_t d, - uint32_t max_top_k_rounds) { + bool* success, IdType* top_k_arr, uint32_t top_k_val, + uint32_t d, uint32_t max_top_k_rounds) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; extern __shared__ __align__( alignof(SamplingTempStorage)) @@ -365,13 +366,11 @@ template __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, bool* success, IdType* row_indices, float* top_p_arr, - float top_p, uint32_t d, uint32_t max_top_p_rounds) { + float top_p_val, uint32_t d, uint32_t max_top_p_rounds) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx]; - if (top_p_arr != nullptr) { - top_p = top_p_arr[bx]; - } const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; extern __shared__ __align__( @@ -451,12 +450,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, template -__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType* min_p, - IdType* output, bool* success, uint32_t d, - uint32_t max_min_p_rounds) { +__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType* min_p_arr, + IdType* output, bool* success, float min_p_val, + uint32_t d, uint32_t max_min_p_rounds) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; - DType p = min_p[bx]; + DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx]; extern __shared__ __align__( alignof(SamplingTempStorage)) @@ -557,13 +556,14 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, template -__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* top_k, - DType* top_p, IdType* output, bool* success, +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, + IdType* top_k_arr, DType* top_p_arr, IdType* output, + bool* success, IdType top_k_val, DType top_p_val, uint32_t d, uint32_t max_rounds) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; - IdType k = top_k[bx]; - DType p = top_p[bx]; + IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; extern __shared__ __align__( alignof(SamplingTempStorage)) @@ -685,7 +685,7 @@ cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* outpu template cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - IdType top_k, uint32_t batch_size, uint32_t d, + T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; @@ -694,7 +694,8 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds}; + void* args[] = {&probs, &uniform_samples, &output, &success, + &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -710,7 +711,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b template cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - T top_p, uint32_t batch_size, uint32_t d, + T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; @@ -720,16 +721,8 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); IdType* row_indices_placeholder = nullptr; - T* top_p_arr_placeholder = nullptr; - void* args[] = {&probs, - &uniform_samples, - &output, - &success, - &row_indices_placeholder, - &top_p_arr_placeholder, - &top_p, - &d, - &max_top_p_rounds}; + void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices_placeholder, + &top_p_arr, &top_p_val, &d, &max_top_p_rounds}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -744,8 +737,8 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b } template -cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p, IdType* output, - bool* success, uint32_t batch_size, uint32_t d, +cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output, + bool* success, uint32_t batch_size, float min_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -753,7 +746,8 @@ cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p, IdType* const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &min_p, &output, &success, &d, &max_rounds}; + void* args[] = {&probs, &uniform_samples, &min_p_arr, &output, + &success, &min_p_val, &d, &max_rounds}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -768,17 +762,18 @@ cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p, IdType* } template -cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k, T* top_p, - IdType* output, bool* success, uint32_t batch_size, uint32_t d, - uint32_t max_rounds, bool deterministic, - cudaStream_t stream = 0) { +cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k_arr, T* top_p_arr, + IdType* output, bool* success, uint32_t batch_size, + IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, + bool deterministic, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &top_k, &top_p, &output, &success, &d, &max_rounds}; + void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, &output, + &success, &top_k_val, &top_p_val, &d, &max_rounds}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { @@ -796,23 +791,27 @@ template ::TempStorage reduce; + typename BlockReduce::TempStorage reduce_int; typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; } block_prim; struct { T max_val; + T min_val; union { T value; + int count; Pair pair; } block_aggregate; } data; }; template -__global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float p, float eps, - uint32_t d) { + typename DType> +__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* top_p_arr, + float top_p_val, float eps, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; + float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx]; extern __shared__ __align__(alignof(RenormTempStorage)) uint8_t smem_renorm[]; @@ -898,49 +897,120 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float template -__global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32_t k, float eps, - uint32_t d) { +__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr, + uint32_t top_k_val, float eps, uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + float pivot = -std::numeric_limits::infinity(); + vec_t logits_vec; + if (k < d) { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + DType logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 + + DType threadlocal_max_val = DType(-std::numeric_limits::infinity()), + threadlocal_min_val = DType(std::numeric_limits::infinity()); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + logits_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + logits_greater_than_pivot[j] = logits_vec[j]; + } + threadlocal_max_val = + max(threadlocal_max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(logits_greater_than_pivot, cub::Max())); + __syncthreads(); + threadlocal_min_val = + min(threadlocal_min_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(logits_greater_than_pivot, cub::Min())); + } + if (tx == 0) { + temp_storage.data.max_val = threadlocal_max_val; + temp_storage.data.min_val = threadlocal_min_val; + } + __syncthreads(); + threadlocal_max_val = temp_storage.data.max_val; + threadlocal_min_val = temp_storage.data.min_val; + + float low = threadlocal_min_val - 1, high = threadlocal_max_val; + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // loop invariant: f(low) >= k, f(high) < k + while (high - low > eps) { + int threadlocal_count_sum = 0; + int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0 + float mid = (low + high) / 2; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + logits_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot_count[j] = + logits_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; + } + threadlocal_count_sum += + BlockReduce(temp_storage.block_prim.reduce_int) + .Sum(probs_greater_than_pivot_count); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.block_aggregate.count = threadlocal_count_sum; + } + __syncthreads(); + threadlocal_count_sum = temp_storage.data.block_aggregate.count; + if (threadlocal_count_sum >= k) { + low = mid; + } else { + high = mid; + } + } + pivot = low; + } - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; - auto& temp_storage = - reinterpret_cast&>(smem_renorm); - temp_storage.data.max_val = DType(0); - vec_t probs_vec; - DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 - - DType threadlocal_max_val = DType(0); + // masking for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); + logits_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot[j] = probs_vec[j]; + logits_vec[j] = + (logits_vec[j] > pivot) ? logits_vec[j] : DType(-std::numeric_limits::infinity()); + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + logits_vec.store(masked_logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - threadlocal_max_val = - max(threadlocal_max_val, - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(probs_greater_than_pivot, cub::Max())); - __syncthreads(); - } - if (tx == 0) { - temp_storage.data.max_val = threadlocal_max_val; } - __syncthreads(); - threadlocal_max_val = temp_storage.data.max_val; +} - float low = 0, high = threadlocal_max_val; - DType sum_low(1); - // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // loop invariant: f(low) >= k, f(high) < k - while (high - low > eps) { - Pair threadlocal_sum{DType(0), 0}; - Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 - float mid = (low + high) / 2; +template +__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, + uint32_t top_k_val, float eps, uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + const uint32_t row_idx = bx; + uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; + float pivot = -std::numeric_limits::infinity(), normalizer = 1; + vec_t probs_vec; + if (k < d) { + extern __shared__ __align__(alignof(RenormTempStorage)) + uint8_t smem_renorm[]; + auto& temp_storage = + reinterpret_cast&>(smem_renorm); + temp_storage.data.max_val = DType(0); + DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 + + DType threadlocal_max_val = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { @@ -948,29 +1018,60 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_greater_than_pivot_pair[j] = { - (probs_vec[j] > mid) ? probs_vec[j] : DType(0), - (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_greater_than_pivot[j] = probs_vec[j]; } - threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_pair) - .Sum(probs_greater_than_pivot_pair); + threadlocal_max_val = + max(threadlocal_max_val, + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(probs_greater_than_pivot, cub::Max())); __syncthreads(); } if (tx == 0) { - temp_storage.data.block_aggregate.pair = threadlocal_sum; + temp_storage.data.max_val = threadlocal_max_val; } __syncthreads(); - threadlocal_sum = temp_storage.data.block_aggregate.pair; - if (threadlocal_sum.count >= k) { - low = mid; - sum_low = float(threadlocal_sum.value); - } else { - high = mid; + threadlocal_max_val = temp_storage.data.max_val; + + float low = 0, high = threadlocal_max_val; + DType sum_low(1); + // f(x) = len(nonzero(probs > x)), f(x) is non-increasing + // loop invariant: f(low) >= k, f(high) < k + while (high - low > eps) { + Pair threadlocal_sum{DType(0), 0}; + Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 + float mid = (low + high) / 2; + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(DType(0)); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_greater_than_pivot_pair[j] = { + (probs_vec[j] > mid) ? probs_vec[j] : DType(0), + (probs_vec[j] > mid && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + } + threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .Sum(probs_greater_than_pivot_pair); + __syncthreads(); + } + if (tx == 0) { + temp_storage.data.block_aggregate.pair = threadlocal_sum; + } + __syncthreads(); + threadlocal_sum = temp_storage.data.block_aggregate.pair; + if (threadlocal_sum.count >= k) { + low = mid; + sum_low = float(threadlocal_sum.value); + } else { + high = mid; + } } - } - float normalizer = math::ptx_rcp(max(sum_low, eps)); + normalizer = math::ptx_rcp(max(sum_low, eps)); + pivot = low; + } // normalize for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { @@ -980,7 +1081,7 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); + probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0); } if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); @@ -988,18 +1089,19 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 } } -template -cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float eps, - uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { +template +cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, float eps, + uint32_t batch_size, float top_p_val, uint32_t d, + cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &p, &eps, &d}; + void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &eps, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopPRenormProbKernel; + auto kernel = TopPRenormProbKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -1008,15 +1110,16 @@ cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float e } template -cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, float eps, - uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { +cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, float eps, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, + cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &k, &eps, &d}; + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &eps, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = TopKRenormProbKernel; FLASHINFER_CUDA_CALL( @@ -1026,6 +1129,26 @@ cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, floa return cudaSuccess; } +template +cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, float eps, + uint32_t batch_size, uint32_t top_k_val, uint32_t d, + cudaStream_t stream = 0) { + const uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &eps, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKMaskLogitsKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + template diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 16800f13..c44ec371 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -34,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Top-k and top-p sampling from probabilities"); m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask"); m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 0eee5495..988716d3 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -39,25 +39,33 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam bool deterministic); std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, double top_p, - bool deterministic); + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic); std::vector top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - unsigned int top_k, bool deterministic); + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic); std::vector min_p_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - torch::Tensor min_p, bool deterministic); + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic); + +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic); -std::vector top_k_top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - torch::Tensor top_k, torch::Tensor top_p, - bool deterministic); +torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val, double eps); -torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps); +torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, double eps); -torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps); +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val, double eps); torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, torch::Tensor target_probs, diff --git a/python/csrc/sampling.cu b/python/csrc/sampling.cu index 2be782cb..623a0c45 100644 --- a/python/csrc/sampling.cu +++ b/python/csrc/sampling.cu @@ -47,8 +47,9 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam } std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, double top_p, - bool deterministic) { + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -59,8 +60,17 @@ std::vector top_p_sampling_from_probs(torch::Tensor probs, unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); unsigned int max_top_p_rounds = uniform_samples.size(0); + bool has_top_p_arr = maybe_top_p_arr.has_value(); + auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); + if (has_top_p_arr) { + CHECK_INPUT(top_p_arr); + CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) + CHECK_EQ(top_p_arr.size(0), batch_size); + CHECK_EQ(top_p_arr.device(), device); + } probs = probs.to(torch::kFloat32); uniform_samples = uniform_samples.to(torch::kFloat32); + top_p_arr = top_p_arr.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); @@ -68,8 +78,9 @@ std::vector top_p_sampling_from_probs(torch::Tensor probs, cudaError_t status = sampling::TopPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), static_cast(success.data_ptr()), top_p, - batch_size, vocab_size, max_top_p_rounds, deterministic, torch_current_stream); + static_cast(samples.data_ptr()), static_cast(success.data_ptr()), + has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, batch_size, top_p_val, + vocab_size, max_top_p_rounds, deterministic, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); @@ -78,7 +89,8 @@ std::vector top_p_sampling_from_probs(torch::Tensor probs, std::vector top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - unsigned int top_k, bool deterministic) { + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); @@ -89,8 +101,17 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); unsigned int max_top_k_rounds = uniform_samples.size(0); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); + if (has_top_k_arr) { + CHECK_INPUT(top_k_arr); + CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) + CHECK_EQ(top_k_arr.size(0), batch_size); + CHECK_EQ(top_k_arr.device(), device); + } probs = probs.to(torch::kFloat32); uniform_samples = uniform_samples.to(torch::kFloat32); + top_k_arr = top_k_arr.to(torch::kInt32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); @@ -98,8 +119,9 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, cudaError_t status = sampling::TopKSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(samples.data_ptr()), static_cast(success.data_ptr()), top_k, - batch_size, vocab_size, max_top_k_rounds, deterministic, torch_current_stream); + static_cast(samples.data_ptr()), static_cast(success.data_ptr()), + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, batch_size, top_k_val, + vocab_size, max_top_k_rounds, deterministic, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); @@ -108,24 +130,29 @@ std::vector top_k_sampling_from_probs(torch::Tensor probs, std::vector min_p_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - torch::Tensor min_p, bool deterministic) { + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); - CHECK_INPUT(min_p); auto device = probs.device(); CHECK_EQ(uniform_samples.device(), device); - CHECK_EQ(min_p.device(), device); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size) - CHECK_DIM(1, min_p); // min_p: (batch_size,) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); unsigned int max_rounds = uniform_samples.size(0); CHECK_EQ(uniform_samples.size(1), batch_size); - CHECK_EQ(min_p.size(0), batch_size); + bool has_min_p_arr = maybe_min_p_arr.has_value(); + auto min_p_arr = maybe_min_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); + if (has_min_p_arr) { + CHECK_INPUT(min_p_arr); + CHECK_DIM(1, min_p_arr); // min_p_arr: (batch_size,) + CHECK_EQ(min_p_arr.size(0), batch_size); + CHECK_EQ(min_p_arr.device(), device); + } + min_p_arr = min_p_arr.to(torch::kFloat32); probs = probs.to(torch::kFloat32); uniform_samples = uniform_samples.to(torch::kFloat32); - min_p = min_p.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); @@ -133,41 +160,49 @@ std::vector min_p_sampling_from_probs(torch::Tensor probs, cudaError_t status = sampling::MinPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(min_p.data_ptr()), static_cast(samples.data_ptr()), - static_cast(success.data_ptr()), batch_size, vocab_size, max_rounds, deterministic, - torch_current_stream); + has_min_p_arr ? static_cast(min_p_arr.data_ptr()) : nullptr, + static_cast(samples.data_ptr()), static_cast(success.data_ptr()), batch_size, + min_p_val, vocab_size, max_rounds, deterministic, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "MinPSamplingFromProb failed with error code " + std::string(cudaGetErrorString(status))); return {samples, success}; } -std::vector top_k_top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - torch::Tensor top_k, torch::Tensor top_p, - bool deterministic) { +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); - CHECK_INPUT(top_k); - CHECK_INPUT(top_p); auto device = probs.device(); CHECK_EQ(uniform_samples.device(), device); - CHECK_EQ(top_k.device(), device); - CHECK_EQ(top_p.device(), device); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size) - CHECK_DIM(1, top_k); // top_k: (batch_size,) - CHECK_DIM(1, top_p); // top_p: (batch_size,) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); unsigned int max_rounds = uniform_samples.size(0); CHECK_EQ(uniform_samples.size(1), batch_size); - CHECK_EQ(top_k.size(0), batch_size); - CHECK_EQ(top_p.size(0), batch_size); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); + if (has_top_k_arr) { + CHECK_INPUT(top_k_arr); + CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) + CHECK_EQ(top_k_arr.size(0), batch_size); + CHECK_EQ(top_k_arr.device(), device); + } + top_k_arr = top_k_arr.to(torch::kInt32); + bool has_top_p_arr = maybe_top_p_arr.has_value(); + auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); + if (has_top_p_arr) { + CHECK_INPUT(top_p_arr); + CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) + CHECK_EQ(top_p_arr.size(0), batch_size); + CHECK_EQ(top_p_arr.device(), device); + } + top_p_arr = top_p_arr.to(torch::kFloat32); probs = probs.to(torch::kFloat32); uniform_samples = uniform_samples.to(torch::kFloat32); - top_k = top_k.to(torch::kInt32); - top_p = top_p.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device)); @@ -175,21 +210,32 @@ std::vector top_k_top_p_sampling_from_probs(torch::Tensor probs, cudaError_t status = sampling::TopKTopPSamplingFromProb( static_cast(probs.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(top_k.data_ptr()), static_cast(top_p.data_ptr()), + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, + has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, static_cast(samples.data_ptr()), static_cast(success.data_ptr()), batch_size, - vocab_size, max_rounds, deterministic, torch_current_stream); + top_k_val, top_p_val, vocab_size, max_rounds, deterministic, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopKTopPSamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); return {samples, success}; } -torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) { +torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val, double eps) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); + bool has_top_p_arr = maybe_top_p_arr.has_value(); + auto top_p_arr = maybe_top_p_arr.value_or(torch::empty({0}, torch::dtype(torch::kFloat32))); + if (has_top_p_arr) { + CHECK_INPUT(top_p_arr); + CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,) + CHECK_EQ(top_p_arr.size(0), batch_size); + CHECK_EQ(top_p_arr.device(), device); + } + top_p_arr = top_p_arr.to(torch::kFloat32); probs = probs.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -197,19 +243,30 @@ torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps) { torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); cudaError_t status = sampling::TopPRenormProb( - static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), top_p, - eps, batch_size, vocab_size, torch_current_stream); + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_p_arr ? static_cast(top_p_arr.data_ptr()) : nullptr, eps, batch_size, + top_p_val, vocab_size, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); return renorm_probs; } -torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps) { +torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, double eps) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); + if (has_top_k_arr) { + CHECK_INPUT(top_k_arr); + CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) + CHECK_EQ(top_k_arr.size(0), batch_size); + CHECK_EQ(top_k_arr.device(), device); + } + top_k_arr = top_k_arr.to(torch::kInt32); probs = probs.to(torch::kFloat32); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -217,14 +274,47 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); cudaError_t status = sampling::TopKRenormProb( - static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), top_k, - eps, batch_size, vocab_size, torch_current_stream); + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, eps, batch_size, top_k_val, + vocab_size, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); return renorm_probs; } +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val, double eps) { + CHECK_INPUT(logits); + auto device = logits.device(); + CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) + unsigned int batch_size = logits.size(0); + unsigned int vocab_size = logits.size(1); + bool has_top_k_arr = maybe_top_k_arr.has_value(); + auto top_k_arr = maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32))); + if (has_top_k_arr) { + CHECK_INPUT(top_k_arr); + CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,) + CHECK_EQ(top_k_arr.size(0), batch_size); + CHECK_EQ(top_k_arr.device(), device); + } + top_k_arr = top_k_arr.to(torch::kInt32); + logits = logits.to(torch::kFloat32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto mask_logits = + torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device)); + + cudaError_t status = sampling::TopKMaskLogits( + static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), + has_top_k_arr ? static_cast(top_k_arr.data_ptr()) : nullptr, eps, batch_size, top_k_val, + vocab_size, torch_current_stream); + + TORCH_CHECK(status == cudaSuccess, + "TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); + return mask_logits; +} + torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, torch::Tensor target_probs, bool deterministic) { diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index f1429e76..8794c135 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -14,26 +14,46 @@ limitations under the License. """ -from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper, - BatchPrefillWithSharedPrefixPagedKVCacheWrapper, - merge_state, merge_state_in_place, merge_states) -from .decode import (BatchDecodeWithPagedKVCacheWrapper, - CUDAGraphBatchDecodeWithPagedKVCacheWrapper, - single_decode_with_kv_cache) +from .cascade import ( + BatchDecodeWithSharedPrefixPagedKVCacheWrapper, + BatchPrefillWithSharedPrefixPagedKVCacheWrapper, + merge_state, + merge_state_in_place, + merge_states, +) +from .decode import ( + BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + single_decode_with_kv_cache, +) from .group_gemm import SegmentGEMMWrapper from .norm import fused_add_rmsnorm, rmsnorm from .page import append_paged_kv_cache -from .prefill import (BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, - single_prefill_with_kv_cache, - single_prefill_with_kv_cache_return_lse) +from .prefill import ( + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + single_prefill_with_kv_cache, + single_prefill_with_kv_cache_return_lse, +) from .quantization import packbits, segment_packbits -from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope, - apply_rope_inplace) -from .sampling import (chain_speculative_sampling, sampling_from_probs, - top_k_renorm_prob, top_k_sampling_from_probs, - top_k_top_p_sampling_from_probs, top_p_renorm_prob, - top_p_sampling_from_probs) +from .rope import ( + apply_llama31_rope, + apply_llama31_rope_inplace, + apply_rope, + apply_rope_inplace, +) +from .sampling import ( + chain_speculative_sampling, + sampling_from_probs, + top_k_renorm_prob, + top_k_mask_logits, + top_k_sampling_from_probs, + top_k_top_p_sampling_from_probs, + top_k_top_p_sampling_from_logits, + top_p_renorm_prob, + top_p_sampling_from_probs, + min_p_sampling_from_probs, +) from .sparse import BlockSparseAttentionWrapper try: diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 8d54ad7c..b6c23d73 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -15,7 +15,7 @@ """ import torch -from typing import Tuple +from typing import Tuple, Union # mypy: disable-error-code="attr-defined" try: @@ -31,6 +31,13 @@ raise e +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) + + def sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, deterministic: bool = True ) -> torch.Tensor: @@ -81,7 +88,7 @@ def sampling_from_probs( def top_p_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, - top_p: float, + top_p: Union[torch.Tensor, float], deterministic: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, @@ -98,8 +105,10 @@ def top_p_sampling_from_probs( The uniform samples used as needle for sampling, shape ``(max_top_p_rounds, batch_size,)``, where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in ``[0, 1)``. - top_p: float - The threshold for top-p sampling. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. @@ -142,14 +151,14 @@ def top_p_sampling_from_probs( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ return _kernels.top_p_sampling_from_probs( - probs, uniform_samples, top_p, deterministic + probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic ) def top_k_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, - top_k: int, + top_k: Union[torch.Tensor, int], deterministic: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for top-k sampling from probabilities, @@ -166,8 +175,10 @@ def top_k_sampling_from_probs( The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``, where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in ``[0, 1)``. - top_k: int - The k in "top-k". + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. @@ -210,14 +221,14 @@ def top_k_sampling_from_probs( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ return _kernels.top_k_sampling_from_probs( - probs, uniform_samples, top_k, deterministic + probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic ) def min_p_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, - min_p: torch.Tensor, + min_p: Union[torch.Tensor, float], deterministic: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Fused GPU kernel for `min_p sampling `_ from probabilities, @@ -236,7 +247,9 @@ def min_p_sampling_from_probs( where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in ``[0, 1)``. min_p: torch.Tensor - The :math:`p_{\text{base}}` in min_p sampling for each request, shape ``(batch_size,)``. + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. @@ -280,18 +293,124 @@ def min_p_sampling_from_probs( implementation usually use much fewer rounds for rejection sampling because of early stopping. """ return _kernels.min_p_sampling_from_probs( - probs, uniform_samples, min_p, deterministic + probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic ) +def top_k_top_p_sampling_from_logits( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Fused GPU kernel for top-k and top-p sampling from pre-softmax logits, + + this operator implements GPU-based rejection sampling without explicit sorting. + + The multiple rounds of rejection sampling are implemented in a single CUDA kernel, + which is more efficient than the naive implementation that launches a series of kernels. + + Parameters + ---------- + logits: torch.Tensor + Pre-softmax logits, shape ``(batch_size, num_classes)``. + uniform_samples: torch.Tensor + The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``, + where the first dimension is the maximum number of rounds for rejection sampling. + Expected to be uniformly distributed in ``[0, 1)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + filter_apply_order: str + The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. + If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. + If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. + deterministic: bool + Whether to use deterministic kernel implementation, default is ``True``. + + Returns + ------- + samples: torch.Tensor + Sampled categories, shape ``(batch_size,)``. + success: torch.Tensor + Whether the sampling is successful within ``max_top_k_rounds`` rounds, + shape ``(batch_size,)``. + + Examples + -------- + + >>> import torch + >>> import flashinfer + >>> torch.manual_seed(42) + >>> batch_size = 4 + >>> vocab_size = 5 + >>> max_rounds = 3 + >>> top_p = 0.5 + >>> top_k = 3 + >>> logits = torch.rand(batch_size, vocab_size).to(0) + >>> logits + tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581], + [ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866], + [-0.4934, 0.2415, -0.2316, 0.0418, -0.2516], + [ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0') + >>> uniform_samples = torch.rand(max_rounds, batch_size).to(0) + >>> samples, success = flashinfer.sampling.top_k_top_p_sampling_from_logits(logits, uniform_samples, top_k, top_p) + >>> samples + tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32 + >>> success + tensor([True, True, True, True], device='cuda:0') + >>> probs = torch.softmax(logits, dim=-1) + >>> probs + tensor([[0.4788, 0.3085, 0.1716, 0.0085, 0.0327], + [0.2358, 0.1787, 0.4307, 0.1145, 0.0404], + [0.1358, 0.2831, 0.1764, 0.2318, 0.1729], + [0.3613, 0.1122, 0.1029, 0.3415, 0.0821]], device='cuda:0') + >>> samples + tensor([0, 2, 1, 3], device='cuda:0', dtype=torch.int32) + >>> success + tensor([True, True, True, True], device='cuda:0') + + Notes + ----- + This function expects float32 inputs, and the output is int32. + We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual + implementation usually use much fewer rounds for rejection sampling because of early stopping. + """ + if filter_apply_order == "top_k_first": + masked_logits = top_k_mask_logits(probs, top_k, **kwargs) + probs = torch.softmax(masked_logits, dim=-1) + return top_p_sampling_from_probs(probs, uniform_samples, top_p, deterministic) + elif filter_apply_order == "joint": + probs = torch.softmax(probs, dim=-1) + return _kernels.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") + + def top_k_top_p_sampling_from_probs( probs: torch.Tensor, uniform_samples: torch.Tensor, - top_k: torch.Tensor, - top_p: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", deterministic: bool = True, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Fused GPU kernel for joint top-k and top-p sampling from probabilities, + r"""Fused GPU kernel for top-k and top-p sampling from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -306,10 +425,18 @@ def top_k_top_p_sampling_from_probs( The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``, where the first dimension is the maximum number of rounds for rejection sampling. Expected to be uniformly distributed in ``[0, 1)``. - top_k: torch.Tensor - The k in "top-k" for each request, shape ``(batch_size,)``. - top_p: torch.Tensor - The threshold for top-p sampling for each request, shape ``(batch_size,)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + filter_apply_order: str + The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``. + If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results. + If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. deterministic: bool Whether to use deterministic kernel implementation, default is ``True``. @@ -352,13 +479,25 @@ def top_k_top_p_sampling_from_probs( We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual implementation usually use much fewer rounds for rejection sampling because of early stopping. """ - return _kernels.top_k_top_p_sampling_from_probs( - probs, uniform_samples, top_k, top_p, deterministic - ) + if filter_apply_order == "top_k_first": + renorm_probs = top_k_renorm_prob(probs, top_k, **kwargs) + return top_p_sampling_from_probs( + renorm_probs, uniform_samples, top_p, deterministic + ) + elif filter_apply_order == "joint": + return _kernels.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") def top_p_renorm_prob( - probs: torch.Tensor, top_p: float, eps: float = 1e-5 + probs: torch.Tensor, top_p: Union[torch.Tensor, float], eps: float = 1e-6 ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-p thresholding. @@ -366,8 +505,11 @@ def top_p_renorm_prob( ---------- probs: torch.Tensor Probabilities, shape ``(batch_size, num_classes)``. - top_p: float - The threshold for re-normalizing probabilities, should be in ``(0, 1)``. + top_p: Union[torch.Tensor, float] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for + re-normalizing probabilities, should be in ``(0, 1)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. We mask out the probabilities less than `threshold` where the cumulative sum of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities. eps: float @@ -381,11 +523,11 @@ def top_p_renorm_prob( This combination of ``top_p_renorm_prob`` and ``sampling_from_probs`` should be equivalent to ``top_p_sampling_from_probs``. """ - return _kernels.top_p_renorm_prob(probs, top_p, eps) + return _kernels.top_p_renorm_prob(probs, *_to_tensor_scalar_tuple(top_p), eps) def top_k_renorm_prob( - probs: torch.Tensor, top_k: int, eps: float = 1e-5 + probs: torch.Tensor, top_k: Union[torch.Tensor, int], eps: float = 1e-6 ) -> torch.Tensor: r"""Fused GPU kernel for renormalizing probabilities by top-k thresholding. @@ -393,8 +535,11 @@ def top_k_renorm_prob( ---------- probs: torch.Tensor Probabilities, shape ``(batch_size, num_classes)``. - top_k: int - The threshold for re-normalizing probabilities, should be in ``(0, num_classes)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for + for re-normalizing probabilities, should be in ``(0, num_classes)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities. eps: float The epsilon value for numerical stability. @@ -409,7 +554,37 @@ def top_k_renorm_prob( This combination of ``top_k_renorm_prob`` and ``sampling_from_probs`` should be equivalent to ``top_k_sampling_from_probs``. """ - return _kernels.top_k_renorm_prob(probs, top_k, eps) + return _kernels.top_k_renorm_prob(probs, *_to_tensor_scalar_tuple(top_k), eps) + + +def top_k_mask_logits( + logits: torch.Tensor, top_k: Union[torch.Tensor, int], eps: float = 1e-5 +) -> torch.Tensor: + r"""Fused GPU kernel for masking logits by top-k thresholding. + + Parameters + ---------- + logits: torch.Tensor + Logits before softmax, shape ``(batch_size, num_classes)``. + top_k: Union[torch.Tensor, int] + Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for + for masking logits, should be in ``(0, num_classes)``. + If a scalar, the same threshold is used for all requests. + If a tensor, each request has its own threshold. + We keep the top-k logits, set the rest to negative infinity. + eps: float + The epsilon value for numerical stability. + + Returns + ------- + masked_logits: torch.Tensor + Masked logits, shape ``(batch_size, num_classes)``. + + Note + ---- + The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_prob``. + """ + return _kernels.top_k_mask_logits(logits, *_to_tensor_scalar_tuple(top_k), eps) def chain_speculative_sampling( diff --git a/python/tests/test_sampling.py b/python/tests/test_sampling.py index 26c2395f..43a49bfe 100644 --- a/python/tests/test_sampling.py +++ b/python/tests/test_sampling.py @@ -132,7 +132,8 @@ def test_min_p_sampling(batch_size, vocab_size, p): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) -def test_top_k_top_p_sampling(batch_size, vocab_size, p): +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: @@ -164,7 +165,11 @@ def test_top_k_top_p_sampling(batch_size, vocab_size, p): for _ in range(num_trails): uniform_samples.uniform_() samples, success = flashinfer.sampling.top_k_top_p_sampling_from_probs( - normalized_prob, uniform_samples, top_k_tensor, top_p_tensor + normalized_prob, + uniform_samples, + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", ) assert torch.all(success) assert torch.all(samples < vocab_size) and torch.all(samples >= 0) @@ -173,6 +178,55 @@ def test_top_k_top_p_sampling(batch_size, vocab_size, p): ] +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [100]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size, k, p): + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size).to(0) * 5 + uniform_samples = torch.empty(32, batch_size).to(0) + samples, success = flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, uniform_samples, k, p, filter_apply_order="top_k_first" + ) + samples_ref, success_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs( + torch.softmax(logits, dim=-1), + uniform_samples, + k, + p, + filter_apply_order="top_k_first", + ) + assert torch.all(samples == samples_ref) + assert torch.all(success) + assert torch.all(success_ref) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p): + torch.manual_seed(42) + logits = torch.rand(batch_size, vocab_size).to(0) * 5 + uniform_samples = torch.empty(32, batch_size).to(0) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + + samples, success = flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, uniform_samples, k, p, filter_apply_order="joint" + ) + + samples_ref, success_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs( + torch.softmax(logits, dim=-1), uniform_samples, k, p, filter_apply_order="joint" + ) + assert torch.all(samples == samples_ref) + assert torch.all(success) + assert torch.all(success_ref) + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) @@ -226,6 +280,27 @@ def test_top_k_renorm_prob(batch_size, vocab_size, k): ) +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_mask_logits(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + logits = torch.randn(batch_size, vocab_size).to(0) * 5 + probs = torch.softmax(logits, dim=-1) + masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k) + renormed_probs = torch.softmax(masked_logits, dim=-1) + renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k, 1e-8) + + numpy.testing.assert_allclose( + renormed_probs.cpu().numpy(), + renormed_probs_ref.cpu().numpy(), + rtol=1e-3, + atol=1e-3, + ) + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7]) @@ -291,5 +366,6 @@ def test_chain_speculative_sampling( test_top_k_sampling(3, 111, 10) test_top_p_renorm_prob(3, 111, 0.9) test_top_k_renorm_prob(3, 111, 10) + test_top_k_mask_logits(99, 989, 10) test_chain_speculative_sampling(3, 111, 3, False) test_chain_speculative_sampling(3, 111, 3, True) diff --git a/src/bench_sampling.cu b/src/bench_sampling.cu index 343d7b86..e328f601 100644 --- a/src/bench_sampling.cu +++ b/src/bench_sampling.cu @@ -101,8 +101,8 @@ void bench_top_p_sampling_with_probability(nvbench::state& state) { cudaError_t status = sampling::TopPSamplingFromProb( thrust::raw_pointer_cast(probs_d.data()), thrust::raw_pointer_cast(uniform_samples_d.data()), - thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), p, - batch_size, vocab_size, max_top_p_rounds, deterministic); + thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), + /*top_p_arr=*/nullptr, batch_size, p, vocab_size, max_top_p_rounds, deterministic); timer.stop(); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); @@ -147,8 +147,8 @@ void bench_top_k_sampling_with_probability(nvbench::state& state) { cudaError_t status = sampling::TopKSamplingFromProb( thrust::raw_pointer_cast(probs_d.data()), thrust::raw_pointer_cast(uniform_samples_d.data()), - thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), k, - batch_size, vocab_size, max_top_k_rounds, deterministic); + thrust::raw_pointer_cast(output_d.data()), thrust::raw_pointer_cast(success_d.data()), + /*top_k_arr=*/nullptr, batch_size, k, vocab_size, max_top_k_rounds, deterministic); timer.stop(); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); diff --git a/src/test_sampling.cu b/src/test_sampling.cu index fc774a4a..8a0a05fe 100644 --- a/src/test_sampling.cu +++ b/src/test_sampling.cu @@ -61,7 +61,7 @@ void _TestTopKSamplingFromProb(size_t batch_size, uint32_t k, size_t vocab_size) thrust::raw_pointer_cast(probs_d.data()), thrust::raw_pointer_cast(uniform_samples_d.data()), thrust::raw_pointer_cast(sampled_ids_d.data()), thrust::raw_pointer_cast(success_d.data()), - k, batch_size, vocab_size, max_top_p_rounds, /*deterministic=*/true); + /*top_k_arr=*/nullptr, batch_size, k, vocab_size, max_top_p_rounds, /*deterministic=*/true); EXPECT_EQ(status, cudaSuccess) << "TopKSamplingFromProb kernel launch failed, error message: " << cudaGetErrorString(status); @@ -126,7 +126,7 @@ void _TestTopPSamplingFromProb(size_t batch_size, uint32_t k, size_t vocab_size) thrust::raw_pointer_cast(probs_d.data()), thrust::raw_pointer_cast(uniform_samples_d.data()), thrust::raw_pointer_cast(sampled_ids_d.data()), thrust::raw_pointer_cast(success_d.data()), - p, batch_size, vocab_size, max_top_p_rounds, /*deterministic=*/true); + /*top_p_arr=*/nullptr, batch_size, p, vocab_size, max_top_p_rounds, /*deterministic=*/true); EXPECT_EQ(status, cudaSuccess) << "TopPSamplingFromProb kernel launch failed, error message: " << cudaGetErrorString(status);