From c520c762600d8263da9e164bdf83f55e75d57db3 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao Date: Tue, 5 Dec 2023 17:27:19 +0800 Subject: [PATCH] function refactored --- colossalai/inference/kv_cache/block_cache.py | 3 + .../inference/kv_cache/kvcache_manager.py | 77 +++++++++++++------ 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/colossalai/inference/kv_cache/block_cache.py b/colossalai/inference/kv_cache/block_cache.py index 855e228f0660..3aa55717a9fb 100644 --- a/colossalai/inference/kv_cache/block_cache.py +++ b/colossalai/inference/kv_cache/block_cache.py @@ -52,3 +52,6 @@ def clear(self) -> None: assert self.ref_count > 0, f"Block#{self.block_id} has no reference to free." self.ref_count = 0 self.allocated_size = 0 + + def __repr__(self): + return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})" diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index a759b4b741e9..c4f9a88e6ed3 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -95,6 +95,7 @@ def allocate_from_last_block_idx(self, last_block_id: int, space_asked: int = 1) blocks.append(last_block_id) while space_asked > 0: new_block: CacheBlock = self._free_blocks.pop(0) + new_block.add_ref() space_asked = self._allocate_on_block(new_block, space_asked) self._allocated_blocks.append(new_block) self.available_blocks -= 1 @@ -103,37 +104,56 @@ def allocate_from_last_block_idx(self, last_block_id: int, space_asked: int = 1) return blocks - def allocate_from_block_table(self, block_table: torch.Tensor, seq_length: int, space_asked: int = 1) -> None: + def allocate_from_block_table( + self, block_table: torch.Tensor, already_allocated_len: int, space_asked: int = 1 + ) -> None: """Allocate the logical cache blocks for a single sequence. It updates the provided block table with the allocated block(s). Args: block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. - seq_length: The length of the sequence. + already_allocated_len: The number of already-allocated tokens in the sequence. + For a new sequence in prefill stage, it should be 0; in generation stage, it should be the sequence length. space_asked: i.e. The number of tokens to be assigned space for. """ assert block_table.dim() == 1 # the last-allocated block can be fully occupied, and can be occupied only one slot as well - last_allocated_block_local_idx = seq_length // self.block_size + last_allocated_block_idx = already_allocated_len // self.block_size # the right-most block that to be allocated, fully or partially - last_newly_allocated_block_local_idx = (seq_length + space_asked - 1) // self.block_size - last_newly_allocated_block_local_idx = min(last_newly_allocated_block_local_idx, block_table.numel()) - - for i in range(last_allocated_block_local_idx, last_newly_allocated_block_local_idx + 1): - block_global_id = block_table[i].item() - if block_global_id < 0: - assert self.available_blocks > 0, "No available blocks to allocate." - free_block: CacheBlock = self._free_blocks.pop(0) - self._allocated_blocks.append(free_block) - block_global_id = free_block.block_id - self.available_blocks -= 1 - self._block_states[block_global_id] = 0 - block_table[i] = block_global_id - - block: CacheBlock = self._cache_blocks[block_global_id] - space_asked = self._allocate_on_block(self, block, space_asked) - - def free_cache_blocks(self, block_table: torch.Tensor): + last_newly_allocated_block_idx = (already_allocated_len + space_asked - 1) // self.block_size + last_newly_allocated_block_idx = min(last_newly_allocated_block_idx, block_table.numel()) + + for block_local_idx in range(last_allocated_block_idx, last_newly_allocated_block_idx + 1): + space_asked = self.allocate_single_block(block_table, block_local_idx, space_asked) + + def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int: + """Allocate space asked on a single block in the block table, specified by the provided position id. + It updates the provided block table with the allocated block. + + Args: + block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id. + block_local_idx: The index of the block in the block table. + space_asked: i.e. The number of tokens to be assigned space for. + Returns: + The remaining space required to be allocated (in other blocks). + """ + assert block_table.dim() == 1 + block_global_id = block_table[block_local_idx].item() + if block_global_id < 0: + # Allocate a new block if the current position is not assigned a block yet + assert self.available_blocks > 0, "No available blocks to allocate." + block: CacheBlock = self._free_blocks.pop(0) + block.add_ref() + self._allocated_blocks.append(block) + block_global_id = block.block_id + self.available_blocks -= 1 + self._block_states[block_global_id] = 0 + block_table[block_local_idx] = block_global_id + + block: CacheBlock = self._cache_blocks[block_global_id] + return self._allocate_on_block(block, space_asked) + + def free_cache_blocks(self, block_table: torch.Tensor) -> None: """Free the logical cache blocks for **a single sequence**.""" assert block_table.dim() == 1 assert torch.all(block_table >= 0) @@ -141,13 +161,24 @@ def free_cache_blocks(self, block_table: torch.Tensor): global_block_id = block_table[i].item() block: CacheBlock = self._cache_blocks[global_block_id] block.remove_ref() # not going to clear the block thoroughly + if not block.has_ref(): + block.allocated_size = 0 + self._free_blocks.append(block) + self.available_blocks += 1 + self._block_states[global_block_id] = 1 + # NOTE reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine) + block_table[i] = -1 def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get the corresponding tensor for the provided block id for a specific layer.""" return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx] - def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> None: - """Allocate a specific size of space on a cache block.""" + def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int: + """Allocate a specific size of space on a cache block. + + Returns: + The remaining space required to be allocated (in other blocks). + """ available_space = block.available_space() assert available_space > 0, "No available blocks to allocate." space_to_allocate = min(available_space, space_asked)