Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled #3357

Merged
merged 20 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager,
from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager,
AllocStatus)
from vllm.utils import Device
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob
Expand All @@ -15,7 +15,8 @@
def test_block_allocator_allocate():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)

# Allocate all available cpu blocks.
num_free = num_cpu_blocks
Expand All @@ -24,7 +25,7 @@ def test_block_allocator_allocate():
block = cpu_allocator.allocate()
num_free -= 1

assert block.block_hash not in cpu_allocator.evictor
assert block not in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free

with pytest.raises(ValueError):
Expand All @@ -34,22 +35,23 @@ def test_block_allocator_allocate():
def test_block_allocator_free():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)

# Allocate all available cpu blocks.
blocks: List[PhysicalTokenBlock] = []
for _ in range(num_cpu_blocks):
block = cpu_allocator.allocate()
blocks.append(block)
assert block.block_hash not in cpu_allocator.evictor
assert block not in cpu_allocator.free_blocks

# Free all allocated cpu blocks.
num_free = 0
assert cpu_allocator.get_num_free_blocks() == num_free
for block in blocks:
cpu_allocator.free(block)
num_free += 1
assert block.block_hash in cpu_allocator.evictor
assert block in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free

with pytest.raises(ValueError):
Expand Down
12 changes: 3 additions & 9 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import pytest

from vllm.core.block_manager import BlockAllocator
from vllm.core.block_manager import CachedBlockAllocator
from vllm.utils import Device


Expand All @@ -15,10 +15,7 @@ def test_block_allocator(
num_blocks: int,
):
block_hash = 1
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)

# Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock
Expand All @@ -45,10 +42,7 @@ def test_block_allocator(
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
blocks = []

for i in range(num_blocks):
Expand Down
189 changes: 146 additions & 43 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,73 @@
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple
from abc import ABC, abstractmethod

from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor


class BlockAllocator:
class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.

The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""

@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
enable_caching: bool = False) -> None:
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass

@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass

@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass

@abstractmethod
def get_num_free_blocks(self) -> int:
pass

@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass

@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass


class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.

The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""

def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.enable_caching = enable_caching

self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}

# Switch over to FIFO eviction when caching is disabled
if not self.enable_caching:
eviction_policy = EvictionPolicy.FIFO
self.evictor: Evictor = make_evictor(eviction_policy)

self.default_hash_ctr = count()
Expand All @@ -57,13 +92,6 @@ def allocate_block(self, block_hash: int,
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
# If caching is disabled, just allocate a new block and return it
if not self.enable_caching:
block = self.allocate_block(next(self.default_hash_ctr),
num_hashed_tokens)
block.ref_count += 1
return block

if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
Expand All @@ -90,9 +118,8 @@ def free(self, block: PhysicalTokenBlock) -> None:
assert block.block_hash not in self.evictor
self.evictor.add(block)

# If caching is enabled, remove the block from the cached_blocks
if self.enable_caching:
del self.cached_blocks[block.block_hash]
# Remove the block from the cached_blocks
del self.cached_blocks[block.block_hash]

def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
Expand All @@ -102,14 +129,68 @@ def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor

def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the
# cached_blocks dictionary.
if self.enable_caching:
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
# Update the hash of block and the cached_blocks dictionary.
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block


class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.

The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""

def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks

# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block)

def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block

def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)

def get_num_free_blocks(self) -> int:
return len(self.free_blocks)

def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")

def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")


class AllocStatus(enum.Enum):
Expand Down Expand Up @@ -142,6 +223,10 @@ def __init__(
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks

if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")

self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
Expand All @@ -154,14 +239,17 @@ def __init__(
self.enable_caching = enable_caching

self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU,
block_size,
num_gpu_blocks,
enable_caching=enable_caching)
self.cpu_allocator = BlockAllocator(Device.CPU,
block_size,
num_cpu_blocks,
enable_caching=enable_caching)

if self.enable_caching:
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}

Expand Down Expand Up @@ -198,10 +286,16 @@ def allocate(self, seq_group: SequenceGroup) -> None:
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does prefix caching work with sliding window now? Should we explicitly check somewhere that if we enable caching, sliding window should not be enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prefix caching functionality is simply not used when we have sliding windows. We have specific checks for that in different places in the code. Putting it in a more central place sounds like a better idea though, and less confusing.

block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)

# Assign the block table for each sequence.
Expand All @@ -220,8 +314,10 @@ def _promote_last_block(
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by
# other Sequences
assert self.enable_caching

# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)

# if new_hash is already in the cached table, then free last_block
Expand Down Expand Up @@ -254,6 +350,8 @@ def _allocate_last_physical_block(
self,
seq: Sequence,
) -> PhysicalTokenBlock:
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
Expand Down Expand Up @@ -293,10 +391,12 @@ def append_slot(
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so
# that it can be shared
new_block = self._maybe_promote_last_block(seq, last_block)
block_table[-1] = new_block
if self.enable_caching:
# If the last block is now complete, we may reuse an old block
# to save memory.
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return None
else:
# The last block is shared with other sequences.
Expand Down Expand Up @@ -440,9 +540,12 @@ def access_all_blocks_in_seq(
seq: Sequence,
access_time: float,
) -> None:
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
if self.enable_caching:
# Update the last accessed time of all the blocks accessed
# in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
Expand Down
Loading
Loading