Skip to content

Commit

Permalink
[KVCache] Python constructor for disaggregation (apache#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored Oct 27, 2024
1 parent 76a0245 commit 4a42fdd
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 70 deletions.
21 changes: 19 additions & 2 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
Expand Down Expand Up @@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
enable_disaggregation : bool
Whether to enable disaggregation in the KV cache.
"""
if rope_mode == RopeMode.INLINE:
assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim."
Expand Down Expand Up @@ -262,6 +265,8 @@ def __init__( # pylint: disable=too-many-locals
# fmt: on
# pylint: enable=line-too-long
]
if enable_disaggregation:
args.append(rx.extern("nvshmem.KVTransfer"))
super().__init__(
_expr=rx.call_pure_packed(
"vm.builtin.paged_attention_kv_cache_create",
Expand Down Expand Up @@ -293,6 +298,7 @@ def __init__( # pylint: disable=too-many-locals
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
Expand Down Expand Up @@ -338,6 +344,8 @@ def __init__( # pylint: disable=too-many-locals
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
enable_disaggregation : bool
Whether to enable disaggregation in the KV cache.
target : Target
The target to build the model to.
"""
Expand Down Expand Up @@ -380,6 +388,8 @@ def __init__( # pylint: disable=too-many-locals
# fmt: on
# pylint: enable=line-too-long
]
if enable_disaggregation:
args.append(rx.extern("nvshmem.KVTransfer"))
super().__init__(
_expr=rx.call_pure_packed(
"vm.builtin.paged_attention_kv_cache_create_reduced",
Expand Down Expand Up @@ -1912,7 +1922,12 @@ def copy_single_page(
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
pages_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype, elem_offset=pages_elem_offset)
pages = T.match_buffer(
var_pages,
(num_pages, 2, num_heads, page_size, head_dim),
dtype,
elem_offset=pages_elem_offset,
)

for b in T.thread_binding(
(copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
Expand Down Expand Up @@ -1957,7 +1972,9 @@ def compact_kv_copy(
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
pages_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset)
pages = T.match_buffer(
var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset
)
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggPrepareRecv);
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_mark_send")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::MarkSend);
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DIsaggMarkSend);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ class AttentionKVCacheObj : public KVStateObj {
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;

/*! \brief Mark which tokens' KV cache needs to be sent to other devices */
virtual void MarkSend(int64_t seq_id, int64_t begin,
const IntTuple& compressed_remote_position_map,
int32_t recver_pe_offset) = 0;
virtual void DIsaggMarkSend(int64_t seq_id, int64_t begin,
const IntTuple& compressed_remote_position_map,
int32_t recver_pe_offset) = 0;

/************** Attention **************/

Expand Down
84 changes: 44 additions & 40 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device);
q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_remote_position_map_device = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
kv_transfer_remote_position_map_device =
NDArray::Empty({prefill_chunk_size}, dtype_aux_, device);
commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device);
commit_copy_src_dst_pos_in_page_table_device_ =
NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)},
Expand Down Expand Up @@ -464,8 +465,8 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
return view;
}
NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final {
NDArray view =
kv_transfer_remote_position_map_device.CreateView({static_cast<int64_t>(data->size())}, dtype_aux_);
NDArray view = kv_transfer_remote_position_map_device.CreateView(
{static_cast<int64_t>(data->size())}, dtype_aux_);
CopyVecDataToArray(view, data->data());
return view;
}
Expand Down Expand Up @@ -826,7 +827,6 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
NDArray merged_compact_kv_aux_data_device_;
};


/*!
* \brief The paged KV cache for attention.
* - It supports managing the K/V data of **multiple sequences**.
Expand Down Expand Up @@ -935,9 +935,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::vector<bool> use_decode_kernel_;
/*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */
bool is_decode_request_;
/*! \brief The KV transfer recver disco group's PE offset in this forward.
/*! \brief The KV transfer recver disco group's PE offset in this forward.
If no KV is transfered, recver is -1.
Assume that all the KV are transfered to the same recver in the forward.
Assume that all the KV are transfered to the same recver in the forward.
todo: support multiple recver. */
int transfer_kv_recver;
/*! \brief The auxiliary data manager for attention. */
Expand Down Expand Up @@ -1047,7 +1047,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc f_merge_inplace,
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_transfer_kv, Optional<PackedFunc> f_debug_get_kv)
PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional<PackedFunc> f_transfer_kv,
Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
layer_id_begin_offset_(layer_id_begin_offset),
Expand Down Expand Up @@ -1086,13 +1087,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_debug_get_kv_(std::move(f_debug_get_kv)),
device_(device) {
pages_.reserve(num_layers);
if(f_transfer_kv_.defined()) {
if (f_transfer_kv_.defined()) {
ICHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem"));
auto f_nvshmem_empty = runtime::Registry::Get("runtime.disco.nvshmem.empty");
nvshmem_pages_ = (*f_nvshmem_empty)(
ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, head_dim}), dtype,
device);
for (int i = 0; i < num_layers; i++){
for (int i = 0; i < num_layers; i++) {
pages_.push_back(nvshmem_pages_.CreateView(
{num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_}, nvshmem_pages_->dtype,
i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ *
Expand Down Expand Up @@ -1779,7 +1780,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
} else {
kv_transfer_remote_position_map_host_.push_back(
sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - seq_send_start]);
if(transfer_kv_recver == -1){
if (transfer_kv_recver == -1) {
transfer_kv_recver = sequences[i]->kv_transfer_metadata.recver_pe_offset;
} else {
ICHECK_EQ(transfer_kv_recver, sequences[i]->kv_transfer_metadata.recver_pe_offset);
Expand Down Expand Up @@ -1831,8 +1832,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
return IntTuple{compressed_append_pos_map};
}

void MarkSend(int64_t seq_id, int64_t begin, const IntTuple& compressed_remote_position_map,
int32_t recver_pe_offset) {
void DIsaggMarkSend(int64_t seq_id, int64_t begin, const IntTuple& compressed_remote_position_map,
int32_t recver_pe_offset) {
ICHECK(f_transfer_kv_.defined());
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
Expand Down Expand Up @@ -1914,19 +1915,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
}
// Part 4: KV transfer
if(transfer_kv_recver != -1){
// FIXME: if the sender and recver's PP/TP degree do not match, we will need to first
if (transfer_kv_recver != -1) {
// FIXME: if the sender and recver's PP/TP degree do not match, we will need to first
// get the view of remote pages, and then take the specific remote layer.
f_transfer_kv_.value()(pages_[local_layer_id], k_data, v_data,
kv_transfer_remote_position_map_view_, transfer_kv_recver,
kv_transfer_stream_);
kv_transfer_remote_position_map_view_, transfer_kv_recver,
kv_transfer_stream_);
}
// Part 4: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor);
// Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set.
if (!append_before_attn_) {
f_transpose_append_(pages_[local_layer_id], k_data, v_data,
append_position_map_view_);
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
}
}

Expand Down Expand Up @@ -2075,8 +2075,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
append_position_map.data() + start_pos,
(end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8));
for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data,
v_data, layer_id);
f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, v_data, layer_id);
}
}

Expand Down Expand Up @@ -2565,18 +2564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
attn_score_scaling_factor, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d]);
} else if (use_decode_kernel_[d]) {
// Use decode kernel for depth d
f_decode(/*depth=*/d, q_data, pages_[local_layer_id],
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
q_rope_position_map_view_, attn_output, attn_scores,
f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
attn_score_scaling_factor);
} else {
// Use prefill kernel for depth d
f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d],
pages_[local_layer_id], page_indptr_on_depths_view_[d],
page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores,
f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
q_rope_position_map_view_, attn_output, attn_scores,
/*causal=*/0,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
attn_score_scaling_factor);
Expand Down Expand Up @@ -2691,8 +2689,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
append_position_map_view_ =
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
// 10. kv_transfer_remote_position_map
kv_transfer_remote_position_map_view_ =
aux_data_manager_->CopyKVTransferRemotePositionMapAsync(&kv_transfer_remote_position_map_host_);
kv_transfer_remote_position_map_view_ = aux_data_manager_->CopyKVTransferRemotePositionMapAsync(
&kv_transfer_remote_position_map_host_);
// 11. tree_attn_mask and tree_attn_mn_indptr
for (int d = 0; d < num_depths_; ++d) {
if (!is_chain_on_depths_[d]) {
Expand Down Expand Up @@ -2726,7 +2724,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 29 || args.size() == 30)
CHECK(args.size() == 28 || args.size() == 29 || args.size() == 30)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
Expand Down Expand Up @@ -2766,11 +2764,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
PackedFunc f_compact_copy = args[25];
PackedFunc f_attention_prefill_with_tree_mask = args[26];
PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27];
Optional<PackedFunc> f_transfer_kv = args[28];
Optional<NDArray> rope_ext_factors = NullOpt;
Optional<PackedFunc> f_transfer_kv = NullOpt;

if (args.size() >= 30 && args[29].IsObjectRef<NDArray>()) {
rope_ext_factors = args[29].AsObjectRef<NDArray>();
if (args.size() >= 29 && args[28].IsObjectRef<NDArray>()) {
rope_ext_factors = args[28].AsObjectRef<NDArray>();
}
if (args.size() >= 30 && args[29].IsObjectRef<PackedFunc>()) {
f_transfer_kv = args[29].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand Down Expand Up @@ -2798,14 +2799,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
std::move(f_attention_prefill_ragged_end_forward),
std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward),
std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward),
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_transfer_kv),
std::move(f_debug_get_kv));
std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page),
std::move(f_transfer_kv), std::move(f_debug_get_kv));
*rv = AttentionKVCache(std::move(n));
});

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 23 || args.size() == 24)
CHECK(args.size() == 22 || args.size() == 23 || args.size() == 24)
<< "Invalid number of KV cache constructor args.";
ShapeTuple cache_config = args[0];
ShapeTuple layer_indptr_tuple = args[1];
Expand Down Expand Up @@ -2839,11 +2840,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
PackedFunc f_compact_copy = args[19];
PackedFunc f_attention_prefill_with_tree_mask = args[20];
PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21];
Optional<PackedFunc> f_transfer_kv = args[22];
Optional<NDArray> rope_ext_factors = NullOpt;
Optional<PackedFunc> f_transfer_kv = NullOpt;

if (args.size() >= 24 && args[23].IsObjectRef<NDArray>()) {
rope_ext_factors = args[23].AsObjectRef<NDArray>();
if (args.size() >= 23 && args[22].IsObjectRef<NDArray>()) {
rope_ext_factors = args[22].AsObjectRef<NDArray>();
}
if (args.size() >= 24 && args[23].IsObjectRef<PackedFunc>()) {
f_transfer_kv = args[23].AsObjectRef<PackedFunc>();
}

CHECK_EQ(cache_config.size(), 5);
Expand Down
Loading

0 comments on commit 4a42fdd

Please sign in to comment.