From 4a42fdd7b52cb6107ab88d779a9d94ece1bbdd99 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 27 Oct 2024 10:59:03 -0400 Subject: [PATCH] [KVCache] Python constructor for disaggregation (#5) --- python/tvm/relax/frontend/nn/llm/kv_cache.py | 21 ++++- src/runtime/relax_vm/kv_state.cc | 4 +- src/runtime/relax_vm/kv_state.h | 6 +- src/runtime/relax_vm/paged_kv_cache.cc | 84 ++++++++++--------- .../test_runtime_builtin_kv_cache_transfer.py | 65 +++++++++----- 5 files changed, 110 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index b7f032d886946..a524463b7b1d7 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals rope_scaling: Dict[str, Any], rope_ext_factors: rx.Expr, rotary_dim: int, + enable_disaggregation: bool, dtype: str, target: Target, name: str = "paged_kv_cache", @@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals The RoPE extension factors when "longrope" mode RoPE scaling is enabled. rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. + enable_disaggregation : bool + Whether to enable disaggregation in the KV cache. """ if rope_mode == RopeMode.INLINE: assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim." @@ -262,6 +265,8 @@ def __init__( # pylint: disable=too-many-locals # fmt: on # pylint: enable=line-too-long ] + if enable_disaggregation: + args.append(rx.extern("nvshmem.KVTransfer")) super().__init__( _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create", @@ -293,6 +298,7 @@ def __init__( # pylint: disable=too-many-locals rope_scaling: Dict[str, Any], rope_ext_factors: rx.Expr, rotary_dim: int, + enable_disaggregation: bool, dtype: str, target: Target, name: str = "paged_kv_cache", @@ -338,6 +344,8 @@ def __init__( # pylint: disable=too-many-locals The RoPE extension factors when "longrope" mode RoPE scaling is enabled. rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. + enable_disaggregation : bool + Whether to enable disaggregation in the KV cache. target : Target The target to build the model to. """ @@ -380,6 +388,8 @@ def __init__( # pylint: disable=too-many-locals # fmt: on # pylint: enable=line-too-long ] + if enable_disaggregation: + args.append(rx.extern("nvshmem.KVTransfer")) super().__init__( _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create_reduced", @@ -1912,7 +1922,12 @@ def copy_single_page( T.func_attr({"tir.is_scheduled": 1}) num_pages = T.int32() pages_elem_offset = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype, elem_offset=pages_elem_offset) + pages = T.match_buffer( + var_pages, + (num_pages, 2, num_heads, page_size, head_dim), + dtype, + elem_offset=pages_elem_offset, + ) for b in T.thread_binding( (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" @@ -1957,7 +1972,9 @@ def compact_kv_copy( copy_length_indptr_elem_offset = T.int32() copy_src_dst_pos_elem_offset = T.int32() pages_elem_offset = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset) + pages = T.match_buffer( + var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset + ) copy_length_indptr = T.match_buffer( var_copy_length_indptr, (batch_size + 1,), diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index d55a34a9136c6..5dd60b25801aa 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -58,8 +58,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward") // Attention KV Cache methods TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_prepare_recv") .set_body_method(&AttentionKVCacheObj::DisaggPrepareRecv); -TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_mark_send") - .set_body_method(&AttentionKVCacheObj::MarkSend); +TVM_REGISTER_GLOBAL("vm.builtin.kv_cache_disagg_mark_send") + .set_body_method(&AttentionKVCacheObj::DIsaggMarkSend); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq") .set_body_method(&AttentionKVCacheObj::EnableSlidingWindowForSeq); TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes") diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index bc640b81d2310..525a024c63cef 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -161,9 +161,9 @@ class AttentionKVCacheObj : public KVStateObj { virtual IntTuple DisaggPrepareRecv(int64_t seq_id, int length) = 0; /*! \brief Mark which tokens' KV cache needs to be sent to other devices */ - virtual void MarkSend(int64_t seq_id, int64_t begin, - const IntTuple& compressed_remote_position_map, - int32_t recver_pe_offset) = 0; + virtual void DIsaggMarkSend(int64_t seq_id, int64_t begin, + const IntTuple& compressed_remote_position_map, + int32_t recver_pe_offset) = 0; /************** Attention **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index b7c931c93aaf0..f7d93b419fe89 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -400,7 +400,8 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, dtype_aux_, device); q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); - kv_transfer_remote_position_map_device = NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); + kv_transfer_remote_position_map_device = + NDArray::Empty({prefill_chunk_size}, dtype_aux_, device); commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, dtype_aux_, device); commit_copy_src_dst_pos_in_page_table_device_ = NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, prefill_chunk_size)}, @@ -464,8 +465,8 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { return view; } NDArray CopyKVTransferRemotePositionMapAsync(HostMemoryVector* data) final { - NDArray view = - kv_transfer_remote_position_map_device.CreateView({static_cast(data->size())}, dtype_aux_); + NDArray view = kv_transfer_remote_position_map_device.CreateView( + {static_cast(data->size())}, dtype_aux_); CopyVecDataToArray(view, data->data()); return view; } @@ -826,7 +827,6 @@ class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { NDArray merged_compact_kv_aux_data_device_; }; - /*! * \brief The paged KV cache for attention. * - It supports managing the K/V data of **multiple sequences**. @@ -935,9 +935,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::vector use_decode_kernel_; /*! \brief Whether the attention request is a decode request, set in BeginForwardFunction. */ bool is_decode_request_; - /*! \brief The KV transfer recver disco group's PE offset in this forward. + /*! \brief The KV transfer recver disco group's PE offset in this forward. If no KV is transfered, recver is -1. - Assume that all the KV are transfered to the same recver in the forward. + Assume that all the KV are transfered to the same recver in the forward. todo: support multiple recver. */ int transfer_kv_recver; /*! \brief The auxiliary data manager for attention. */ @@ -1047,7 +1047,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { Optional f_attention_prefill_end_forward, Optional f_attention_decode_begin_forward, Optional f_attention_decode_end_forward, PackedFunc f_merge_inplace, - PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_transfer_kv, Optional f_debug_get_kv) + PackedFunc f_split_rotary, PackedFunc f_copy_single_page, Optional f_transfer_kv, + Optional f_debug_get_kv) : page_size_(page_size), num_layers_(num_layers), layer_id_begin_offset_(layer_id_begin_offset), @@ -1086,13 +1087,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_debug_get_kv_(std::move(f_debug_get_kv)), device_(device) { pages_.reserve(num_layers); - if(f_transfer_kv_.defined()) { + if (f_transfer_kv_.defined()) { ICHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem")); auto f_nvshmem_empty = runtime::Registry::Get("runtime.disco.nvshmem.empty"); nvshmem_pages_ = (*f_nvshmem_empty)( ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, head_dim}), dtype, device); - for (int i = 0; i < num_layers; i++){ + for (int i = 0; i < num_layers; i++) { pages_.push_back(nvshmem_pages_.CreateView( {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_}, nvshmem_pages_->dtype, i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ * @@ -1779,7 +1780,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { kv_transfer_remote_position_map_host_.push_back( sequences[i]->kv_transfer_metadata.remote_position_map[pos_in_seq - seq_send_start]); - if(transfer_kv_recver == -1){ + if (transfer_kv_recver == -1) { transfer_kv_recver = sequences[i]->kv_transfer_metadata.recver_pe_offset; } else { ICHECK_EQ(transfer_kv_recver, sequences[i]->kv_transfer_metadata.recver_pe_offset); @@ -1831,8 +1832,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return IntTuple{compressed_append_pos_map}; } - void MarkSend(int64_t seq_id, int64_t begin, const IntTuple& compressed_remote_position_map, - int32_t recver_pe_offset) { + void DIsaggMarkSend(int64_t seq_id, int64_t begin, const IntTuple& compressed_remote_position_map, + int32_t recver_pe_offset) { ICHECK(f_transfer_kv_.defined()); auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; @@ -1914,19 +1915,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } // Part 4: KV transfer - if(transfer_kv_recver != -1){ - // FIXME: if the sender and recver's PP/TP degree do not match, we will need to first + if (transfer_kv_recver != -1) { + // FIXME: if the sender and recver's PP/TP degree do not match, we will need to first // get the view of remote pages, and then take the specific remote layer. f_transfer_kv_.value()(pages_[local_layer_id], k_data, v_data, - kv_transfer_remote_position_map_view_, transfer_kv_recver, - kv_transfer_stream_); + kv_transfer_remote_position_map_view_, transfer_kv_recver, + kv_transfer_stream_); } // Part 4: perform attention AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor); // Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set. if (!append_before_attn_) { - f_transpose_append_(pages_[local_layer_id], k_data, v_data, - append_position_map_view_); + f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_); } } @@ -2075,8 +2075,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map.data() + start_pos, (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 8)); for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) { - f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, - v_data, layer_id); + f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, v_data, layer_id); } } @@ -2565,18 +2564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { attn_score_scaling_factor, tree_attn_mn_indptr_view_[d], tree_attn_mask_view_[d]); } else if (use_decode_kernel_[d]) { // Use decode kernel for depth d - f_decode(/*depth=*/d, q_data, pages_[local_layer_id], - page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], - length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], - q_rope_position_map_view_, attn_output, attn_scores, + f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d], + page_indices_on_depths_view_[d], length_info_on_depths_view_[d], + k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor); } else { // Use prefill kernel for depth d - f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], - pages_[local_layer_id], page_indptr_on_depths_view_[d], - page_indices_on_depths_view_[d], length_info_on_depths_view_[d], - k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores, + f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id], + page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d], + length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], + q_rope_position_map_view_, attn_output, attn_scores, /*causal=*/0, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_, attn_score_scaling_factor); @@ -2691,8 +2689,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { append_position_map_view_ = aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_); // 10. kv_transfer_remote_position_map - kv_transfer_remote_position_map_view_ = - aux_data_manager_->CopyKVTransferRemotePositionMapAsync(&kv_transfer_remote_position_map_host_); + kv_transfer_remote_position_map_view_ = aux_data_manager_->CopyKVTransferRemotePositionMapAsync( + &kv_transfer_remote_position_map_host_); // 11. tree_attn_mask and tree_attn_mn_indptr for (int d = 0; d < num_depths_; ++d) { if (!is_chain_on_depths_[d]) { @@ -2726,7 +2724,7 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 29 || args.size() == 30) + CHECK(args.size() == 28 || args.size() == 29 || args.size() == 30) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2766,11 +2764,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") PackedFunc f_compact_copy = args[25]; PackedFunc f_attention_prefill_with_tree_mask = args[26]; PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[27]; - Optional f_transfer_kv = args[28]; Optional rope_ext_factors = NullOpt; + Optional f_transfer_kv = NullOpt; - if (args.size() >= 30 && args[29].IsObjectRef()) { - rope_ext_factors = args[29].AsObjectRef(); + if (args.size() >= 29 && args[28].IsObjectRef()) { + rope_ext_factors = args[28].AsObjectRef(); + } + if (args.size() >= 30 && args[29].IsObjectRef()) { + f_transfer_kv = args[29].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); @@ -2798,14 +2799,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create") std::move(f_attention_prefill_ragged_end_forward), std::move(f_attention_prefill_begin_forward), std::move(f_attention_prefill_end_forward), std::move(f_attention_decode_begin_forward), std::move(f_attention_decode_end_forward), - std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), std::move(f_transfer_kv), - std::move(f_debug_get_kv)); + std::move(f_merge_inplace), std::move(f_split_rotary), std::move(f_copy_single_page), + std::move(f_transfer_kv), std::move(f_debug_get_kv)); *rv = AttentionKVCache(std::move(n)); }); TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK(args.size() == 23 || args.size() == 24) + CHECK(args.size() == 22 || args.size() == 23 || args.size() == 24) << "Invalid number of KV cache constructor args."; ShapeTuple cache_config = args[0]; ShapeTuple layer_indptr_tuple = args[1]; @@ -2839,11 +2840,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced") PackedFunc f_compact_copy = args[19]; PackedFunc f_attention_prefill_with_tree_mask = args[20]; PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[21]; - Optional f_transfer_kv = args[22]; Optional rope_ext_factors = NullOpt; + Optional f_transfer_kv = NullOpt; - if (args.size() >= 24 && args[23].IsObjectRef()) { - rope_ext_factors = args[23].AsObjectRef(); + if (args.size() >= 23 && args[22].IsObjectRef()) { + rope_ext_factors = args[22].AsObjectRef(); + } + if (args.size() >= 24 && args[23].IsObjectRef()) { + f_transfer_kv = args[23].AsObjectRef(); } CHECK_EQ(cache_config.size(), 5); diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index aa0920367dcb8..5f9848ac3af47 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -21,6 +21,7 @@ import numpy as np import pytest import scipy.special +from mpi4py import MPI import tvm import tvm.testing @@ -39,7 +40,6 @@ tree_attn_with_paged_kv_cache, ) from tvm.runtime import ShapeTuple -from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() @@ -73,7 +73,7 @@ fnvshmem_get_uid = None fnvshmem_init = None ftransfer_kv = None -fmark_send = None +fdisagg_mark_send = None fdisagg_prepare_recv = None ftranspose_append = None @@ -91,6 +91,7 @@ fcopy_single_page = None fcompact_copy = None + def set_global_func(head_dim, dtype): global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq global fpopn, fbegin_forward, fend_forward, fcommit_accepted_token_tree_nodes @@ -99,7 +100,7 @@ def set_global_func(head_dim, dtype): global fattn_prefill_ragged, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache global fattn_prefill_sliding_window, fattn_decode_sliding_window global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, fcompact_copy - global fnvshmem_get_uid, fnvshmem_init, ftransfer_kv, fmark_send, fdisagg_prepare_recv + global fnvshmem_get_uid, fnvshmem_init, ftransfer_kv, fdisagg_mark_send, fdisagg_prepare_recv fclear = tvm.get_global_func("vm.builtin.kv_state_clear") fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence") @@ -119,13 +120,13 @@ def set_global_func(head_dim, dtype): ) fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - + fnvshmem_get_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") fnvshmem_init = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem") ftransfer_kv = tvm.get_global_func("nvshmem.KVTransfer") - fmark_send = tvm.get_global_func("vm.builtin.kv_cache_mark_send") + fdisagg_mark_send = tvm.get_global_func("vm.builtin.kv_cache_disagg_mark_send") fdisagg_prepare_recv = tvm.get_global_func("vm.builtin.kv_cache_disagg_prepare_recv") - + target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ @@ -207,8 +208,8 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): fcompact_copy, fattn_prefill_with_tree_mask, fattn_prefill_with_tree_mask_paged_kv_cache, - ftransfer_kv, None, + ftransfer_kv, ) return cache @@ -293,8 +294,8 @@ def apply_attention( attn_sink_sizes: Optional[List[int]] = None, token_tree_parent_ptr_list: Optional[List[List[int]]] = None, accepted_leaf_indices: Optional[List[int]] = None, - only_update_host = False, - skip_add_sequence = False, + only_update_host=False, + skip_add_sequence=False, ) -> None: seq_ids = [] append_lengths = [] @@ -385,9 +386,11 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i][-append_length:] - if token_tree_node_depths_list[i] is not None - else None, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), ) ) for l in range(num_layers) @@ -430,9 +433,11 @@ def apply_attention( rope_offset, rope_scale, rope_theta, - token_tree_node_depths_list[i][-append_length:] - if token_tree_node_depths_list[i] is not None - else None, + ( + token_tree_node_depths_list[i][-append_length:] + if token_tree_node_depths_list[i] is not None + else None + ), ) ).transpose(1, 0, 2) k_seq = ( @@ -581,6 +586,7 @@ def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): for batch in operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) + @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_paged_attention_kv_cache_transfer(kv_cache_and_config): @@ -591,7 +597,7 @@ def test_paged_attention_kv_cache_transfer(kv_cache_and_config): np.random.seed(0) fclear(kv_cache) # Prefill. - prefill_operation_seq= [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] + prefill_operation_seq = [[(0, 6)], [(1, 8)], [(2, 11)], [(3, 16)], [(4, 19), (5, 20)]] prefill_operation_seq += [[(6, 21), (7, 24)], [(2, 5), (4, 7), (8, 24)]] prefill_operation_seq += [[(6, 13)], [(8, 19)], [(0, 1)], [(1, 3), (3, 8), (5, 12), (7, 11)]] prefill_len = {i: 0 for i in range(9)} @@ -599,8 +605,12 @@ def test_paged_attention_kv_cache_transfer(kv_cache_and_config): for seq_id, append_length in batch: prefill_len[seq_id] += append_length # Decode - decode_operation_seq = [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] - decode_operation_seq += [[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] + decode_operation_seq = [ + [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)] + ] + decode_operation_seq += [ + [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, 1)] + ] decode_operation_seq += [[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1)]] decode_operation_seq += [[(4, 1), (5, 1), (6, 1), (7, 1), (8, 1)]] @@ -616,7 +626,7 @@ def test_paged_attention_kv_cache_transfer(kv_cache_and_config): comm.Barrier() print("phase 1") for seq_id in prefill_len.keys(): - fmark_send(kv_cache, seq_id, 0, ShapeTuple(remote_pos_maps[seq_id]), 1) + fdisagg_mark_send(kv_cache, seq_id, 0, ShapeTuple(remote_pos_maps[seq_id]), 1) for batch in prefill_operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, skip_add_sequence=True) device.sync() @@ -630,12 +640,20 @@ def test_paged_attention_kv_cache_transfer(kv_cache_and_config): remote_pos_maps = comm.bcast(remote_pos_maps, root=1) comm.Barrier() for batch in prefill_operation_seq: - apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, only_update_host=True, skip_add_sequence=True) + apply_attention( + kv_cache, + rope_mode, + batch, + cached_k, + cached_v, + only_update_host=True, + skip_add_sequence=True, + ) comm.Barrier() print("phase 2") for batch in decode_operation_seq: apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v, skip_add_sequence=True) - + def init_nvshmem(num_workers, pe_offset): if rank == 0: @@ -646,7 +664,8 @@ def init_nvshmem(num_workers, pe_offset): uid = comm.bcast(uid, root=0) init_func = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem") init_func(uid, num_workers, pe_offset) - + + if __name__ == "__main__": HEAD_DIMS = [128] DTYPES = ["float16"] @@ -659,4 +678,4 @@ def init_nvshmem(num_workers, pe_offset): set_global_func(head_dim, dtype) cache = create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window) cache_and_config = (cache, rope_mode, support_sliding_window) - test_paged_attention_kv_cache_transfer(cache_and_config) \ No newline at end of file + test_paged_attention_kv_cache_transfer(cache_and_config)