Skip to content

Commit

Permalink
[Model][Jamba] Mamba cache single buffer (#6739)
Browse files Browse the repository at this point in the history
Co-authored-by: Mor Zusman <morz@ai21.com>
  • Loading branch information
mzusman and Mor Zusman authored Aug 9, 2024
1 parent b4e9528 commit 07ab160
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 124 deletions.
269 changes: 148 additions & 121 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,12 +609,8 @@ def __init__(
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
# Current step used indices
self.current_indices: List[int] = []
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Used as an input_buffer for the CUDA graph runs.
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
Expand Down Expand Up @@ -644,95 +640,148 @@ def forward(self,
batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids)
(
current_seqlen_agnostic_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size,
finished_requests_ids)
mamba_cache = self._prepare_current_run_mamba_cache(
request_ids_to_seq_ids, batch_size, finished_requests_ids)
else:
# CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = (
kwargs["seqlen_agnostic_capture_inputs"],
[],
)
self.current_indices = indices
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]

hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata,
current_seqlen_agnostic_cache[0],
current_seqlen_agnostic_cache[1])

if "seqlen_agnostic_capture_inputs" not in kwargs:
self._copy_mamba_cache_by_indices(self.current_indices,
current_seqlen_agnostic_cache)

attn_metadata, mamba_cache[0],
mamba_cache[1])
return hidden_states

def _copy_mamba_cache_by_indices(
self, indices: List[int],
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
for i, offset in enumerate(indices):
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
def _swap_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]

def _copy_mamba_cache(self, index_to: int, index_from: int,
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True)

def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
seqs_id: List[int]) -> List[int]:
indices_for_current_run = []
for seq_id in seqs_id:
if cur_rid not in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping[cur_rid] = {}
first_free_index = self._first_free_index_in_mamba_cache()
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
first_free_index = self._first_free_index_in_mamba_cache()
index_exist = list(seq_ids2indices.values())[0]
self._copy_mamba_cache(index_from=index_exist,
index_to=first_free_index,
from_buffer=self.mamba_cache)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
else:
index_for_current_run = self.mamba_cache_indices_mapping[
cur_rid][seq_id]

indices_for_current_run.append(index_for_current_run)
return indices_for_current_run
def _move_out_if_already_occupied(self, index: int,
all_occupied_indices: List[int]):
if index in all_occupied_indices:
first_free_index = self._first_free_index_in_mamba_cache()
# In case occupied, move the occupied to a new empty block
self._move_cache_index_and_mappings(from_index=index,
to_index=first_free_index)

def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)

def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
finished_requests_ids: List[str]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
indices_for_current_run = []
for request_id, seqs_id in request_ids_to_seq_ids.items():
self, request_ids_to_seq_ids: Dict[str, list[int]],
batch_size: int, finished_requests_ids: List[str]):
running_indices = []
request_ids_to_seq_ids_flatten = [
(req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids:
# Do not allocate cache for requests that run
# Do not allocate cache index for requests that run
# and finish right after
continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
request_id, seqs_id)
## Pad the batch in case of running batch that was not captured via CG
padded_indices = indices_for_current_run.copy()
pad_index = self._first_free_index_in_mamba_cache()
self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seq_id, dest_index)
running_indices.append(dest_index)

for _ in range(batch_size - len(indices_for_current_run)):
padded_indices.append(pad_index)
self._clean_up_first_bs_blocks(batch_size, running_indices)
conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]

conv_state = self.mamba_cache[0][:, padded_indices]
temporal_state = self.mamba_cache[1][:, padded_indices]
return (conv_state, temporal_state)

return (conv_state, temporal_state), indices_for_current_run
def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]

def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)

def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)

def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)

def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})

def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
"""
Expand All @@ -747,55 +796,35 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0]
(
current_mamba_cache,
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self.current_indices = indices

for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
current_mamba_cache):
input_buffer.copy_(current_cache_buffer, non_blocking=True)

def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self._copy_mamba_cache_by_indices(
self.current_indices,
input_buffers["seqlen_agnostic_capture_inputs"])
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)

def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
"""
Provide the CUDA graph capture runs with a buffer in adjusted size.
The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs.
"""
return tuple(buffer[:, :batch_size]
for buffer in self.mamba_gc_cache_buffer)
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)

def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id)

def _first_free_index_in_mamba_cache(self) -> int:
if self.mamba_cache:
def _first_free_index_in_mamba_cache(
self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1]
occupied = [
id for seq_ids in self.mamba_cache_indices_mapping.values()
for id in seq_ids.values()
]
first_free_index = [
i not in occupied for i in range(max_possible_batch_size)
].index(True)
return first_free_index
return 0
indices_range = list(range(max_possible_batch_size))
all_occupied_indices = self._get_all_occupied_indices()
for i in indices_range:
if i not in all_occupied_indices:
return i
raise Exception("Couldn't find a free spot in the mamba cache! This"
"should never happen")

def _get_mamba_cache_shape(
self
Expand All @@ -819,20 +848,18 @@ def _prepare_mamba_cache(self):
[layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10
max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None

for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))
setattr(self, buffername, buffer)
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
conv_state_shape,
dtype=dtype,
device="cuda"),
torch.empty(size=(mamba_layers, max_batch_size) +
temporal_state_shape,
dtype=dtype,
device="cuda"))

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,9 +1711,6 @@ def forward(
non_blocking=True)
# Run the graph.
self.graph.replay()
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
**kwargs)
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
Expand Down

0 comments on commit 07ab160

Please sign in to comment.