Skip to content

Commit

Permalink
[Misc][V1] Fix type in v1 prefix caching (#11151)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Dec 13, 2024
1 parent db6c264 commit 78ed8f5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
12 changes: 8 additions & 4 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_prefill():
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
parent_block_hash = block_hash
parent_block_hash = block_hash.hash_value

# Check partial/preallocated block metadata
for block_id in (3, 4):
Expand Down Expand Up @@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks

# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
# Assume all computed.
manager.append_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)

# Append 1 block.
blocks = manager.append_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks


Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def _cache_full_blocks(
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
prev_block_hash = (prev_block.block_hash
if prev_block is not None else None)
prev_block_hash_value = (prev_block.block_hash.hash_value
if prev_block is not None else None)
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i

Expand All @@ -390,10 +390,10 @@ def _cache_full_blocks(
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash,
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash
prev_block_hash_value = block_hash.hash_value
22 changes: 15 additions & 7 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple

from vllm.logger import init_logger

logger = init_logger(__name__)

BlockHashType = Tuple[int, Tuple[int]]

class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]


@dataclass
Expand Down Expand Up @@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return (hash(
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)


def hash_request_tokens(block_size: int,
Expand All @@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
ret = []
parent_block_hash = None
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
parent_block_hash_value = block_hash.hash_value
return ret

0 comments on commit 78ed8f5

Please sign in to comment.