Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace begin_forward/forward/end_forward with plan/run #466

Merged
merged 4 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 14 additions & 33 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ class BatchDecodeHandler {
template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut,
typename IdType>
cudaError_t BeginForwardDispatched(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes,
IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t page_size) {
cudaError_t PlanDispatched(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h,
IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t page_size) {
Clear();
batch_size_before_partition_ = batch_size;
bool split_kv;
uint32_t max_grid_size, max_num_pages_per_batch, new_batch_size;
Expand Down Expand Up @@ -438,12 +438,10 @@ class BatchDecodeHandler {
}
}
});
forward_started_ = true;
return cudaSuccess;
}

cudaError_t EndForward() {
forward_started_ = false;
void Clear() {
padded_batch_size_ = 0;
batch_size_before_partition_ = 0;
batch_size_after_partition_ = 0;
Expand All @@ -456,11 +454,8 @@ class BatchDecodeHandler {
batch_idx_map_ = nullptr;
chunk_start_pos_ = nullptr;
seq_lengths_before_partition_ = nullptr;
return cudaSuccess;
}

bool IsForwardStarted() const { return forward_started_; }

void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes);
Expand Down Expand Up @@ -490,16 +485,12 @@ class BatchDecodeHandler {
batch_idx_map_(nullptr),
chunk_start_pos_(nullptr),
seq_lengths_before_partition_(nullptr),
forward_started_(false),
cuda_graph_enabled_(enable_cuda_graph),
fixed_batch_size_(batch_size),
stream_(nullptr) {
cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024);
}
~BatchDecodeHandler() {
EndForward();
cudaFreeHost(page_locked_buffer_);
}
~BatchDecodeHandler() { cudaFreeHost(page_locked_buffer_); }

bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; }

Expand All @@ -516,7 +507,6 @@ class BatchDecodeHandler {
void* batch_idx_map_;
void* chunk_start_pos_;
void* seq_lengths_before_partition_;
bool forward_started_;
bool cuda_graph_enabled_;
uint32_t padded_batch_size_;
uint32_t fixed_batch_size_;
Expand Down Expand Up @@ -656,19 +646,17 @@ class BatchPrefillHandler {

uint32_t GetTotalNumRows() const { return total_num_rows_; }

bool IsForwardStarted() const { return request_indices_ != nullptr; }

void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes);
}

template <typename DTypeOut, typename IdType>
cudaError_t BeginForward(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size) {
cudaError_t Plan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
size_t int_workspace_size_in_bytes, IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, uint32_t page_size) {
Clear();
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand Down Expand Up @@ -812,8 +800,7 @@ class BatchPrefillHandler {
return cudaSuccess;
}

cudaError_t EndForward() {
forward_started_ = false;
void Clear() {
request_indices_ = nullptr;
qo_tile_indices_ = nullptr;
kv_tile_indices_ = nullptr;
Expand All @@ -826,7 +813,6 @@ class BatchPrefillHandler {
total_num_rows_ = 0U;
padded_batch_size_ = 0U;
warp_layout_ = WarpLayout::k4x1x2;
return cudaSuccess;
}

cudaStream_t GetCUDAStream() const { return stream_; }
Expand All @@ -848,15 +834,11 @@ class BatchPrefillHandler {
total_num_rows_(0U),
padded_batch_size_(0U),
warp_layout_(WarpLayout::k4x1x2),
forward_started_(false),
enable_cuda_graph_(enable_cuda_graph),
stream_(nullptr) {
cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024);
}
~BatchPrefillHandler() {
EndForward();
cudaFreeHost(page_locked_buffer_);
}
~BatchPrefillHandler() { cudaFreeHost(page_locked_buffer_); }

protected:
void* page_locked_buffer_;
Expand All @@ -872,7 +854,6 @@ class BatchPrefillHandler {
uint32_t total_num_rows_;
uint32_t padded_batch_size_;
WarpLayout warp_layout_;
bool forward_started_;
bool enable_cuda_graph_;
cudaStream_t stream_;
};
Expand Down
27 changes: 10 additions & 17 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,16 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS();

if (handler->IsForwardStarted()) {
if (tmp_v != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
if (tmp_v != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}

return BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, PAGE_STORAGE, LOGITS_POST_HOOK,
Expand Down
62 changes: 24 additions & 38 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,18 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return BatchPrefillWithPagedKVCacheDispatched<
Expand Down Expand Up @@ -131,25 +124,18 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithRaggedKVWrapperCache()";
throw std::runtime_error(err_msg.str());
}
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return BatchPrefillWithRaggedKVCacheDispatched<WARP_LAYOUT, HEAD_DIM, LOGITS_POST_HOOK,
Expand Down
29 changes: 13 additions & 16 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

using namespace flashinfer;

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchDecodeWithPagedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
Expand Down Expand Up @@ -62,15 +62,15 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, qkv_type,
qkv_type, qkv_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, page_size);
handler_
->PlanDispatched<HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK,
POS_ENCODING_MODE, qkv_type, qkv_type, qkv_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size,
num_qo_heads, num_kv_heads, page_size);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
Expand All @@ -86,9 +86,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, q_type,
kv_type, q_type, int32_t>(
handler_->PlanDispatched<HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK,
POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
Expand All @@ -107,14 +106,12 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
}
}

void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Run(
torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down
20 changes: 8 additions & 12 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithPagedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
Expand Down Expand Up @@ -48,7 +48,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
cudaError_t status = handler_->Plan<q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()),
Expand All @@ -60,14 +60,12 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
});
}

void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Run(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down Expand Up @@ -257,7 +255,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
}
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::RunCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down Expand Up @@ -452,7 +450,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
}
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithRaggedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
Expand All @@ -479,7 +477,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
cudaError_t status = handler_->Plan<q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
Expand All @@ -491,14 +489,12 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
});
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Run(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale,
Expand Down Expand Up @@ -605,7 +601,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
}
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask(
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::RunCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left,
Expand Down
Loading