Skip to content

Commit

Permalink
Revert "Dynamically configure shared memory size for moe_align_block_…
Browse files Browse the repository at this point in the history
…size_kernel (vllm-project#3376)"

This reverts commit 78b6c48.
  • Loading branch information
simon-mo committed Mar 16, 2024
1 parent 21b0dbb commit fe983cc
Showing 1 changed file with 13 additions and 29 deletions.
42 changes: 13 additions & 29 deletions csrc/moe_align_block_size_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

const static size_t NUM_MAX_EXPERTS = 64;
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))

namespace vllm {

namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
// don't worry about overflow because num_experts is relatively small
return row * total_col + col;
}
}

template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t *sorted_token_ids,
Expand All @@ -28,14 +21,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;

extern __shared__ int32_t shared_mem[];

int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)

__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
tokens_cnts[threadIdx.x + 1][i] = 0;
}

/**
Expand All @@ -44,15 +33,15 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
}

__syncthreads();

// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
tokens_cnts[0][threadIdx.x] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
}

__syncthreads();
Expand All @@ -61,7 +50,7 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
Expand Down Expand Up @@ -89,9 +78,9 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
++tokens_cnts[threadIdx.x][expert_id];
}
}
}
Expand All @@ -104,16 +93,11 @@ void moe_align_block_size(
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);

// set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
Expand Down

0 comments on commit fe983cc

Please sign in to comment.