From 4a3d12c1eb2e67f3957691fb497cf1d88547e871 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 29 May 2024 12:09:13 -0400 Subject: [PATCH] [Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837) --- tests/core/block/test_block_manager_v2.py | 154 ++++++++++++++- tests/core/test_block_manager.py | 220 +++++++++++++++++++++- tests/core/utils.py | 99 +++++++++- vllm/core/block/utils.py | 56 ++++++ vllm/core/block_manager_v1.py | 187 ++++++++++++------ vllm/core/block_manager_v2.py | 65 ++++++- vllm/sequence.py | 23 +++ 7 files changed, 735 insertions(+), 69 deletions(-) create mode 100644 vllm/core/block/utils.py diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 91b047f0e183e..f98fc0e217278 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,11 +1,13 @@ import pytest +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group +from ..utils import create_seq_group, create_seq_group_encoder_decoder @pytest.mark.parametrize("block_size", [16]) @@ -52,6 +54,156 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group_encoder_decoder(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for bdx, num_prompt_blocks in enumerate( + range(1, num_gpu_blocks - num_output_blocks)): + num_cross_blocks_per_seq = num_prompt_blocks + + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id=str(bdx)) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + \ + num_output_blocks + \ + num_cross_blocks_per_seq + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + ''' + SWA short for Sliding Window Attention. + + At time of writing block manager v2 does not support SWA. + + However even when SWA is implemented for block manager v2, + there will still most likely be a separate workstream required + to enable SWA for encoder/decoder models. + + Therefore this test enforces that one of the following cases + hold true: + 1. Block manager v2 does not support SWA at all (true at time of writing) + 2. Block manager v2 fails with NotImplementError when SWA is enabled + AND a SequenceGroup with an encoder sequence (i.e. in support of an + encoder/decoder model) is passed into can_allocate() as an argument + + The setup for this test is stripped down version of + test_can_allocate_seq_group_encoder_decoder() + ''' + + with pytest.raises((NotImplementedError, AssertionError)) as exc_info: + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + sliding_window=5 # SWA + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + block_manager.can_allocate(seq_group) + + # Assert that either + # 1. Block manager v2 constructor fails with assertion that sliding window + # is not yet supported (most likely near-term outcome at time of + # writing), or + # 2. can_allocate() fails with NotImplementedError due to combination of + # encoder/decoder and sliding window attention + if isinstance(exc_info.value, NotImplementedError): + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + elif isinstance(exc_info.value, AssertionError): + assert str(exc_info.value) == "Sliding window not yet supported" + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_prefix_cache( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, + watermark: float): + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + enable_caching=True # Prefix cache + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + # Assert that either can_allocate() fails with NotImplementedError + # due to combination of encoder/decoder and prefix cache + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 88cd4f98091f9..ddd843174f7b1 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -6,13 +6,15 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, UncachedBlockAllocator) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from .utils import create_dummy_prompt +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def test_block_allocator_allocate(): @@ -73,7 +75,7 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -85,11 +87,107 @@ def test_allocate(): watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + +def test_allocate_encoder_decoder_fails_with_swa(): + # SWA short for sliding window attention + + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + sliding_window=5) # swa + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + # Assert that allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + +def test_allocate_encoder_decoder_fails_with_prefix_caching(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=True) # Prefix cache + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + # Assert that allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 @@ -244,6 +342,62 @@ def test_swap(): assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) +def test_swap_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + decoder_prompt.status = SequenceStatus.WAITING + encoder_prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + decoder_prompt.status = SequenceStatus.RUNNING + decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap encoder/decoder seq group from GPU -> CPU. + decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert [x[0] for x in mapping] == gpu_blocks + #assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + decoder_prompt.status = SequenceStatus.SWAPPED + + # Swap encoder/decoder seq group from CPU -> GPU. + decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert [x[0] for x in mapping] == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + def test_free(): block_size = 4 num_cpu_blocks = 4 @@ -268,6 +422,41 @@ def test_free(): block_manager.get_block_table(prompt) +def test_free_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + + # Free allocated seq. + decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) + encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) + prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(decoder_prompt) + block_manager.free_cross(seq_group) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(decoder_prompt) + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(encoder_prompt) + + def test_reset(): block_size = 4 num_cpu_blocks = 4 @@ -289,6 +478,31 @@ def test_reset(): assert block_manager.get_num_free_gpu_blocks() == original_blocks +def test_reset_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks + + def test_sliding_window_multi_seq(): """ Tests that memory allocation and deallocation is handled diff --git a/tests/core/utils.py b/tests/core/utils.py index 1c5724090b69b..cd2045b8a1889 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -39,6 +39,52 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_prompt_encoder_decoder( + request_id: str, + decoder_prompt_length: int, + encoder_prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = decoder_prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + decoder_prompt_tokens = list(range(decoder_prompt_length)) + decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + + decoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + encoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) + + return decoder_prompt, encoder_prompt, seq_group + + def create_seq_group( seq_prompt_len: int = 1024, seq_output_lens: Iterable[int] = (128, ), @@ -82,5 +128,56 @@ def create_seq_group( return seq_group +def create_seq_group_encoder_decoder( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + # Encoder sequence + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + return SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size + return (seq_len + block_size - 1) // block_size \ No newline at end of file diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000000000..2c412a8f472e0 --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,56 @@ +"""Block manager utils.""" +from vllm.sequence import SequenceGroup + +# Exception strings for non-implemented block manager enc/dec scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + + +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 52a170d79e4e7..201cba309f6ef 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -255,14 +256,30 @@ def __init__( Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.cross_block_tables: Dict[str, BlockTable] = {} + + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + return 0 if seq is None \ + else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = len(seq.logical_token_blocks) + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + self_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -276,11 +293,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - + def _allocate_sequence(self, \ + seq: Sequence, \ + ref_count: int, \ + is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -290,21 +306,46 @@ def allocate(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - elif self.enable_caching: + block.ref_count = ref_count + elif not is_encoder_decoder and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count block_table.append(block) - # Assign the block table for each sequence. + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + is_encoder_decoder = seq_group.is_encoder_decoder() + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + block_table: BlockTable = \ + self._allocate_sequence(seq, + seq_group.num_seqs(), + is_encoder_decoder) + + # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + # Allocate encoder sequence + if is_encoder_decoder: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: @@ -443,13 +484,18 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. + request_id = seq_group.request_id blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) + # Cross-attention blocks + if seq_group.is_encoder_decoder(): + blocks.update(self.cross_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -457,8 +503,11 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.is_encoder_decoder(): + num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. @@ -471,70 +520,81 @@ def can_swap_in(self, else: return AllocStatus.LATER + def _swap_block_table( + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table = [] + + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 + else: + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) + + return new_block_table + def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + request_id = seq_group.request_id + # CPU block -> GPU block. # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up @@ -559,15 +619,32 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.cross_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.cross_block_tables[seq_group.request_id] + def reset(self) -> None: + # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): + self._free_block_table(block_table) + self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 834436c25e160..cad42ab3c1ba2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,11 +5,13 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device SeqId = int +EncoderSeqId = str class BlockSpaceManagerV2(BlockSpaceManager): @@ -94,17 +96,26 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) + if seq_group.is_encoder_decoder(): + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.get_encoder_seq().get_token_ids(), + block_size=self.block_size, + ) + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.max_block_sliding_window) @@ -121,7 +132,19 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + block_table.allocate(seq.get_token_ids()) + + return block_table + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -129,20 +152,29 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) - - block_table.allocate(seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + block_table = self._allocate_sequence(seq_group.get_encoder_seq()) + self.cross_block_tables[request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue @@ -197,12 +229,27 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed # in this step. diff --git a/vllm/sequence.py b/vllm/sequence.py index f8e9da6c7965a..ee8c94bbf06f7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -430,6 +430,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( @@ -441,6 +443,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -455,6 +458,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.encoder_seq = encoder_seq @property def prompt(self) -> Optional[str]: @@ -538,6 +542,12 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq + def get_unfinished_seqs(self) -> List[Sequence]: return [ seq for seq in self.seqs_dict.values() if not seq.is_finished() @@ -621,6 +631,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( @@ -637,6 +656,8 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, + encoder_seq_data: Optional[SequenceData] = None, + cross_block_table: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -648,6 +669,8 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.encoder_seq_data = encoder_seq_data + self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample