diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5f4d58dd5fd3..f5a108e232ba 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -380,5 +380,131 @@ def create_immutable_chain( prev_block = allocator.allocate_immutable( prev_block=prev_block, token_ids=block_token_ids) blocks.append(prev_block) - return blocks + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order(num_blocks: int, block_size: int, seed: int): + """Verify sharing occurs by allocating two sequences that share prefixes + and incrementally freeing blocks. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:block_size], + allocator=allocator, + ) + + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - 1) + + # Set the last accessed time of the first block to 1 + allocator.access_all_blocks(1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[block_size:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + # Free the one block in the first chain + assert len(first_chain) == 1 + first_block_id = first_chain[0].block_id + allocator.free(first_chain[0]) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + allocator.access_all_blocks(2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == first_block_id + + # Test case where two last accessed times are equal + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order_num_tokens(num_blocks: int, block_size: int, + seed: int): + """Verify sharing occurs by allocating two sequences that share prefixes + and incrementally freeing blocks. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + num_blocks_in_first_chain = 2 + num_tokens_in_first_chain = block_size * num_blocks_in_first_chain + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:num_tokens_in_first_chain], + allocator=allocator, + ) + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_in_first_chain) + + # Set the last accessed time of the first block to 1 + allocator.access_all_blocks(1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[num_tokens_in_first_chain:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + assert len(first_chain) == num_blocks_in_first_chain + last_block_id = first_chain[-1].block_id + # Free each block in the first chain. + for i, block in enumerate(first_chain): + allocator.free(block) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + allocator.access_all_blocks(2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == last_block_id diff --git a/tests/models/test_models.py b/tests/models/test_models.py index cfe2539e3a05..7ffc5bbe73dc 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -38,7 +38,7 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype) + vllm_model = vllm_runner(model, dtype=dtype, enable_prefix_caching=True) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3135e194c593..0ad8f3234336 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -195,6 +195,11 @@ def mark_blocks_as_computed(self) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed() + def access_all_blocks_in_seq(self, seq: List[int], now: float) -> None: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].access_all_blocks_in_seq(seq, now) + def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8abb80..f2e7971d9d05 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,17 +1,80 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, OrderedDict, Tuple from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block import DEFAULT_LAST_ACCESSED_TIME PrefixHash = int BlockId = int +class BlockMetaData: + + def __init__( + self, + computed: bool = False, + last_accessed: float = DEFAULT_LAST_ACCESSED_TIME, + token_ids_len: int = 0, + ): + self.computed: bool = computed + self.last_accessed: float = last_accessed + self.token_ids_len: int = token_ids_len + + +class LRUEvictorV2: + + def __init__(self): + self._free_blocks: OrderedDict[PrefixHash, BlockId] = OrderedDict() + + def __contains__(self, block_hash: PrefixHash) -> bool: + return block_hash in self._free_blocks + + def evict( + self, block_meta_data: Dict[BlockId, BlockMetaData] + ) -> Tuple[BlockId, PrefixHash]: + evicted_block_hash, evicted_block_id = next( + iter(self._free_blocks.items())) + + # Find the block with the "lowest" timestamp + for block_hash, block_id in self._free_blocks.items(): + evicted_block_id = self._free_blocks[evicted_block_hash] + evicted_block_last_accessed = block_meta_data[ + evicted_block_id].last_accessed + block_last_accessed = block_meta_data[block_id].last_accessed + + if evicted_block_last_accessed < block_last_accessed: + break + + # If two blocks have the same last_accessed_time, evict the one + # with more token_ids + if block_meta_data[block_id].token_ids_len > block_meta_data[ + evicted_block_id].token_ids_len: + evicted_block_hash = block_hash + evicted_block_id = block_id + assert block_last_accessed == evicted_block_last_accessed + + del self._free_blocks[evicted_block_hash] + return evicted_block_id, evicted_block_hash + + def add(self, block_hash: PrefixHash, block_id: BlockId): + self._free_blocks[block_hash] = block_id + + def remove(self, block_hash: PrefixHash): + if block_hash not in self._free_blocks: + raise ValueError( + "Attempting to remove block that's not in the evictor") + del self._free_blocks[block_hash] + + @property + def num_blocks(self) -> int: + return len(self._free_blocks) + + class PrefixCachingBlockAllocator(BlockAllocator): """A block allocator that implements prefix caching. @@ -37,12 +100,11 @@ def __init__( ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. - self._cached_blocks: Dict[PrefixHash, BlockId] = {} + self._cached_blocks: OrderedDict[PrefixHash, BlockId] = OrderedDict() - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset - # of self._cached_blocks. - self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + # A mapping of block index to block metadata. All blocks that are in + # _cached_blocks will be in this dict. + self._block_meta_data: Dict[BlockId, BlockMetaData] = {} # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( @@ -53,6 +115,7 @@ def __init__( ) self._block_size = block_size + self._evictor = LRUEvictorV2() # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable @@ -116,7 +179,6 @@ def allocate_immutable(self, prev_block: Optional[Block], block = self.allocate_mutable(prev_block) block.append_token_ids(token_ids) assert block.content_hash is not None - # TODO computed bit return block @@ -139,14 +201,14 @@ def allocate_mutable(self, prev_block: Block) -> Block: # We must check the unused cached blocks before raising OOM. pass - if self._unused_cached_blocks: - # TODO policy for selecting block to remove - content_hash_to_evict = next(iter(self._unused_cached_blocks)) + if self._evictor.num_blocks > 0: + block_id, content_hash_to_evict = self._evictor.evict( + self._block_meta_data) # Clear content hash mapping; the block will be overwritten. del self._cached_blocks[content_hash_to_evict] + del self._block_meta_data[block_id] - block_id = self._unused_cached_blocks.pop(content_hash_to_evict) refcount = self._refcounter.incr(block_id) assert refcount == 1 block = self._create_block( @@ -166,8 +228,8 @@ def _incr_refcount_cached_block(self, content_hash: int, block_id: BlockId) -> None: refcount = self._refcounter.incr(block_id) if refcount == 1: - assert content_hash in self._unused_cached_blocks - del self._unused_cached_blocks[content_hash] + assert content_hash in self._evictor + self._evictor.remove(content_hash) def free(self, block: Block) -> None: """Decrement the refcount of the block. If the decremented refcount is @@ -193,9 +255,11 @@ def _free_block_id_for_block(self, block_id: BlockId, # If no longer used, add the block to the unused cached blocks. if refcount == 0: - assert block.content_hash not in self._unused_cached_blocks + assert block.content_hash not in self._evictor assert block.content_hash in self._cached_blocks - self._unused_cached_blocks[block.content_hash] = block_id + self._evictor.add(block.content_hash, block.block_id) + self._block_meta_data[ + block_id].token_ids_len = block.num_tokens_total def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -230,9 +294,9 @@ def fork(self, last_block: Block) -> List[Block]: def get_num_free_blocks(self) -> int: # The number of free blocks is the number of hashless free blocks - # plus the number of hashful blocks that are unused. - return self._hashless_allocator.get_num_free_blocks() + len( - self._unused_cached_blocks) + # plus the number of hashful blocks that are in the evictor. + return self._hashless_allocator.get_num_free_blocks( + ) + self._evictor.num_blocks @property def all_block_ids(self) -> frozenset[int]: @@ -263,8 +327,10 @@ def promote_to_immutable_block(self, # set this block as the cached block. if block.content_hash not in self._cached_blocks: self._cached_blocks[block.content_hash] = block.block_id + self._block_meta_data[block.block_id] = BlockMetaData( + token_ids_len=block.num_tokens_total) else: - self._free_block_id_for_block(block.block_id, block) + # Otherwise, increment the ref count of the cached block self._incr_refcount_cached_block( block.content_hash, self._cached_blocks[block.content_hash]) @@ -295,8 +361,22 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: def mark_blocks_as_computed(self) -> None: """Mark blocks as computed, used in prefix caching.""" - # TODO Track computed blocks. - pass + # Reverse iterate over the blocks in _cached_blocks and break the + # first time we find a computed block. + for _, block_id in reversed(self._cached_blocks.items()): + if self._block_meta_data[block_id].computed: + break + self._block_meta_data[block_id].computed = True + + def get_all_computed_blocks(self, seq: List[int]) -> List[int]: + + def computed(block_id): + return self._block_meta_data[block_id].computed + + return [ + block_id for block_id in takewhile( + lambda block_id: computed(block_id), seq[:-1]) + ] def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: @@ -305,18 +385,21 @@ def get_common_computed_block_ids( Used in prefill (can skip prefill of some blocks). """ - # TODO: Track computed blocks. - computed = lambda block_id: False - # NOTE We exclude the last block to avoid the case where the entire # prompt is cached. This would cause erroneous behavior in model # runner. - ids_list = [ - takewhile(lambda block_id: computed(block_id), seq[:-1]) - for seq in seq_block_ids - ] + ids_list = [self.get_all_computed_blocks(seq) for seq in seq_block_ids] return commonprefix([ids for ids in ids_list if ids != []]) + def access_all_blocks_in_seq(self, seq_block_ids: List[int], now: float): + for block_id in seq_block_ids: + # If there's a mutable block at the end of the sequence it won't + # be in _block_meta_data yet since it's not cached. + if block_id not in self._block_meta_data: + assert block_id == seq_block_ids[-1] + else: + self._block_meta_data[block_id].last_accessed = now + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. @@ -349,9 +432,11 @@ def __init__( assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block + self._prev_num_tokens = 0 + if prev_block: + self._prev_num_tokens = prev_block.num_tokens_total self._cached_content_hash: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator - self._block = NaiveBlock( prev_block=prev_block, token_ids=token_ids, @@ -382,6 +467,10 @@ def append_token_ids(self, token_ids: List[int]) -> None: self.block_id = (self._prefix_caching_allocator. promote_to_immutable_block(self)) + @property + def num_tokens_total(self) -> int: + return self._prev_num_tokens + len(self.token_ids) + @property def block_id(self) -> Optional[int]: return self._block.block_id diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 19f0cf415eb3..e49213568336 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -64,7 +64,6 @@ def __init__( self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - assert sliding_window is None, "Sliding window not yet supported" self.block_sliding_window = None @@ -72,14 +71,15 @@ def __init__( self.watermark = watermark assert watermark >= 0.0 - assert not enable_caching, "Prefix caching not yet supported" self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) + allocator_type = "naive" + if enable_caching: + allocator_type = "prefix_caching" self.block_allocator = CpuGpuBlockAllocator.create( - # Currently, only naive blocks are supported (no prefix caching). - allocator_type="naive", + allocator_type=allocator_type, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, @@ -194,17 +194,18 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert all(b is not None for b in block_ids) return block_ids - def access_all_blocks_in_seq(self, seq, now): - # TODO add prefix caching support. - # Tracked here https://github.com/vllm-project/vllm/issues/3667 - pass + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + if self.enable_caching: + self.block_allocator.access_all_blocks_in_seq( + self.get_block_table(seq), now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): # We ignore the sequence group as its not necessary. After the batch is # formed by the scheduler, we do not need to mark blocks from individual # sequence groups as computed -- all blocks in the batch can be marked # as computed. - self.block_allocator.mark_blocks_as_computed() + if self.enable_caching: + self.block_allocator.mark_blocks_as_computed() def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: @@ -217,6 +218,8 @@ def get_common_computed_block_ids( This method determines which blocks can be safely skipped for all sequences in the sequence group. """ + if not self.enable_caching: + return [] seq_block_ids = [ self.block_tables[seq.seq_id].physical_block_ids for seq in seqs ]