Skip to content

Commit

Permalink
[core][misc] remove logical block (vllm-project#5882)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and jimpang committed Jul 8, 2024
1 parent 8233bca commit 8cce630
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 120 deletions.
82 changes: 1 addition & 81 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,10 @@
"""Token blocks."""
import weakref
from collections import defaultdict
from typing import Dict, List
from typing import List

from vllm.utils import Device

_BLANK_TOKEN_ID = -1

DEFAULT_LAST_ACCESSED_TIME = -1

TokensBlock = List[int]


class BlockPool:
"""A pool of logical blocks.
When requests come, we create a lot of logical blocks;
when requests are done, we destroy a lot of logical blocks.
It turns out that creating and destroying logical blocks can be expensive,
especially for the `token_ids` field, which is a list of integers.
To avoid this overhead, we use a pool to manage the logical blocks.
When an old request is done and a new request comes, we can reuse the
logical blocks from the old request to feed the new request.
"""

def __init__(self) -> None:
# block size to list of token blocks
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)

def alloc_block(self, block_size: int) -> TokensBlock:
if block_size in self.pool and self.pool[block_size]:
return self.pool[block_size].pop()
return [_BLANK_TOKEN_ID] * block_size

def del_block(self, block: TokensBlock) -> None:
self.pool[len(block)].append(block)


_BLOCK_POOL = BlockPool()


class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""

def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size

self.token_ids = _BLOCK_POOL.alloc_block(block_size)
# this finalizer is used to return the block to the pool when the object is deleted # noqa
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
# i.e. `self.token_ids` may be deleted before `self`, and we lose
# the opportunity to return the block to the pool
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
self.token_ids)
self.num_tokens = 0

def is_empty(self) -> bool:
return self.num_tokens == 0

def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens

def is_full(self) -> bool:
return self.num_tokens == self.block_size

def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)

def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]

def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]


class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
Expand Down
19 changes: 9 additions & 10 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def __init__(
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
Expand Down Expand Up @@ -298,7 +297,7 @@ def _allocate_sequence(self, \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
num_prompt_blocks = seq.n_blocks

block_table: BlockTable = []
for logical_idx in range(num_prompt_blocks):
Expand Down Expand Up @@ -367,7 +366,7 @@ def _promote_last_block(

# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
new_hash = seq.hash_of_block(seq.n_blocks - 1)

# if new_hash is already in the cached table, then free last_block
# and return the cached version
Expand Down Expand Up @@ -407,10 +406,10 @@ def _allocate_last_physical_block(
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
n_blocks = seq.n_blocks
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(n_blocks - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)

# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
Expand All @@ -429,12 +428,12 @@ def append_slots(
num_lookahead_slots: int = 0,
) -> List[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
n_blocks = seq.n_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
if len(block_table) < n_blocks:
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
assert len(block_table) == n_blocks - 1

if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
Expand Down
35 changes: 6 additions & 29 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Sequence and its related classes."""
import copy
import enum
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch

from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -236,9 +236,6 @@ def __init__(
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None

Expand All @@ -248,6 +245,10 @@ def __init__(
# Input + output tokens
self.tokens: Optional[List[str]] = None

@property
def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size)

@property
def prompt(self) -> Optional[str]:
return self.inputs.get("prompt")
Expand Down Expand Up @@ -287,36 +288,12 @@ def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()

def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)

def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
cursor = 0
while cursor < len(token_ids):
if not self.logical_token_blocks:
self._append_logical_block()

last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self._append_logical_block()
last_block = self.logical_token_blocks[-1]

num_empty_slots = last_block.get_num_empty_slots()
last_block.append_tokens(token_ids[cursor:cursor +
num_empty_slots])
cursor += num_empty_slots

def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs
self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)

Expand Down Expand Up @@ -388,7 +365,7 @@ def is_prefill(self) -> bool:
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name}, "
f"num_blocks={len(self.logical_token_blocks)})")
f"num_blocks={self.n_blocks}, ")


@dataclass
Expand Down

0 comments on commit 8cce630

Please sign in to comment.