Skip to content

Commit

Permalink
function refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanheng-zhao committed Dec 5, 2023
1 parent 1df4123 commit c520c76
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
3 changes: 3 additions & 0 deletions colossalai/inference/kv_cache/block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def clear(self) -> None:
assert self.ref_count > 0, f"Block#{self.block_id} has no reference to free."
self.ref_count = 0
self.allocated_size = 0

def __repr__(self):
return f"CacheBlock#{self.block_id}(ref#{self.ref_count}, allocated#{self.allocated_size})"
77 changes: 54 additions & 23 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def allocate_from_last_block_idx(self, last_block_id: int, space_asked: int = 1)
blocks.append(last_block_id)
while space_asked > 0:
new_block: CacheBlock = self._free_blocks.pop(0)
new_block.add_ref()
space_asked = self._allocate_on_block(new_block, space_asked)
self._allocated_blocks.append(new_block)
self.available_blocks -= 1
Expand All @@ -103,51 +104,81 @@ def allocate_from_last_block_idx(self, last_block_id: int, space_asked: int = 1)

return blocks

def allocate_from_block_table(self, block_table: torch.Tensor, seq_length: int, space_asked: int = 1) -> None:
def allocate_from_block_table(
self, block_table: torch.Tensor, already_allocated_len: int, space_asked: int = 1
) -> None:
"""Allocate the logical cache blocks for a single sequence.
It updates the provided block table with the allocated block(s).
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
seq_length: The length of the sequence.
already_allocated_len: The number of already-allocated tokens in the sequence.
For a new sequence in prefill stage, it should be 0; in generation stage, it should be the sequence length.
space_asked: i.e. The number of tokens to be assigned space for.
"""
assert block_table.dim() == 1
# the last-allocated block can be fully occupied, and can be occupied only one slot as well
last_allocated_block_local_idx = seq_length // self.block_size
last_allocated_block_idx = already_allocated_len // self.block_size
# the right-most block that to be allocated, fully or partially
last_newly_allocated_block_local_idx = (seq_length + space_asked - 1) // self.block_size
last_newly_allocated_block_local_idx = min(last_newly_allocated_block_local_idx, block_table.numel())

for i in range(last_allocated_block_local_idx, last_newly_allocated_block_local_idx + 1):
block_global_id = block_table[i].item()
if block_global_id < 0:
assert self.available_blocks > 0, "No available blocks to allocate."
free_block: CacheBlock = self._free_blocks.pop(0)
self._allocated_blocks.append(free_block)
block_global_id = free_block.block_id
self.available_blocks -= 1
self._block_states[block_global_id] = 0
block_table[i] = block_global_id

block: CacheBlock = self._cache_blocks[block_global_id]
space_asked = self._allocate_on_block(self, block, space_asked)

def free_cache_blocks(self, block_table: torch.Tensor):
last_newly_allocated_block_idx = (already_allocated_len + space_asked - 1) // self.block_size
last_newly_allocated_block_idx = min(last_newly_allocated_block_idx, block_table.numel())

for block_local_idx in range(last_allocated_block_idx, last_newly_allocated_block_idx + 1):
space_asked = self.allocate_single_block(block_table, block_local_idx, space_asked)

def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int, space_asked: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id.
It updates the provided block table with the allocated block.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], storing mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for.
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block_table.dim() == 1
block_global_id = block_table[block_local_idx].item()
if block_global_id < 0:
# Allocate a new block if the current position is not assigned a block yet
assert self.available_blocks > 0, "No available blocks to allocate."
block: CacheBlock = self._free_blocks.pop(0)
block.add_ref()
self._allocated_blocks.append(block)
block_global_id = block.block_id
self.available_blocks -= 1
self._block_states[block_global_id] = 0
block_table[block_local_idx] = block_global_id

block: CacheBlock = self._cache_blocks[block_global_id]
return self._allocate_on_block(block, space_asked)

def free_cache_blocks(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**."""
assert block_table.dim() == 1
assert torch.all(block_table >= 0)
for i in range(block_table.numel()):
global_block_id = block_table[i].item()
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref() # not going to clear the block thoroughly
if not block.has_ref():
block.allocated_size = 0
self._free_blocks.append(block)
self.available_blocks += 1
self._block_states[global_block_id] = 1
# NOTE reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
block_table[i] = -1

def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the corresponding tensor for the provided block id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]

def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> None:
"""Allocate a specific size of space on a cache block."""
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
"""Allocate a specific size of space on a cache block.
Returns:
The remaining space required to be allocated (in other blocks).
"""
available_space = block.available_space()
assert available_space > 0, "No available blocks to allocate."
space_to_allocate = min(available_space, space_asked)
Expand Down

0 comments on commit c520c76

Please sign in to comment.