diff --git a/src/engine/batch.cpp b/src/engine/batch.cpp index cedf7cc2..ef7a592e 100644 --- a/src/engine/batch.cpp +++ b/src/engine/batch.cpp @@ -99,7 +99,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, std::vector q_cu_seq_lens = {0}; // slot ids for new token std::vector new_token_slot_ids; - std::vector> block_tables_vec; + std::vector block_tables; + std::vector cu_block_lens = {0}; const int32_t num_sequences = static_cast(sequences_.size()); for (int32_t i = 0; i < num_sequences; ++i) { auto* sequence = sequences_[i]; @@ -205,13 +206,11 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, new_token_slot_ids.insert( new_token_slot_ids.end(), slot_ids.begin(), slot_ids.end()); - // construct block ids for each sequence - std::vector block_ids; - block_ids.reserve(blocks.size()); + // add block ids for each sequence for (const auto& block : blocks) { - block_ids.push_back(block.id()); + block_tables.push_back(block.id()); } - block_tables_vec.push_back(block_ids); + cu_block_lens.push_back(static_cast(cu_block_lens.size())); } if (flatten_tokens_vec.empty()) { @@ -238,7 +237,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, } cu_seq_lens.push_back(cu_seq_lens.back() + num_decoding_tokens); q_cu_seq_lens.push_back(q_cu_seq_lens.back() + num_decoding_tokens); - block_tables_vec.emplace_back(); + // empty block table for padding sequences? + cu_block_lens.push_back(cu_block_lens.back()); } } } @@ -256,8 +256,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens, input_params.q_cu_seq_lens = torch::tensor(q_cu_seq_lens, torch::kInt); input_params.new_cache_slots = torch::tensor(new_token_slot_ids, torch::kInt); - pad_2d_vector(block_tables_vec, /*pad_value=*/0); - input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt); + input_params.block_tables = torch::tensor(block_tables, torch::kInt); + input_params.cu_block_lens = torch::tensor(cu_block_lens, torch::kInt); CHECK_EQ(sampling_params.size(), selected_token_idxes.size()); if (!selected_token_idxes.empty()) { diff --git a/src/engine/model_runner.cpp b/src/engine/model_runner.cpp index 2639ae46..3868f50f 100644 --- a/src/engine/model_runner.cpp +++ b/src/engine/model_runner.cpp @@ -57,7 +57,8 @@ ModelRunner::ModelRunner(CausalLM* model, const int64_t max_block_table_len = (max_seq_len + block_size - 1) / block_size + 1; block_tables_ = - torch::zeros({max_batch_size_, max_block_table_len}, tensor_options); + torch::zeros({max_batch_size_ * max_block_table_len}, tensor_options); + cu_block_lens_ = torch::zeros({max_batch_size_ + 1}, tensor_options); mem_pool_ = at::cuda::graph_pool_handle(); } @@ -91,8 +92,9 @@ void ModelRunner::capture_cuda_graphs(uint32_t batch_size, /*dim=*/0, /*start=*/0, /*end=*/batch_size + 1); params.kv_cu_seq_lens = kv_cu_seq_lens_.slice( /*dim=*/0, /*start=*/0, /*end=*/batch_size + 1); - params.block_tables = block_tables_.slice( - /*dim=*/0, /*start=*/0, /*end=*/batch_size); + params.block_tables = block_tables_; + params.cu_block_lens = cu_block_lens_.slice( + /*dim=*/0, /*start=*/0, /*end=*/batch_size + 1); params.new_cache_slots = new_cache_slots_.slice( /*dim=*/0, /*start=*/0, /*end=*/n_tokens); @@ -156,6 +158,7 @@ void ModelRunner::CudaGraph::capture(at::cuda::MempoolId_t mem_pool, flatten_positions_ = flatten_positions; new_cache_slots_ = params.new_cache_slots; block_tables_ = params.block_tables; + cu_block_lens_ = params.cu_block_lens; q_cu_seq_lens_ = params.q_cu_seq_lens; kv_cu_seq_lens_ = params.kv_cu_seq_lens; @@ -184,8 +187,8 @@ torch::Tensor ModelRunner::CudaGraph::replay(torch::Tensor flatten_tokens, const int64_t batch_size = params.num_sequences; const int64_t num_tokens = flatten_tokens.size(/*dim=*/0); - const int64_t block_table_len = params.block_tables.size(/*dim=*/1); - const int64_t max_block_table_len = block_tables_.size(/*dim=*/1); + const int64_t block_table_len = params.block_tables.size(/*dim=*/0); + const int64_t max_block_table_len = block_tables_.size(/*dim=*/0); CHECK_EQ(batch_size, batch_size_) << "batch size mismatch"; CHECK_EQ(num_tokens, num_tokens_) << "num tokens mismatch"; CHECK_GE(max_block_table_len, block_table_len) << "block table size mismatch"; @@ -198,8 +201,9 @@ torch::Tensor ModelRunner::CudaGraph::replay(torch::Tensor flatten_tokens, new_cache_slots_.copy_(params.new_cache_slots, /*non_blocking=*/true); // it is possible that the block table with different padding length - block_tables_.slice(/*dim=*/1, /*start=*/0, /*end=*/block_table_len) + block_tables_.slice(/*dim=*/0, /*start=*/0, /*end=*/block_table_len) .copy_(params.block_tables, /*non_blocking=*/true); + cu_block_lens_.copy_(params.cu_block_lens, /*non_blocking=*/true); // replay the graph graph_->replay(); diff --git a/src/engine/model_runner.h b/src/engine/model_runner.h index 6493cb3f..446b9d56 100644 --- a/src/engine/model_runner.h +++ b/src/engine/model_runner.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -65,6 +66,7 @@ class ModelRunner final { torch::Tensor kv_cu_seq_lens_; torch::Tensor new_cache_slots_; torch::Tensor block_tables_; + torch::Tensor cu_block_lens_; // graph pool handler at::cuda::MempoolId_t mem_pool_; @@ -99,6 +101,7 @@ class ModelRunner final { torch::Tensor flatten_positions_; torch::Tensor new_cache_slots_; torch::Tensor block_tables_; + torch::Tensor cu_block_lens_; torch::Tensor q_cu_seq_lens_; torch::Tensor kv_cu_seq_lens_;