diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 44ac05a1430b..9473a33f0ee6 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -4,7 +4,7 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock -from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager, +from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager, AllocStatus) from vllm.utils import Device from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob @@ -15,7 +15,8 @@ def test_block_allocator_allocate(): block_size = 4 num_cpu_blocks = 4 - cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) # Allocate all available cpu blocks. num_free = num_cpu_blocks @@ -24,7 +25,7 @@ def test_block_allocator_allocate(): block = cpu_allocator.allocate() num_free -= 1 - assert block.block_hash not in cpu_allocator.evictor + assert block not in cpu_allocator.free_blocks assert cpu_allocator.get_num_free_blocks() == num_free with pytest.raises(ValueError): @@ -34,14 +35,15 @@ def test_block_allocator_allocate(): def test_block_allocator_free(): block_size = 4 num_cpu_blocks = 4 - cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) # Allocate all available cpu blocks. blocks: List[PhysicalTokenBlock] = [] for _ in range(num_cpu_blocks): block = cpu_allocator.allocate() blocks.append(block) - assert block.block_hash not in cpu_allocator.evictor + assert block not in cpu_allocator.free_blocks # Free all allocated cpu blocks. num_free = 0 @@ -49,7 +51,7 @@ def test_block_allocator_free(): for block in blocks: cpu_allocator.free(block) num_free += 1 - assert block.block_hash in cpu_allocator.evictor + assert block in cpu_allocator.free_blocks assert cpu_allocator.get_num_free_blocks() == num_free with pytest.raises(ValueError): diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index c83551c36ef1..cb61aac3975a 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -4,7 +4,7 @@ """ import pytest -from vllm.core.block_manager import BlockAllocator +from vllm.core.block_manager import CachedBlockAllocator from vllm.utils import Device @@ -15,10 +15,7 @@ def test_block_allocator( num_blocks: int, ): block_hash = 1 - block_allocator = BlockAllocator(Device.CPU, - block_size, - num_blocks, - enable_caching=True) + block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) # Allocate two PysicalTokenBlocks with the same hash and check # that they are the same PhysicalTokenBlock @@ -45,10 +42,7 @@ def test_block_allocator( @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): block_size = 16 - block_allocator = BlockAllocator(Device.CPU, - block_size, - num_blocks, - enable_caching=True) + block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) blocks = [] for i in range(num_blocks): diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b089a5650f4..ad9b557fd9a8 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -3,6 +3,7 @@ from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set, Tuple +from abc import ABC, abstractmethod from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -10,7 +11,7 @@ from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor -class BlockAllocator: +class BlockAllocatorBase(ABC): """Manages free physical token blocks for a device. The allocator maintains a list of free blocks and allocates a block when @@ -18,23 +19,57 @@ class BlockAllocator: the reference count becomes zero, the block is added back to the free list. """ + @abstractmethod def __init__(self, device: Device, block_size: int, num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, - enable_caching: bool = False) -> None: + eviction_policy: EvictionPolicy = EvictionPolicy.LRU): + pass + + @abstractmethod + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + pass + + @abstractmethod + def free(self, block: PhysicalTokenBlock) -> None: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def contains_block(self, block_hash: int) -> bool: + pass + + @abstractmethod + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + pass + + +class CachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__(self, + device: Device, + block_size: int, + num_blocks: int, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: self.device = device self.block_size = block_size self.num_blocks = num_blocks - self.enable_caching = enable_caching self.current_num_blocks = 0 self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} - # Switch over to FIFO eviction when caching is disabled - if not self.enable_caching: - eviction_policy = EvictionPolicy.FIFO self.evictor: Evictor = make_evictor(eviction_policy) self.default_hash_ctr = count() @@ -57,13 +92,6 @@ def allocate_block(self, block_hash: int, def allocate(self, block_hash: Optional[int] = None, num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - # If caching is disabled, just allocate a new block and return it - if not self.enable_caching: - block = self.allocate_block(next(self.default_hash_ctr), - num_hashed_tokens) - block.ref_count += 1 - return block - if block_hash is None: block_hash = next(self.default_hash_ctr) if block_hash in self.evictor: @@ -90,9 +118,8 @@ def free(self, block: PhysicalTokenBlock) -> None: assert block.block_hash not in self.evictor self.evictor.add(block) - # If caching is enabled, remove the block from the cached_blocks - if self.enable_caching: - del self.cached_blocks[block.block_hash] + # Remove the block from the cached_blocks + del self.cached_blocks[block.block_hash] def get_num_free_blocks(self) -> int: return (self.num_blocks - self.current_num_blocks + @@ -102,14 +129,68 @@ def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - # If caching is enabled, update the hash of block and the - # cached_blocks dictionary. - if self.enable_caching: - assert not self.contains_block(block_hash) - old_hash = block.block_hash - block.block_hash = block_hash - del self.cached_blocks[old_hash] - self.cached_blocks[block_hash] = block + # Update the hash of block and the cached_blocks dictionary. + assert not self.contains_block(block_hash) + old_hash = block.block_hash + block.block_hash = block_hash + del self.cached_blocks[old_hash] + self.cached_blocks[block_hash] = block + + +class UncachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: BlockTable = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(device=device, + block_number=i, + block_size=block_size, + block_hash=-1, + num_hashed_tokens=0) + self.free_blocks.append(block) + + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + def contains_block(self, block_hash: int) -> bool: + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") + + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") class AllocStatus(enum.Enum): @@ -142,6 +223,10 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks + if enable_caching and sliding_window is not None: + raise NotImplementedError( + "Sliding window is not allowed with prefix caching enabled!") + self.block_sliding_window = None if sliding_window is not None: assert sliding_window % block_size == 0, (sliding_window, @@ -154,14 +239,17 @@ def __init__( self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, - block_size, - num_gpu_blocks, - enable_caching=enable_caching) - self.cpu_allocator = BlockAllocator(Device.CPU, - block_size, - num_cpu_blocks, - enable_caching=enable_caching) + + if self.enable_caching: + self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, + num_gpu_blocks) + self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, + num_cpu_blocks) + else: + self.gpu_allocator = UncachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = UncachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -198,10 +286,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] - else: + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() block_table.append(block) # Assign the block table for each sequence. @@ -220,8 +314,10 @@ def _promote_last_block( seq: Sequence, last_block: PhysicalTokenBlock, ) -> PhysicalTokenBlock: - # Compute a new hash for the block so that it can be shared by - # other Sequences + assert self.enable_caching + + # Compute a new hash for the block so that it can be shared by other + # Sequences new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) # if new_hash is already in the cached table, then free last_block @@ -254,6 +350,8 @@ def _allocate_last_physical_block( self, seq: Sequence, ) -> PhysicalTokenBlock: + if not self.enable_caching: + return self.gpu_allocator.allocate() block_hash: Optional[int] = None if (self._is_last_block_full(seq)): block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) @@ -293,10 +391,12 @@ def append_slot( assert last_block.device == Device.GPU if last_block.ref_count == 1: # Not shared with other sequences. Appendable. - # If the last block is now complete, promote it to a full block so - # that it can be shared - new_block = self._maybe_promote_last_block(seq, last_block) - block_table[-1] = new_block + if self.enable_caching: + # If the last block is now complete, we may reuse an old block + # to save memory. + maybe_new_block = self._maybe_promote_last_block( + seq, last_block) + block_table[-1] = maybe_new_block return None else: # The last block is shared with other sequences. @@ -440,9 +540,12 @@ def access_all_blocks_in_seq( seq: Sequence, access_time: float, ) -> None: - block_table = self.block_tables[seq.seq_id] - for block in block_table: - block.last_accessed = access_time + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.block_tables[seq.seq_id] + for block in block_table: + block.last_accessed = access_time def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: diff --git a/vllm/core/evictor.py b/vllm/core/evictor.py index 1d81f5a97d71..9f401cba3fbe 100644 --- a/vllm/core/evictor.py +++ b/vllm/core/evictor.py @@ -1,5 +1,5 @@ import enum -from typing import Dict, List, Optional +from typing import Dict from abc import ABC, abstractmethod, abstractproperty from vllm.block import PhysicalTokenBlock @@ -10,7 +10,6 @@ class EvictionPolicy(enum.Enum): Evictor subclass. """ LRU = enum.auto() - FIFO = enum.auto() class Evictor(ABC): @@ -66,37 +65,18 @@ def __contains__(self, block_hash: int) -> bool: # TODO: The performance of this evict function can be optimized further. def evict(self) -> PhysicalTokenBlock: - free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values()) - if len(free_blocks) == 0: + if len(self.free_table) == 0: raise ValueError("No usable cache memory left") + free_blocks = self.free_table.values() - # Find lowest timestamp - lowest_timestamp = free_blocks[0].last_accessed - for block in free_blocks: - if block.last_accessed < lowest_timestamp: - lowest_timestamp = block.last_accessed + # Get evicted block + evicted_block: PhysicalTokenBlock = next(iter(free_blocks)) - # Find all blocks with the lowest timestamp - least_recent: List[PhysicalTokenBlock] = [] for block in free_blocks: - if block.last_accessed == lowest_timestamp: - least_recent.append(block) - - # Find highest prefix count per block - highest_num_hashed_tokens = 0 - for block in least_recent: - if block.num_hashed_tokens > highest_num_hashed_tokens: - highest_num_hashed_tokens = block.num_hashed_tokens - - evicted_block: Optional[PhysicalTokenBlock] = None - - # Find the first block with the lowest timestamp - for block in least_recent: - if block.num_hashed_tokens == highest_num_hashed_tokens: + if (block.last_accessed < evicted_block.last_accessed + or block.last_accessed == evicted_block.last_accessed and + block.num_hashed_tokens > evicted_block.num_hashed_tokens): evicted_block = block - break - - assert evicted_block is not None del self.free_table[evicted_block.block_hash] @@ -119,43 +99,8 @@ def num_blocks(self) -> int: return len(self.free_table) -class RandomEvictor(Evictor): - """Evicts in a first-in-first-out order""" - - def __init__(self): - self.free_table: Dict[int, PhysicalTokenBlock] = {} - - def __contains__(self, block_hash: int) -> bool: - return block_hash in self.free_table - - def evict(self) -> PhysicalTokenBlock: - if len(self.free_table) == 0: - raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block.computed = False - del self.free_table[evicted_block.block_hash] - return evicted_block - - def add(self, block: PhysicalTokenBlock): - self.free_table[block.block_hash] = block - - def remove(self, block_hash: int) -> PhysicalTokenBlock: - if block_hash not in self.free_table: - raise ValueError( - "Attempting to remove block that's not in the evictor") - block: PhysicalTokenBlock = self.free_table[block_hash] - del self.free_table[block_hash] - return block - - @property - def num_blocks(self) -> int: - return len(self.free_table) - - def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: if eviction_policy == EvictionPolicy.LRU: return LRUEvictor() - elif eviction_policy == EvictionPolicy.FIFO: - return RandomEvictor() else: raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")