From 9e583d6b5ae9af241c13369b8ae7f291a02f335f Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 8 Aug 2024 15:54:00 +0300 Subject: [PATCH] Add comment --- vllm/model_executor/models/jamba.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 58c336e04867..30d3bec380ea 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -651,20 +651,20 @@ def forward(self, mamba_cache[1]) return hidden_states - def _swap_mamba_cache(self, to_index: int, from_index: int): + def _swap_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 for cache_t in self.mamba_cache: cache_t[:, [to_index,from_index]] = \ cache_t[:, [from_index,to_index]] - def _copy_mamba_cache(self, to_index: int, from_index: int): + def _copy_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 for cache_t in self.mamba_cache: cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): + all_occupied_indices: List[int]): if index in all_occupied_indices: first_free_index = self._first_free_index_in_mamba_cache() # In case occupied, move the occupied to a new empty block @@ -674,6 +674,10 @@ def _move_out_if_already_occupied(self, index: int, def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, seq_id: int, destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ all_occupied_indices = self._get_all_occupied_indices() if cur_rid not in self.mamba_cache_indices_mapping: self._move_out_if_already_occupied( @@ -697,7 +701,7 @@ def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, self.mamba_cache_indices_mapping[cur_rid][ seq_id] = destination_index else: - ## already exists + # already exists cache_index_already_exists = self.mamba_cache_indices_mapping[ cur_rid][seq_id] if cache_index_already_exists != destination_index: