Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][V1] Fix type in v1 prefix caching #11151

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_prefill():
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
parent_block_hash = block_hash
parent_block_hash = block_hash.hash_value

# Check partial/preallocated block metadata
for block_id in (3, 4):
Expand Down Expand Up @@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks

# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
# Assume all computed.
manager.append_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)

# Append 1 block.
blocks = manager.append_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks


Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def _cache_full_blocks(
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
prev_block_hash = (prev_block.block_hash
if prev_block is not None else None)
prev_block_hash_value = (prev_block.block_hash.hash_value
if prev_block is not None else None)
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i

Expand All @@ -390,10 +390,10 @@ def _cache_full_blocks(
f"{request.request_id}({request})")

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash,
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))

# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash
prev_block_hash_value = block_hash.hash_value
22 changes: 15 additions & 7 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple

from vllm.logger import init_logger

logger = init_logger(__name__)

BlockHashType = Tuple[int, Tuple[int]]

class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]


@dataclass
Expand Down Expand Up @@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return (hash(
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)


def hash_request_tokens(block_size: int,
Expand All @@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
ret = []
parent_block_hash = None
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
parent_block_hash_value = block_hash.hash_value
return ret
Loading