Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Sep 5, 2024
1 parent fb4f0c0 commit 47ead2c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
18 changes: 9 additions & 9 deletions src/engine/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
std::vector<int32_t> q_cu_seq_lens = {0};
// slot ids for new token
std::vector<int32_t> new_token_slot_ids;
std::vector<std::vector<int32_t>> block_tables_vec;
std::vector<int32_t> block_tables;
std::vector<int32_t> cu_block_lens = {0};
const int32_t num_sequences = static_cast<int32_t>(sequences_.size());
for (int32_t i = 0; i < num_sequences; ++i) {
auto* sequence = sequences_[i];
Expand Down Expand Up @@ -205,13 +206,11 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
new_token_slot_ids.insert(
new_token_slot_ids.end(), slot_ids.begin(), slot_ids.end());

// construct block ids for each sequence
std::vector<int32_t> block_ids;
block_ids.reserve(blocks.size());
// add block ids for each sequence
for (const auto& block : blocks) {
block_ids.push_back(block.id());
block_tables.push_back(block.id());
}
block_tables_vec.push_back(block_ids);
cu_block_lens.push_back(static_cast<int32_t>(cu_block_lens.size()));
}

if (flatten_tokens_vec.empty()) {
Expand All @@ -238,7 +237,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
}
cu_seq_lens.push_back(cu_seq_lens.back() + num_decoding_tokens);
q_cu_seq_lens.push_back(q_cu_seq_lens.back() + num_decoding_tokens);
block_tables_vec.emplace_back();
// empty block table for padding sequences?
cu_block_lens.push_back(cu_block_lens.back());
}
}
}
Expand All @@ -256,8 +256,8 @@ ModelInput Batch::prepare_model_input(uint32_t num_decoding_tokens,
input_params.q_cu_seq_lens = torch::tensor(q_cu_seq_lens, torch::kInt);
input_params.new_cache_slots = torch::tensor(new_token_slot_ids, torch::kInt);

pad_2d_vector(block_tables_vec, /*pad_value=*/0);
input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt);
input_params.block_tables = torch::tensor(block_tables, torch::kInt);
input_params.cu_block_lens = torch::tensor(cu_block_lens, torch::kInt);

CHECK_EQ(sampling_params.size(), selected_token_idxes.size());
if (!selected_token_idxes.empty()) {
Expand Down
16 changes: 10 additions & 6 deletions src/engine/model_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ ModelRunner::ModelRunner(CausalLM* model,
const int64_t max_block_table_len =
(max_seq_len + block_size - 1) / block_size + 1;
block_tables_ =
torch::zeros({max_batch_size_, max_block_table_len}, tensor_options);
torch::zeros({max_batch_size_ * max_block_table_len}, tensor_options);
cu_block_lens_ = torch::zeros({max_batch_size_ + 1}, tensor_options);

mem_pool_ = at::cuda::graph_pool_handle();
}
Expand Down Expand Up @@ -91,8 +92,9 @@ void ModelRunner::capture_cuda_graphs(uint32_t batch_size,
/*dim=*/0, /*start=*/0, /*end=*/batch_size + 1);
params.kv_cu_seq_lens = kv_cu_seq_lens_.slice(
/*dim=*/0, /*start=*/0, /*end=*/batch_size + 1);
params.block_tables = block_tables_.slice(
/*dim=*/0, /*start=*/0, /*end=*/batch_size);
params.block_tables = block_tables_;
params.cu_block_lens = cu_block_lens_.slice(
/*dim=*/0, /*start=*/0, /*end=*/batch_size + 1);
params.new_cache_slots = new_cache_slots_.slice(
/*dim=*/0, /*start=*/0, /*end=*/n_tokens);

Expand Down Expand Up @@ -156,6 +158,7 @@ void ModelRunner::CudaGraph::capture(at::cuda::MempoolId_t mem_pool,
flatten_positions_ = flatten_positions;
new_cache_slots_ = params.new_cache_slots;
block_tables_ = params.block_tables;
cu_block_lens_ = params.cu_block_lens;
q_cu_seq_lens_ = params.q_cu_seq_lens;
kv_cu_seq_lens_ = params.kv_cu_seq_lens;

Expand Down Expand Up @@ -184,8 +187,8 @@ torch::Tensor ModelRunner::CudaGraph::replay(torch::Tensor flatten_tokens,

const int64_t batch_size = params.num_sequences;
const int64_t num_tokens = flatten_tokens.size(/*dim=*/0);
const int64_t block_table_len = params.block_tables.size(/*dim=*/1);
const int64_t max_block_table_len = block_tables_.size(/*dim=*/1);
const int64_t block_table_len = params.block_tables.size(/*dim=*/0);
const int64_t max_block_table_len = block_tables_.size(/*dim=*/0);
CHECK_EQ(batch_size, batch_size_) << "batch size mismatch";
CHECK_EQ(num_tokens, num_tokens_) << "num tokens mismatch";
CHECK_GE(max_block_table_len, block_table_len) << "block table size mismatch";
Expand All @@ -198,8 +201,9 @@ torch::Tensor ModelRunner::CudaGraph::replay(torch::Tensor flatten_tokens,
new_cache_slots_.copy_(params.new_cache_slots, /*non_blocking=*/true);

// it is possible that the block table with different padding length
block_tables_.slice(/*dim=*/1, /*start=*/0, /*end=*/block_table_len)
block_tables_.slice(/*dim=*/0, /*start=*/0, /*end=*/block_table_len)
.copy_(params.block_tables, /*non_blocking=*/true);
cu_block_lens_.copy_(params.cu_block_lens, /*non_blocking=*/true);

// replay the graph
graph_->replay();
Expand Down
3 changes: 3 additions & 0 deletions src/engine/model_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/cuda/CUDAGraph.h>
#include <absl/container/flat_hash_map.h>
#include <c10/core/TensorImpl.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/torch.h>

Expand Down Expand Up @@ -65,6 +66,7 @@ class ModelRunner final {
torch::Tensor kv_cu_seq_lens_;
torch::Tensor new_cache_slots_;
torch::Tensor block_tables_;
torch::Tensor cu_block_lens_;

// graph pool handler
at::cuda::MempoolId_t mem_pool_;
Expand Down Expand Up @@ -99,6 +101,7 @@ class ModelRunner final {
torch::Tensor flatten_positions_;
torch::Tensor new_cache_slots_;
torch::Tensor block_tables_;
torch::Tensor cu_block_lens_;
torch::Tensor q_cu_seq_lens_;
torch::Tensor kv_cu_seq_lens_;

Expand Down

0 comments on commit 47ead2c

Please sign in to comment.