Skip to content

Commit

Permalink
[Misc]: Implement CPU/GPU swapping in BlockManagerV2 (vllm-project#3834)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyang-Chen authored and blinkbear committed Jun 6, 2024
1 parent 6034663 commit 13e8f41
Show file tree
Hide file tree
Showing 16 changed files with 520 additions and 38 deletions.
2 changes: 1 addition & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ mypy vllm/model_executor --config-file pyproject.toml
# https://github.com/codespell-project/codespell/issues/1915
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
CODESPELL_EXCLUDES=(
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**'
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
)

# check spelling of specified files
Expand Down
49 changes: 40 additions & 9 deletions tests/core/block/e2e/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
Expand Down Expand Up @@ -95,7 +101,13 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
Expand Down Expand Up @@ -179,11 +191,18 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
}])
[
{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
"preemption_mode": "swap",
},
{
"num_lookahead_slots": 10,
"preemption_mode": "recompute",
}
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
Expand Down Expand Up @@ -322,7 +341,13 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
Expand Down Expand Up @@ -397,7 +422,13 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"enable_prefix_caching": True,
"preemption_mode": "swap"
}, {
"enable_prefix_caching": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
Expand Down
58 changes: 57 additions & 1 deletion tests/core/block/test_block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list

from ..utils import create_seq_group, create_seq_group_encoder_decoder
from ..utils import (create_dummy_prompt, create_seq_group,
create_seq_group_encoder_decoder)


@pytest.mark.parametrize("block_size", [16])
Expand Down Expand Up @@ -255,6 +256,61 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
assert num_consumed_blocks == expected_consumed_blocks


@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("num_cpu_blocks", [4])
@pytest.mark.parametrize("num_gpu_blocks", [4])
@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
@pytest.mark.parametrize("enable_caching", [False, True])
def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
enable_caching):
"""Verify blocks number on src/desc device is correct after swapping in/out
sequence group (not missing or extra blocks).
"""
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})

# Swap seq group from GPU -> CPU.
gpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
prompt.status = SequenceStatus.SWAPPED

# Swap seq group from CPU -> GPU.
assert block_manager.can_swap_in(seq_group, num_lookahead_slots)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
cpu_blocks = block_manager.get_block_table(prompt)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == [cpu_blocks[0]]
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)


# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.


@pytest.mark.parametrize("block_size", [8, 16])
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
@pytest.mark.parametrize("num_slots_to_append", [50])
Expand Down
8 changes: 8 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
"""

def __init__(
Expand All @@ -664,6 +670,7 @@ def __init__(
enable_chunked_prefill: bool = False,
policy: str = "fcfs",
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
Expand Down Expand Up @@ -691,6 +698,7 @@ def __init__(
self.chunked_prefill_enabled = enable_chunked_prefill
self.policy = policy
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode

self._verify_args()

Expand Down
4 changes: 4 additions & 0 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ def _get_all_token_ids(self) -> List[int]:
def _is_allocated(self) -> bool:
return len(self._blocks) > 0

@property
def blocks(self) -> Optional[List[Block]]:
return self._blocks

@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
Expand Down
1 change: 0 additions & 1 deletion vllm/core/block/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
assert refcount != 0
if refcount > 1:
src_block_id = block_id

# Decrement refcount of the old block.
self._allocator.free(block)

Expand Down
82 changes: 77 additions & 5 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,8 @@ def create(
gpu_block_allocator=gpu_allocator,
)

def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
Expand All @@ -105,6 +102,7 @@ def __init__(
Device.GPU: gpu_block_allocator,
}

self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None

self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
Expand Down Expand Up @@ -198,6 +196,68 @@ def get_num_free_blocks(self, device: Device) -> int:
def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()

def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain device given the
absolute block id.
Args:
device (Device): The device for which to query relative block id.
absolute_id (int): The absolute block id for the block in
whole allocator.
Returns:
int: The zero-offset block id on certain device.
"""
return self._allocators[device].get_physical_block_id(absolute_id)

def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
scheduling move.
Args:
blocks: List of blocks to be swapped.
source_device (Device): Device to swap the 'blocks' from.
dest_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
source_block_ids = [block.block_id for block in blocks]
self._allocators[source_device].swap_out(blocks)
self._allocators[dest_device].swap_in(blocks)
dest_block_ids = [block.block_id for block in blocks]

current_swap_mapping: Dict[int, int] = {}
for src, dest in zip(source_block_ids, dest_block_ids):
if src is not None and dest is not None:
self._swap_mapping[src] = dest
current_swap_mapping[src] = dest
return current_swap_mapping

def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
"""Returns the number of blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Args:
blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
"""
return self._allocators[device].get_num_blocks_touched(
blocks, num_lookahead_slots)

def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Expand Down Expand Up @@ -240,6 +300,18 @@ def promote_to_immutable_block(self, block: Block) -> BlockId:
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError

def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
schedule when BlockManagerV2 become default. Currently not useful.
Returns:
List[Tuple[int, int]]: A mapping of source to destination block IDs.
"""
mapping = self._swap_mapping.copy()
self._swap_mapping.clear()
return list(mapping.items())


class NullBlock(Block):
"""
Expand Down
36 changes: 35 additions & 1 deletion vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import FrozenSet, List, Optional, Protocol, Tuple
from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple

from vllm.utils import Device

Expand Down Expand Up @@ -116,6 +116,18 @@ def get_num_total_blocks(self) -> int:
def get_num_free_blocks(self) -> int:
pass

@abstractmethod
def get_physical_block_id(self, absolute_id: int) -> int:
pass

@abstractmethod
def swap_out(self, blocks: List[Block]) -> None:
pass

@abstractmethod
def swap_in(self, blocks: List[Block]) -> None:
pass

@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
Expand Down Expand Up @@ -149,6 +161,12 @@ def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass

@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
pass

class NoFreeBlocksError(ValueError):
pass

Expand Down Expand Up @@ -204,6 +222,22 @@ def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass

@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
pass

@abstractmethod
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
pass

@abstractmethod
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
pass

@abstractmethod
def allocate_or_get_null_block(self) -> Block:
"""
Expand Down
Loading

0 comments on commit 13e8f41

Please sign in to comment.