From 07ab160741a486bbef23efbf26aaf2ea8a785ae1 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Fri, 9 Aug 2024 17:07:06 +0300 Subject: [PATCH] [Model][Jamba] Mamba cache single buffer (#6739) Co-authored-by: Mor Zusman --- vllm/model_executor/models/jamba.py | 269 +++++++++++++++------------- vllm/worker/model_runner.py | 3 - 2 files changed, 148 insertions(+), 124 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index cf407c86acd7..ededf9c533f0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -609,12 +609,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) - # Current step used indices - self.current_indices: List[int] = [] # Used to track and store by the Mamba cache between steps. self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() - # Used as an input_buffer for the CUDA graph runs. - self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple() # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} @@ -644,95 +640,148 @@ def forward(self, batch_size = input_ids.shape[0] if attn_metadata.prefill_metadata: batch_size = len(request_ids_to_seq_ids) - ( - current_seqlen_agnostic_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - batch_size, - finished_requests_ids) + mamba_cache = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, batch_size, finished_requests_ids) else: # CUDA graph capturing runs - current_seqlen_agnostic_cache, indices = ( - kwargs["seqlen_agnostic_capture_inputs"], - [], - ) - self.current_indices = indices + mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, - current_seqlen_agnostic_cache[0], - current_seqlen_agnostic_cache[1]) - - if "seqlen_agnostic_capture_inputs" not in kwargs: - self._copy_mamba_cache_by_indices(self.current_indices, - current_seqlen_agnostic_cache) - + attn_metadata, mamba_cache[0], + mamba_cache[1]) return hidden_states - def _copy_mamba_cache_by_indices( - self, indices: List[int], - current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]): - for i, offset in enumerate(indices): - self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache) + def _swap_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, [to_index,from_index]] = \ + cache_t[:, [from_index,to_index]] - def _copy_mamba_cache(self, index_to: int, index_from: int, - from_buffer: Tuple[torch.Tensor, torch.Tensor]): + def _copy_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 - for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): - cache_t[:, index_to].copy_(from_buffer_t[:, index_from], + for cache_t in self.mamba_cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) - def _assign_seq_id_to_mamba_cache(self, cur_rid: str, - seqs_id: List[int]) -> List[int]: - indices_for_current_run = [] - for seq_id in seqs_id: - if cur_rid not in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping[cur_rid] = {} - first_free_index = self._first_free_index_in_mamba_cache() - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - ## case of decoding n>1, copy prefill cache to decoding indices - elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): - first_free_index = self._first_free_index_in_mamba_cache() - index_exist = list(seq_ids2indices.values())[0] - self._copy_mamba_cache(index_from=index_exist, - index_to=first_free_index, - from_buffer=self.mamba_cache) - self.mamba_cache_indices_mapping[cur_rid][ - seq_id] = first_free_index - index_for_current_run = first_free_index - else: - index_for_current_run = self.mamba_cache_indices_mapping[ - cur_rid][seq_id] - - indices_for_current_run.append(index_for_current_run) - return indices_for_current_run + def _move_out_if_already_occupied(self, index: int, + all_occupied_indices: List[int]): + if index in all_occupied_indices: + first_free_index = self._first_free_index_in_mamba_cache() + # In case occupied, move the occupied to a new empty block + self._move_cache_index_and_mappings(from_index=index, + to_index=first_free_index) + + def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, + seq_id: int, + destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + all_occupied_indices = self._get_all_occupied_indices() + if cur_rid not in self.mamba_cache_indices_mapping: + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened now we only need to copy the already + # existing cache into the siblings seq_ids caches + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + index_exists = list(seq_ids2indices.values())[0] + # case of decoding n>1, copy prefill cache to decoding indices + self._copy_mamba_cache(from_index=index_exists, + to_index=destination_index) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index + else: + # already exists + cache_index_already_exists = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + if cache_index_already_exists != destination_index: + # In case the seq id already exists but not in + # the right destination, swap it with what's occupying it + self._swap_pair_indices_and_mappings( + from_index=cache_index_already_exists, + to_index=destination_index) def _prepare_current_run_mamba_cache( - self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, - finished_requests_ids: List[str] - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: - indices_for_current_run = [] - for request_id, seqs_id in request_ids_to_seq_ids.items(): + self, request_ids_to_seq_ids: Dict[str, list[int]], + batch_size: int, finished_requests_ids: List[str]): + running_indices = [] + request_ids_to_seq_ids_flatten = [ + (req_id, seq_id) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + for dest_index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): if request_id in finished_requests_ids: - # Do not allocate cache for requests that run + # Do not allocate cache index for requests that run # and finish right after continue - indices_for_current_run += self._assign_seq_id_to_mamba_cache( - request_id, seqs_id) - ## Pad the batch in case of running batch that was not captured via CG - padded_indices = indices_for_current_run.copy() - pad_index = self._first_free_index_in_mamba_cache() + self._assign_seq_id_to_mamba_cache_in_specific_dest( + request_id, seq_id, dest_index) + running_indices.append(dest_index) - for _ in range(batch_size - len(indices_for_current_run)): - padded_indices.append(pad_index) + self._clean_up_first_bs_blocks(batch_size, running_indices) + conv_state = self.mamba_cache[0][:, :batch_size] + temporal_state = self.mamba_cache[1][:, :batch_size] - conv_state = self.mamba_cache[0][:, padded_indices] - temporal_state = self.mamba_cache[1][:, padded_indices] + return (conv_state, temporal_state) - return (conv_state, temporal_state), indices_for_current_run + def _get_all_occupied_indices(self): + return [ + cache_idx + for seq_ids2indices in self.mamba_cache_indices_mapping.values() + for cache_idx in seq_ids2indices.values() + ] + + def _clean_up_first_bs_blocks(self, batch_size: int, + indices_for_current_run: List[int]): + # move out all of the occupied but currently not running blocks + # outside of the first n blocks + destination_indices = set([range(batch_size)]) + max_possible_batch_size = self.mamba_cache[0].shape[1] + for destination_index in destination_indices: + if destination_index in self._get_all_occupied_indices() and \ + destination_index not in indices_for_current_run: + # move not running indices outside of the batch + all_other_indices = list( + range(batch_size, max_possible_batch_size)) + first_avail_index = self._first_free_index_in_mamba_cache( + all_other_indices) + self._swap_indices(from_index=destination_index, + to_index=first_avail_index) + + def _move_cache_index_and_mappings(self, from_index: int, to_index: int): + self._copy_mamba_cache(from_index=from_index, to_index=to_index) + self._update_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): + self._swap_mamba_cache(from_index=from_index, to_index=to_index) + self._swap_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + elif to_index == index: + seq_ids2index.update({seq_id: from_index}) + + def _update_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + return def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ @@ -747,28 +796,9 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): self._release_mamba_cache(finished_requests_ids) request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] cg_batch_size = input_buffers['input_ids'].shape[0] - ( - current_mamba_cache, - indices, - ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, - cg_batch_size, - finished_requests_ids) - self.current_indices = indices - - for input_buffer, current_cache_buffer in zip( - input_buffers["seqlen_agnostic_capture_inputs"], - current_mamba_cache): - input_buffer.copy_(current_cache_buffer, non_blocking=True) - - def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs): - """ - Copy the relevant Mamba cache from the CUDA graph input_buffers - back to the JambaForCausalLM.mamba_cache after CUDA - graph replay run is done. - """ - self._copy_mamba_cache_by_indices( - self.current_indices, - input_buffers["seqlen_agnostic_capture_inputs"]) + self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + cg_batch_size, + finished_requests_ids) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ @@ -776,26 +806,25 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - return tuple(buffer[:, :batch_size] - for buffer in self.mamba_gc_cache_buffer) + return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: self.mamba_cache_indices_mapping.pop(req_id) - def _first_free_index_in_mamba_cache(self) -> int: - if self.mamba_cache: + def _first_free_index_in_mamba_cache( + self, indices_range: Optional[List[int]] = None) -> int: + assert self.mamba_cache is not None + if indices_range is None: max_possible_batch_size = self.mamba_cache[0].shape[1] - occupied = [ - id for seq_ids in self.mamba_cache_indices_mapping.values() - for id in seq_ids.values() - ] - first_free_index = [ - i not in occupied for i in range(max_possible_batch_size) - ].index(True) - return first_free_index - return 0 + indices_range = list(range(max_possible_batch_size)) + all_occupied_indices = self._get_all_occupied_indices() + for i in indices_range: + if i not in all_occupied_indices: + return i + raise Exception("Couldn't find a free spot in the mamba cache! This" + "should never happen") def _get_mamba_cache_shape( self @@ -819,20 +848,18 @@ def _prepare_mamba_cache(self): [layer_type == "mamba" for layer_type in layers_type]) max_batch_size = (_get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE)) + 10 + max(_BATCH_SIZES_TO_CAPTURE) + 2) conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None - for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: - buffer = (torch.empty(size=(mamba_layers, max_batch_size) + - conv_state_shape, - dtype=dtype, - device="cuda"), - torch.empty(size=(mamba_layers, max_batch_size) + - temporal_state_shape, - dtype=dtype, - device="cuda")) - setattr(self, buffername, buffer) + self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda"), + torch.empty(size=(mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda")) def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 913a08ce9f53..2731bddba76d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1711,9 +1711,6 @@ def forward( non_blocking=True) # Run the graph. self.graph.replay() - if "seqlen_agnostic_capture_inputs" in self.input_buffers: - self.model.copy_outputs_after_cuda_graphs(self.input_buffers, - **kwargs) # Return the output tensor. if get_pp_group().is_last_rank: return self.output_buffers["hidden_states"]