Skip to content

Commit

Permalink
Kv cache interface (apache#4)
Browse files Browse the repository at this point in the history
Add MarkSend interface, and call KVTransfer kernel in attention

```
void MarkSend(int64_t seq_id, int64_t begin,
                        const IntTuple& compressed_remote_position_map,
                        int32_t recver_pe_offset) = 0;
```
  • Loading branch information
jinhongyii authored Oct 27, 2024
1 parent 8a54170 commit 60caa16
Show file tree
Hide file tree
Showing 8 changed files with 874 additions and 75 deletions.
18 changes: 12 additions & 6 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,9 @@ def tir_kv_cache_transpose_append(
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
num_pages = T.int64()
pages_elem_offset = T.int64()
position_map_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype)
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset)
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype)
v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype)
position_map = T.match_buffer(
Expand Down Expand Up @@ -453,8 +454,9 @@ def tir_kv_cache_debug_get_kv(
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages_elem_offset = T.int64()
position_map_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype,elem_offset=pages_elem_offset)
position_map = T.match_buffer(
var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset
)
Expand Down Expand Up @@ -594,6 +596,7 @@ def batch_prefill_paged_kv(
total_len = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
pages_elem_offset = T.int64(is_size_var=True)
q_indptr_elem_offset = T.int32(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
Expand All @@ -603,7 +606,7 @@ def batch_prefill_paged_kv(

q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset)
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype)
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset)
page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset)
page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
Expand Down Expand Up @@ -971,6 +974,7 @@ def batch_decode_paged_kv(
B = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
pages_elem_offset = T.int64(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
Expand All @@ -979,7 +983,7 @@ def batch_decode_paged_kv(

Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
pages = T.match_buffer(
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype, elem_offset=pages_elem_offset
)
page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset)
page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
Expand Down Expand Up @@ -1907,7 +1911,8 @@ def copy_single_page(
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype)
pages_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype, elem_offset=pages_elem_offset)

for b in T.thread_binding(
(copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
Expand Down Expand Up @@ -1951,7 +1956,8 @@ def compact_kv_copy(
total_copy_length = T.int32()
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype)
pages_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset)
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
Expand Down
69 changes: 37 additions & 32 deletions src/runtime/contrib/nvshmem/kv_transfer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ __device__ int calc_flattened_index(int shape[dim], int index[dim]) {

template <typename T, int local_num_kv_head, int remote_num_kv_head, int head_dim, int page_size>
__global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_position_map,
int ntokens, int remote_layer_id, int local_tp_rank,
int ntokens, int local_tp_rank,
int remote_tp_group_pe_offset, int remote_num_pages) {
// launch grid: [num_blocks, 1, 1], [32, local_num_kv_head, 1]
// pages(remote): [remote_num_layers, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim]
// pages(remote): [remote_num_pages, 2, remote_num_kv_head, page_size, head_dim]
// k_data: [ntokens, local_num_kv_head, head_dim]
// v_data: [ntokens, local_num_kv_head, head_dim]
int remote_pe;
Expand All @@ -48,15 +48,15 @@ __global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_posit
};
int page_id = position / page_size;
int offset_in_page = position % page_size;
int pages_shape[6] = {1, remote_num_pages, 2, remote_num_kv_head, page_size, head_dim};
int k_page_index[6] = {remote_layer_id, page_id, 0, remote_kv_head_index, offset_in_page, 0};
int v_page_index[6] = {remote_layer_id, page_id, 1, remote_kv_head_index, offset_in_page, 0};
int pages_shape[5] = {remote_num_pages, 2, remote_num_kv_head, page_size, head_dim};
int k_page_index[5] = {page_id, 0, remote_kv_head_index, offset_in_page, 0};
int v_page_index[5] = {page_id, 1, remote_kv_head_index, offset_in_page, 0};
int k_v_shape[3] = {ntokens, local_num_kv_head, head_dim};
int k_v_index[3] = {global_pos, h, 0};
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, k_page_index),
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape, k_page_index),
k_data + calc_flattened_index<3>(k_v_shape, k_v_index),
head_dim * sizeof(T), remote_pe);
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<6>(pages_shape, v_page_index),
nvshmemx_putmem_nbi_warp(pages + calc_flattened_index<5>(pages_shape, v_page_index),
v_data + calc_flattened_index<3>(k_v_shape, k_v_index),
head_dim * sizeof(T), remote_pe);
}
Expand Down Expand Up @@ -109,24 +109,26 @@ __global__ void KVTransfer(T* pages, T* k_data, T* v_data, int32_t* remote_posit
LOG(FATAL) << "Unsupported num_kv_head " << num_kv_head; \
}

int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_position_map,
int remote_num_pages, int remote_num_layers, int remote_num_kv_head,
int remote_layer_id, int remote_tp_group_pe_offset) {
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA.";
int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* remote_position_map,
int remote_tp_group_pe_offset, TVMStreamHandle transfer_stream) {
CHECK_EQ(remote_pages->device.device_type, kDLCUDA) << "The device of remote_pages matrix must be CUDA.";
CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA.";
CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA.";
CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
<< "The device of o matrix must be CUDA.";

size_t dev_id = pages->device.device_id;
CHECK_EQ(k->device.device_id, dev_id) << "The device id of q and k matrix doesn't match.";
CHECK_EQ(v->device.device_id, dev_id) << "The device id of q and v matrix doesn't match.";
<< "The device of remote_position_map matrix must be CUDA.";
size_t dev_id = remote_pages->device.device_id;
CHECK_EQ(k->device.device_id, dev_id)
<< "The device id of remote_pages and k matrix doesn't match.";
CHECK_EQ(v->device.device_id, dev_id)
<< "The device id of remote_pages and v matrix doesn't match.";
CHECK_EQ(remote_position_map->device.device_id, dev_id)
<< "The device id of q and o matrix doesn't match.";
<< "The device id of remote_pages and remote_position_map matrix doesn't match.";

CHECK_GE(pages->ndim, 6);
int page_size = pages->shape[pages->ndim - 2];
int head_dim = pages->shape[pages->ndim - 1];
CHECK_EQ(remote_pages->ndim, 5);
int remote_num_pages = remote_pages->shape[0];
int remote_num_kv_head = remote_pages->shape[2];
int page_size = remote_pages->shape[3];
int head_dim = remote_pages->shape[4];

CHECK_GE(k->ndim, 3);
int kv_len = k->shape[k->ndim - 3];
Expand All @@ -138,35 +140,38 @@ int _KVTransfer(DLTensor* pages, DLTensor* k, DLTensor* v, DLTensor* remote_posi
CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
CHECK_EQ(head_dim, v->shape[v->ndim - 1]);

CHECK(pages->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1);
CHECK(pages->dtype.bits == k->dtype.bits && pages->dtype.code == k->dtype.code);
CHECK(pages->dtype.bits == v->dtype.bits && pages->dtype.code == v->dtype.code);
CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 && v->dtype.lanes == 1);
CHECK(remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code == k->dtype.code);
CHECK(remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code == v->dtype.code);
int local_tp_rank;
tvm::runtime::DiscoWorker* worker = tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
if (worker == nullptr){
local_tp_rank = 0;
} else {
local_tp_rank = worker->worker_id;
}

dim3 blocks(8, 1, 1);
dim3 threads(32, local_num_kv_heads, 1);
DISPATCH_TVM_CUDA_DTYPE(
pages->dtype, dtype_in,
remote_pages->dtype, dtype_in,
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM,
{DISPATCH_PAGE_SIZE(
page_size, PAGE_SIZE,
{DISPATCH_NUM_KV_HEAD(
remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
{DISPATCH_NUM_KV_HEAD(remote_num_kv_head, REMOTE_NUM_KV_HEAD,
{DISPATCH_NUM_KV_HEAD(local_num_kv_heads, LOCAL_NUM_KV_HEAD, {
dtype_in* remote_pages_data = (dtype_in*)((char*)remote_pages->data + remote_pages->byte_offset);
dtype_in* k_data = (dtype_in*)((char*)k->data + k->byte_offset);
dtype_in* v_data = (dtype_in*)((char*)v->data + v->byte_offset);
int32_t* remote_position_map_data = (int32_t*)((char*)remote_position_map->data + remote_position_map->byte_offset);
KVTransfer<dtype_in, LOCAL_NUM_KV_HEAD, REMOTE_NUM_KV_HEAD, HEAD_DIM, PAGE_SIZE>
<<<blocks, threads>>>(
(dtype_in*)pages->data, (dtype_in*)k->data, (dtype_in*)v->data,
(int32_t*)remote_position_map->data, kv_len, remote_layer_id,
<<<blocks, threads, 0, static_cast<cudaStream_t>(transfer_stream)>>>(
remote_pages_data, k_data, v_data, remote_position_map_data, kv_len,
local_tp_rank, remote_tp_group_pe_offset, remote_num_pages);
})})})})})
})})})})})

return 0;
return 0;
}

TVM_REGISTER_GLOBAL("nvshmem.KVTransfer").set_body_typed(_KVTransfer);
2 changes: 2 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DisaggPrepareRecv);
TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_mark_send")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::MarkSend);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ class AttentionKVCacheObj : public KVStateObj {
/*! \brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/
virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0;

/*! \brief Mark which tokens' KV cache needs to be sent to other devices */
virtual void MarkSend(int64_t seq_id, int64_t begin,
const IntTuple& compressed_remote_position_map,
int32_t recver_pe_offset) = 0;

/************** Attention **************/

/*!
Expand Down
Loading

0 comments on commit 60caa16

Please sign in to comment.