Skip to content

Commit

Permalink
[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCache…
Browse files Browse the repository at this point in the history
…Manager (vllm-project#12003)
  • Loading branch information
heheda12345 authored and jikunshang committed Jan 21, 2025
1 parent bcec922 commit c4e26c6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 35 deletions.
71 changes: 47 additions & 24 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def test_prefill():
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

Expand All @@ -73,9 +74,10 @@ def test_prefill():
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
Expand All @@ -91,7 +93,7 @@ def test_prefill():
# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8)]
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
Expand All @@ -103,9 +105,10 @@ def test_prefill():
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
Expand All @@ -123,8 +126,9 @@ def test_prefill():

# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
Expand All @@ -150,8 +154,9 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]

Expand Down Expand Up @@ -197,16 +202,18 @@ def test_evict():

last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated

# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
Expand All @@ -222,8 +229,9 @@ def test_evict():

# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
Expand All @@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1

Expand All @@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1

Expand All @@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0

# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
Expand All @@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
assert num_computed_tokens == block_size

blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
Expand All @@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():

req1 = make_request("1", list(range(10))) # 2 blocks and some more

computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3

Expand All @@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():

# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4

# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks

Expand All @@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
Expand Down Expand Up @@ -469,10 +486,11 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
Expand Down Expand Up @@ -503,8 +521,9 @@ def test_mm_prefix_caching():
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
assert num_computed_tokens == 3 * 16


def test_prefill_not_enough_free_blocks_with_computed_blocks():
Expand All @@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]

# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks = manager.get_computed_blocks(req1)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
Expand All @@ -547,17 +568,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks = manager.get_computed_blocks(req2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)

# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks = manager.get_computed_blocks(req3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
# Block 0-2 are used by Req 1.
Expand Down
17 changes: 12 additions & 5 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional, Tuple

from vllm.logger import init_logger
from vllm.utils import cdiv
Expand Down Expand Up @@ -69,19 +69,22 @@ def __init__(
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}

def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return []
return [], 0

computed_blocks = []

Expand All @@ -101,7 +104,11 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
else:
break

return computed_blocks
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens

def append_slots(
self,
Expand Down
8 changes: 2 additions & 6 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,8 @@ def schedule(self) -> "SchedulerOutput":

request = self.waiting[0]
# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)
# Number of tokens to be scheduled.
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests,
Expand Down

0 comments on commit c4e26c6

Please sign in to comment.