Skip to content

Commit

Permalink
[KVCache] Disaggregation PrepareRecv (apache#2)
Browse files Browse the repository at this point in the history
* [KVCache] Disaggregation PrepareRecv

This PR introduces the `DisaggPrepareRecv` function. It is essentially
a wrapper of `BeginForward` and returns the host "append position map"
array. It does not involve data copy from CPU to GPU.

* Update src/runtime/relax_vm/paged_kv_cache.cc

Co-authored-by: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com>

---------

Co-authored-by: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com>
  • Loading branch information
MasterJH5574 and CharlieFRuan authored Oct 26, 2024
1 parent 6e3d8b9 commit 8a54170
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);

// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggPrepareRecv);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;

/*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;

/************** Attention **************/

/*!
Expand Down
30 changes: 30 additions & 0 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

IntTuple DisaggPrepareRecv(int64_t seq_id, int append_length) final {
// No CPU to GPU copy is needed.
// Essentially we
// (step 1.) redirect the preparation to BeginForward.
BeginForward({seq_id}, {append_length}, /*opt_token_tree_parent_ptr=*/NullOpt);
// (step 2.) fetch the append_position_map, compress and return.
// Compression format: [n, begin_1, length_1, begin_2, length_2, ..., begin_n, length_n]
// The compressed format will be decompressed to:
// [begin_1, begin_1+1, ..., begin_1+length_1-1, ..., begin_n, ..., begin_n+length_n-1]
CHECK_EQ(append_position_map_host_.size(), append_length);
std::vector<int64_t> compressed_append_pos_map{/*num_segments=*/1,
append_position_map_host_[0]};
for (int i = 1; i < append_length; ++i) {
if (append_position_map_host_[i] != append_position_map_host_[i - 1] + 1) {
// Terminate the current segment.
compressed_append_pos_map.push_back(append_position_map_host_[i - 1] -
compressed_append_pos_map.back() + 1);
// Start a new segment.
++compressed_append_pos_map[0];
compressed_append_pos_map.push_back(append_position_map_host_[i]);
}
}
// Terminate the last segment.
compressed_append_pos_map.push_back(append_position_map_host_.back() -
compressed_append_pos_map.back() + 1);
// The compressed array size should be "num_segments * 2 + 1".
CHECK_EQ(compressed_append_pos_map.size(), compressed_append_pos_map[0] * 2 + 1);
return IntTuple{compressed_append_pos_map};
}

void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) final {
// Part 1. Shape and dtype check.
Expand Down

0 comments on commit 8a54170

Please sign in to comment.