diff --git a/vllm/block.py b/vllm/block.py index bd00c07adc0d..0b8ef7d4b73d 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,90 +1,10 @@ """Token blocks.""" -import weakref -from collections import defaultdict -from typing import Dict, List +from typing import List from vllm.utils import Device -_BLANK_TOKEN_ID = -1 - DEFAULT_LAST_ACCESSED_TIME = -1 -TokensBlock = List[int] - - -class BlockPool: - """A pool of logical blocks. - When requests come, we create a lot of logical blocks; - when requests are done, we destroy a lot of logical blocks. - It turns out that creating and destroying logical blocks can be expensive, - especially for the `token_ids` field, which is a list of integers. - To avoid this overhead, we use a pool to manage the logical blocks. - When an old request is done and a new request comes, we can reuse the - logical blocks from the old request to feed the new request. - """ - - def __init__(self) -> None: - # block size to list of token blocks - self.pool: Dict[int, List[TokensBlock]] = defaultdict(list) - - def alloc_block(self, block_size: int) -> TokensBlock: - if block_size in self.pool and self.pool[block_size]: - return self.pool[block_size].pop() - return [_BLANK_TOKEN_ID] * block_size - - def del_block(self, block: TokensBlock) -> None: - self.pool[len(block)].append(block) - - -_BLOCK_POOL = BlockPool() - - -class LogicalTokenBlock: - """A block that stores a contiguous chunk of tokens from left to right. - - Logical blocks are used to represent the states of the corresponding - physical blocks in the KV cache. - """ - - def __init__( - self, - block_number: int, - block_size: int, - ) -> None: - self.block_number = block_number - self.block_size = block_size - - self.token_ids = _BLOCK_POOL.alloc_block(block_size) - # this finalizer is used to return the block to the pool when the object is deleted # noqa - # NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa - # i.e. `self.token_ids` may be deleted before `self`, and we lose - # the opportunity to return the block to the pool - self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block, - self.token_ids) - self.num_tokens = 0 - - def is_empty(self) -> bool: - return self.num_tokens == 0 - - def get_num_empty_slots(self) -> int: - return self.block_size - self.num_tokens - - def is_full(self) -> bool: - return self.num_tokens == self.block_size - - def append_tokens(self, token_ids: List[int]) -> None: - assert len(token_ids) <= self.get_num_empty_slots() - curr_idx = self.num_tokens - self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids - self.num_tokens += len(token_ids) - - def get_token_ids(self) -> List[int]: - return self.token_ids[:self.num_tokens] - - def get_last_token_id(self) -> int: - assert self.num_tokens > 0 - return self.token_ids[self.num_tokens - 1] - class PhysicalTokenBlock: """Represents the state of a block in the KV cache.""" diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 4010aaf02b82..995ea04a5b3d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -262,8 +262,7 @@ def __init__( self.cross_block_tables: Dict[str, BlockTable] = {} def _get_seq_num_required_blocks(self, seq: Sequence) -> int: - return 0 if seq is None \ - else len(seq.logical_token_blocks) + return 0 if seq is None else seq.n_blocks def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -298,7 +297,7 @@ def _allocate_sequence(self, \ ref_count: int, \ is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = len(seq.logical_token_blocks) + num_prompt_blocks = seq.n_blocks block_table: BlockTable = [] for logical_idx in range(num_prompt_blocks): @@ -367,7 +366,7 @@ def _promote_last_block( # Compute a new hash for the block so that it can be shared by other # Sequences - new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) + new_hash = seq.hash_of_block(seq.n_blocks - 1) # if new_hash is already in the cached table, then free last_block # and return the cached version @@ -407,10 +406,10 @@ def _allocate_last_physical_block( if not self.enable_caching: return self.gpu_allocator.allocate() block_hash: Optional[int] = None + n_blocks = seq.n_blocks if (self._is_last_block_full(seq)): - block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) - num_hashed_tokens = seq.num_hashed_tokens_of_block( - len(seq.logical_token_blocks) - 1) + block_hash = seq.hash_of_block(n_blocks - 1) + num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1) # num_hashed_tokens is used to compute future hashes # (e.g. in the hashing function, it is used to ask the sequence for @@ -429,12 +428,12 @@ def append_slots( num_lookahead_slots: int = 0, ) -> List[Tuple[int, int]]: """Allocate a physical slot for a new token.""" - logical_blocks = seq.logical_token_blocks + n_blocks = seq.n_blocks block_table = self.block_tables[seq.seq_id] # If we need to allocate a new physical block - if len(block_table) < len(logical_blocks): + if len(block_table) < n_blocks: # Currently this code only supports adding one physical block - assert len(block_table) == len(logical_blocks) - 1 + assert len(block_table) == n_blocks - 1 if (self.block_sliding_window and len(block_table) >= self.block_sliding_window): diff --git a/vllm/sequence.py b/vllm/sequence.py index 0925d15461fd..c618c3692611 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,13 +1,13 @@ """Sequence and its related classes.""" import copy import enum +import math from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch -from vllm.block import LogicalTokenBlock from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -236,9 +236,6 @@ def __init__( self.output_logprobs: SampleLogprobs = [] self.output_text = "" - self.logical_token_blocks: List[LogicalTokenBlock] = [] - # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -248,6 +245,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def n_blocks(self) -> int: + return math.ceil(self.get_len() / self.block_size) + @property def prompt(self) -> Optional[str]: return self.inputs.get("prompt") @@ -287,36 +288,12 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def _append_logical_block(self) -> None: - block = LogicalTokenBlock( - block_number=len(self.logical_token_blocks), - block_size=self.block_size, - ) - self.logical_token_blocks.append(block) - - def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: - cursor = 0 - while cursor < len(token_ids): - if not self.logical_token_blocks: - self._append_logical_block() - - last_block = self.logical_token_blocks[-1] - if last_block.is_full(): - self._append_logical_block() - last_block = self.logical_token_blocks[-1] - - num_empty_slots = last_block.get_num_empty_slots() - last_block.append_tokens(token_ids[cursor:cursor + - num_empty_slots]) - cursor += num_empty_slots - def append_token_id( self, token_id: int, logprobs: Dict[int, Logprob], ) -> None: assert token_id in logprobs - self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) @@ -388,7 +365,7 @@ def is_prefill(self) -> bool: def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " - f"num_blocks={len(self.logical_token_blocks)})") + f"num_blocks={self.n_blocks}, ") @dataclass