From 7eb0e0d7a42b3ac64a7912faf1f2822601da5f2a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 09:44:51 -0400 Subject: [PATCH 001/239] added block manager tests --- tests/core/test_block_manager.py | 132 ++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 1 deletion(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 22a9f0cf47d32..6b2fa21f2ef46 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -12,7 +12,7 @@ from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from .utils import create_dummy_prompt +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def test_block_allocator_allocate(): @@ -89,6 +89,34 @@ def test_allocate(): block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size, block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range((num_gpu_blocks - 1)//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size//2, block_size//2) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK def test_append_slot_single_seq(): block_size = 4 @@ -240,6 +268,58 @@ def test_swap(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) +def test_swap_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + decoder_prompt.status = SequenceStatus.WAITING + encoder_prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + decoder_prompt.status = SequenceStatus.RUNNING + decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap encoder/decoder seq group from GPU -> CPU. + decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) + encoder_gpu_blocks = block_manager.get_encoder_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + encoder_gpu_blocks + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert [x[0] for x in mapping] == gpu_blocks + #assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + decoder_prompt.status = SequenceStatus.SWAPPED + + # Swap decoder seq group from CPU -> GPU. + decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) + encoder_cpu_blocks = block_manager.get_encoder_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + encoder_cpu_blocks + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert [x[0] for x in mapping] == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) def test_free(): block_size = 4 @@ -264,6 +344,34 @@ def test_free(): with pytest.raises(KeyError): block_manager.get_block_table(prompt) +def test_free_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", + decoder_prompt_length=block_size//2, + encoder_prompt_length=block_size//2) + block_manager.allocate(seq_group) + + # Free allocated seq. + decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) + encoder_prompt_blocks = len(block_manager.get_encoder_block_table(seq_group)) + prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(decoder_prompt) + block_manager.free_encoder(seq_group) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(decoder_prompt) + block_manager.get_block_table(encoder_prompt) def test_reset(): block_size = 4 @@ -285,6 +393,28 @@ def test_reset(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks +def test_reset_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks//block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder(f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks def test_sliding_window_multi_seq(): """ From 6e41c39b24e8bdcff76ebbab0b95e16c0603e0b3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 09:52:03 -0400 Subject: [PATCH 002/239] passing block manager encoder/decoder test --- tests/core/utils.py | 29 ++++++++ vllm/core/block_manager_v1.py | 130 ++++++++++++++++++++++++++++++++-- vllm/sequence.py | 12 ++++ 3 files changed, 166 insertions(+), 5 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 8fb13177a2d6c..170bf9fff3dd2 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -32,6 +32,35 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_prompt_encoder_decoder( + request_id: str, + decoder_prompt_length: int, + encoder_prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = decoder_prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + decoder_prompt_tokens = list(range(decoder_prompt_length)) + decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + decoder_prompt = Sequence(int(request_id), decoder_prompt_str, decoder_prompt_tokens, block_size) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + encoder_prompt = Sequence(int(request_id), encoder_prompt_str, encoder_prompt_tokens, block_size) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) + + return decoder_prompt, encoder_prompt, seq_group def create_seq_group( seq_prompt_len: int = 1024, diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 52a170d79e4e7..bd2ccbbb86572 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -255,12 +255,23 @@ def __init__( Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.encoder_block_tables: Dict[str, BlockTable] = {} + + def get_seq_num_required_blocks(self, seq: Sequence) -> int: + if seq is None: + return 0 + return len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = len(seq.logical_token_blocks) + + decoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + encoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) + num_required_blocks = decoder_num_required_blocks+encoder_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -276,9 +287,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: + def allocate_decoder(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same - # prompt. + # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] # Allocate new physical token blocks that will store the prompt tokens. @@ -301,10 +312,46 @@ def allocate(self, seq_group: SequenceGroup) -> None: block.ref_count = seq_group.num_seqs() block_table.append(block) - # Assign the block table for each sequence. + # Assign the decoder block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + def allocate_encoder(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + seq = seq_group.get_encoder_seq() + + # Allocate new physical token blocks that will store the prompt tokens. + block_table: BlockTable = [] + if seq is None: + # Assign empty encoder block table for the SequenceGroup + self.encoder_block_tables[seq_group.request_id] = block_table + else: + num_prompt_blocks = len(seq.logical_token_blocks) + for logical_idx in range(num_prompt_blocks): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + elif self.enable_caching: + block = self.gpu_allocator.allocate( + seq.hash_of_block(logical_idx), + seq.num_hashed_tokens_of_block(logical_idx)) + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + # TODO: feature not supported with encoder/decoder + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + # Assign the encoder block table for the SequenceGroup. + self.encoder_block_tables[seq_group.request_id] = block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + self.allocate_decoder(seq_group) + self.allocate_encoder(seq_group) + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: @@ -445,11 +492,15 @@ def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. + request_id = seq_group.request_id blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) + # Encoder blocks + if seq_group.encoder_seq is not None: + blocks.update(self.encoder_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -459,6 +510,8 @@ def can_swap_in(self, ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.encoder_seq is not None: + num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. @@ -477,6 +530,8 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + request_id = seq_group.request_id + # CPU block -> GPU block. # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} @@ -497,6 +552,23 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.block_tables[seq.seq_id] = new_block_table + if seq_group.encoder_seq is not None: + new_block_table: BlockTable = [] + block_table = self.encoder_block_tables[request_id] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate( + cpu_block.block_hash, cpu_block.num_hashed_tokens) + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + self.encoder_block_tables[request_id] = new_block_table + block_number_mapping = { cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items() @@ -509,6 +581,8 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: return len(blocks) <= self.cpu_allocator.get_num_free_blocks() def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} @@ -529,6 +603,23 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.block_tables[seq.seq_id] = new_block_table + if seq_group.encoder_seq is not None: + new_block_table: BlockTable = [] + block_table = self.encoder_block_tables[request_id] + + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate( + gpu_block.block_hash, gpu_block.num_hashed_tokens) + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + self.encoder_block_tables[request_id] = new_block_table + block_number_mapping = { gpu_block.block_number: cpu_block.block_number for gpu_block, cpu_block in mapping.items() @@ -559,15 +650,32 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_encoder(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.encoder_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.encoder_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.encoder_block_tables[seq_group.request_id] + def reset(self) -> None: + # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() + # Free encoder block tables + for block_table in self.encoder_block_tables.values(): + self._free_block_table(block_table) + self.encoder_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] + def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.encoder_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() @@ -586,6 +694,18 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time + def access_all_encoder_blocks_in_seq_group( + self, + seq_group: SequenceGroup, + access_time: float, + ) -> None: + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.encoder_block_tables[seq_group.request_id] + for block in block_table: + block.last_accessed = access_time + def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return diff --git a/vllm/sequence.py b/vllm/sequence.py index aa759448d82b1..ca2de3ef0d774 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,6 +420,7 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. + encoder_seq: Optional, the single encoder sequence. """ def __init__( @@ -432,6 +433,7 @@ def __init__( multi_modal_data: Optional[MultiModalData] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -447,6 +449,7 @@ def __init__( self.multi_modal_data = multi_modal_data self.embeddings = embeddings self.pooling_params = pooling_params + self.encoder_seq = encoder_seq @property def prompt(self) -> str: @@ -524,6 +527,9 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def get_encoder_seq(self) -> Sequence: + return self.encoder_seq + def get_unfinished_seqs(self) -> List[Sequence]: return [ seq for seq in self.seqs_dict.values() if not seq.is_finished() @@ -607,6 +613,8 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + encoder_seq_data: Optional, the sequence data for the single encoder prompt. + encoder_block_table: Optional, the block table for the single encoder prompt. """ def __init__( @@ -623,6 +631,8 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, + encoder_seq_data: Optional[SequenceData] = None, + encoder_block_table: Optional[Dict[int, List[int]]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -634,6 +644,8 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.encoder_seq_data = encoder_seq_data + self.encoder_block_table = encoder_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample From f04ee73114eb50dbf03cb1d2a9ecd238705db035 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 14:22:04 -0400 Subject: [PATCH 003/239] block manager v2 changes to pass test_can_allocate_seq_group_encoder_decoder --- tests/core/block/test_block_manager_v2.py | 49 ++++++++++++++++++++++- tests/core/utils.py | 49 +++++++++++++++++++++++ vllm/core/block_manager_v2.py | 6 +++ 3 files changed, 103 insertions(+), 1 deletion(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..6cb2f3708199f 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -5,7 +5,7 @@ from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group +from ..utils import create_seq_group, create_seq_group_encoder_decoder @pytest.mark.parametrize("block_size", [16]) @@ -52,6 +52,53 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, + num_gpu_blocks: int, watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): + num_encoder_blocks_per_seq = num_prompt_blocks + + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id=str(bdx) + ) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + num_output_blocks + num_encoder_blocks_per_seq + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/utils.py b/tests/core/utils.py index 170bf9fff3dd2..91930457bd25b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -102,5 +102,54 @@ def create_seq_group( return seq_group +def create_seq_group_encoder_decoder( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + # Encoder sequence + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq + ) + + return seq_group + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f0bc96564050a..06bfbba78dce6 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -96,6 +96,12 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) + if seq_group.encoder_seq is not None: + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + assert self.block_sliding_window is None if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, From 07bbd8ac4c44f50f42137350ee928483842d02ee Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 14:47:47 -0400 Subject: [PATCH 004/239] block manager v2 support for encoder/decoder --- vllm/core/block_manager_v1.py | 9 ++---- vllm/core/block_manager_v2.py | 59 ++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index bd2ccbbb86572..812d1ee3197a5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -319,14 +319,11 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: def allocate_encoder(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. - seq = seq_group.get_encoder_seq() # Allocate new physical token blocks that will store the prompt tokens. - block_table: BlockTable = [] - if seq is None: - # Assign empty encoder block table for the SequenceGroup - self.encoder_block_tables[seq_group.request_id] = block_table - else: + seq = seq_group.get_encoder_seq() + if seq is not None: + block_table: BlockTable = [] num_prompt_blocks = len(seq.logical_token_blocks) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 06bfbba78dce6..2f7a11bacc1a1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -10,6 +10,7 @@ from vllm.utils import Device SeqId = int +EncoderSeqId = str class BlockSpaceManagerV2(BlockSpaceManager): @@ -85,6 +86,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.encoder_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -119,7 +121,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: + def allocate_decoder(self, seq_group: SequenceGroup) -> None: waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -140,6 +142,28 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + def allocate_encoder(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + request_id = seq_group.request_id + seq = seq_group.encoder_seq + + assert not (request_id in self.encoder_block_tables), "block table already exists" + + seq = seq_group.get_encoder_seq() + if seq is not None: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + ) + assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) + self.encoder_block_tables[request_id] = block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + self.allocate_decoder(seq_group) + self.allocate_encoder(seq_group) + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue @@ -193,12 +217,29 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] + def free_encoder(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.encoder_block_tables: + # Already freed or hasn't ben scheduled yet. + return + self.encoder_block_tables[request_id].free() + del self.encoder_block_tables[request_id] + + del self.encoder_block_tables[seq_group.request_id] + def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.encoder_block_tables + block_ids = self.block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids + def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed # in this step. @@ -215,6 +256,22 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) + def access_all_encoder_blocks_in_seq_group( + self, + seq_group: SequenceGroup, + now: float, + ) -> None: + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.encoder_block_tables[seq_group.request_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed( + block_ids, # type: ignore + now) + def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, # while currently we could determine whether one block is computed From 3e95602f9c408f82628e881f30540ac82b3cb5f7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:11:35 -0400 Subject: [PATCH 005/239] renamed encoder to cross in block manager v2, regarding block tables --- vllm/core/block_manager_v2.py | 32 ++++++++++++++++---------------- vllm/sequence.py | 6 +++--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 2f7a11bacc1a1..426612f615508 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -86,7 +86,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} - self.encoder_block_tables: Dict[EncoderSeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -121,7 +121,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_decoder(self, seq_group: SequenceGroup) -> None: + def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -142,13 +142,13 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def allocate_encoder(self, seq_group: SequenceGroup) -> None: + def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. request_id = seq_group.request_id seq = seq_group.encoder_seq - assert not (request_id in self.encoder_block_tables), "block table already exists" + assert not (request_id in self.cross_block_tables), "block table already exists" seq = seq_group.get_encoder_seq() if seq is not None: @@ -158,11 +158,11 @@ def allocate_encoder(self, seq_group: SequenceGroup) -> None: ) assert self.block_sliding_window is None block_table.allocate(seq.get_token_ids()) - self.encoder_block_tables[request_id] = block_table + self.cross_block_tables[request_id] = block_table def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_decoder(seq_group) - self.allocate_encoder(seq_group) + self.allocate_self_block_tables(seq_group) + self.allocate_cross_block_table(seq_group) def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: @@ -217,15 +217,15 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] - def free_encoder(self, seq_group: SequenceGroup) -> None: + def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id - if request_id not in self.encoder_block_tables: + if request_id not in self.cross_block_tables: # Already freed or hasn't ben scheduled yet. return - self.encoder_block_tables[request_id].free() - del self.encoder_block_tables[request_id] + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] - del self.encoder_block_tables[seq_group.request_id] + del self.cross_block_tables[seq_group.request_id] def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables @@ -233,9 +233,9 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert all(b is not None for b in block_ids) return block_ids # type: ignore - def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: request_id = seq_group.request_id - assert request_id in self.encoder_block_tables + assert request_id in self.cross_block_tables block_ids = self.block_tables[request_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids @@ -256,7 +256,7 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) - def access_all_encoder_blocks_in_seq_group( + def access_all_cross_blocks_in_seq_group( self, seq_group: SequenceGroup, now: float, @@ -264,7 +264,7 @@ def access_all_encoder_blocks_in_seq_group( if self.enable_caching: # Update the last accessed time of all the blocks accessed # in this step. - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] block_ids = [] for block_id in block_table.physical_block_ids: block_ids.append(block_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index ca2de3ef0d774..a73e70c1ae69d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -614,7 +614,7 @@ class SequenceGroupMetadata: state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional, the sequence data for the single encoder prompt. - encoder_block_table: Optional, the block table for the single encoder prompt. + cross_block_table: Optional, the cross-attention block table associated with the single encoder prompt. """ def __init__( @@ -632,7 +632,7 @@ def __init__( state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, encoder_seq_data: Optional[SequenceData] = None, - encoder_block_table: Optional[Dict[int, List[int]]] = None, + cross_block_table: Optional[Dict[int, List[int]]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -645,7 +645,7 @@ def __init__( self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data - self.encoder_block_table = encoder_block_table + self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample From 04f38a819445c0141246feeb6969cc4b1e67891f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:22:53 -0400 Subject: [PATCH 006/239] renamed encoder to cross where appropriate --- tests/core/block/test_block_manager_v2.py | 4 +- tests/core/test_block_manager.py | 12 ++--- vllm/core/block_manager_v1.py | 54 +++++++++++------------ 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 6cb2f3708199f..9b1c6cd68a15a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -74,7 +74,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr num_output_blocks = num_output_blocks_per_seq for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): - num_encoder_blocks_per_seq = num_prompt_blocks + num_cross_blocks_per_seq = num_prompt_blocks seq_group = create_seq_group_encoder_decoder( seq_prompt_len=block_size * num_prompt_blocks, @@ -89,7 +89,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr can_allocate_result = block_manager.can_allocate(seq_group) - num_required_blocks = num_prompt_blocks + num_output_blocks + num_encoder_blocks_per_seq + num_required_blocks = num_prompt_blocks + num_output_blocks + num_cross_blocks_per_seq if num_gpu_blocks - num_required_blocks < num_watermark_blocks: assert can_allocate_result == AllocStatus.NEVER diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 6b2fa21f2ef46..62b7132e40462 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -293,8 +293,8 @@ def test_swap_encoder_decoder(): # Swap encoder/decoder seq group from GPU -> CPU. decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) - encoder_gpu_blocks = block_manager.get_encoder_block_table(seq_group) - gpu_blocks = decoder_gpu_blocks + encoder_gpu_blocks + cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks assert block_manager.can_swap_out(seq_group) before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -309,8 +309,8 @@ def test_swap_encoder_decoder(): # Swap decoder seq group from CPU -> GPU. decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) - encoder_cpu_blocks = block_manager.get_encoder_block_table(seq_group) - cpu_blocks = decoder_cpu_blocks + encoder_cpu_blocks + cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() @@ -360,11 +360,11 @@ def test_free_encoder_decoder(): # Free allocated seq. decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) - encoder_prompt_blocks = len(block_manager.get_encoder_block_table(seq_group)) + encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks before_blocks = block_manager.get_num_free_gpu_blocks() block_manager.free(decoder_prompt) - block_manager.free_encoder(seq_group) + block_manager.free_cross(seq_group) after_blocks = block_manager.get_num_free_gpu_blocks() assert after_blocks == before_blocks + prompt_blocks diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 812d1ee3197a5..11a52b3618b44 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -258,7 +258,7 @@ def __init__( # Mapping: req_id -> BlockTable # Note that each SequenceGroup has a unique # request ID - self.encoder_block_tables: Dict[str, BlockTable] = {} + self.cross_block_tables: Dict[str, BlockTable] = {} def get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: @@ -269,9 +269,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - decoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - encoder_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) - num_required_blocks = decoder_num_required_blocks+encoder_num_required_blocks + self_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks+cross_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -287,7 +287,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_decoder(self, seq_group: SequenceGroup) -> None: + def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] @@ -316,7 +316,7 @@ def allocate_decoder(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - def allocate_encoder(self, seq_group: SequenceGroup) -> None: + def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. @@ -342,12 +342,12 @@ def allocate_encoder(self, seq_group: SequenceGroup) -> None: block.ref_count = seq_group.num_seqs() block_table.append(block) - # Assign the encoder block table for the SequenceGroup. - self.encoder_block_tables[seq_group.request_id] = block_table + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_decoder(seq_group) - self.allocate_encoder(seq_group) + self.allocate_self_block_tables(seq_group) + self.allocate_cross_block_table(seq_group) def can_append_slots(self, seq_group: SequenceGroup, @@ -495,9 +495,9 @@ def _get_physical_blocks( if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) - # Encoder blocks + # Cross-attention blocks if seq_group.encoder_seq is not None: - blocks.update(self.encoder_block_tables[request_id]) + blocks.update(self.cross_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -551,7 +551,7 @@ def swap_in(self, if seq_group.encoder_seq is not None: new_block_table: BlockTable = [] - block_table = self.encoder_block_tables[request_id] + block_table = self.cross_block_tables[request_id] for cpu_block in block_table: if cpu_block in mapping: @@ -564,7 +564,7 @@ def swap_in(self, new_block_table.append(gpu_block) # Free the CPU block swapped in to GPU. self.cpu_allocator.free(cpu_block) - self.encoder_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = new_block_table block_number_mapping = { cpu_block.block_number: gpu_block.block_number @@ -602,7 +602,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: if seq_group.encoder_seq is not None: new_block_table: BlockTable = [] - block_table = self.encoder_block_tables[request_id] + block_table = self.cross_block_tables[request_id] for gpu_block in block_table: if gpu_block in mapping: @@ -615,7 +615,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: new_block_table.append(cpu_block) # Free the GPU block swapped out to CPU. self.gpu_allocator.free(gpu_block) - self.encoder_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = new_block_table block_number_mapping = { gpu_block.block_number: cpu_block.block_number @@ -647,30 +647,30 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] - def free_encoder(self, seq_group: SequenceGroup) -> None: - if seq_group.request_id not in self.encoder_block_tables: + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: # Already freed or hasn't ben scheduled yet. return - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] self._free_block_table(block_table) - del self.encoder_block_tables[seq_group.request_id] + del self.cross_block_tables[seq_group.request_id] def reset(self) -> None: # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() - # Free encoder block tables - for block_table in self.encoder_block_tables.values(): + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): self._free_block_table(block_table) - self.encoder_block_tables.clear() + self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] - def get_encoder_block_table(self, seq_group: SequenceGroup) -> List[int]: - block_table = self.encoder_block_tables[seq_group.request_id] + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] return [block.block_number for block in block_table] def get_num_free_gpu_blocks(self) -> int: @@ -691,7 +691,7 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def access_all_encoder_blocks_in_seq_group( + def access_all_cross_blocks_in_seq_group( self, seq_group: SequenceGroup, access_time: float, @@ -699,7 +699,7 @@ def access_all_encoder_blocks_in_seq_group( if self.enable_caching: # Update the last accessed time of all the blocks accessed # in this step. - block_table = self.encoder_block_tables[seq_group.request_id] + block_table = self.cross_block_tables[seq_group.request_id] for block in block_table: block.last_accessed = access_time From 2dcd663d40bdcc1cf2aca19b9cec64395ac6d528 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 15:45:12 -0400 Subject: [PATCH 007/239] formatting --- tests/core/block/test_block_manager_v2.py | 16 +++++--- tests/core/test_block_manager.py | 43 +++++++++++++++------- tests/core/utils.py | 45 ++++++++++++----------- vllm/core/block_manager_v1.py | 18 +++++---- vllm/core/block_manager_v2.py | 8 ++-- vllm/sequence.py | 9 +++-- 6 files changed, 85 insertions(+), 54 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 9b1c6cd68a15a..06c3389cfa0f0 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -56,8 +56,10 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, @pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) @pytest.mark.parametrize("num_seqs_per_group", [1, 4]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, - num_gpu_blocks: int, watermark: float): +def test_can_allocate_seq_group_encoder_decoder(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): block_manager = BlockSpaceManagerV2( block_size=block_size, num_gpu_blocks=num_gpu_blocks, @@ -73,7 +75,8 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr # different output lens. num_output_blocks = num_output_blocks_per_seq - for bdx,num_prompt_blocks in enumerate(range(1, num_gpu_blocks - num_output_blocks)): + for bdx, num_prompt_blocks in enumerate( + range(1, num_gpu_blocks - num_output_blocks)): num_cross_blocks_per_seq = num_prompt_blocks seq_group = create_seq_group_encoder_decoder( @@ -82,14 +85,15 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_gr block_size * num_output_blocks_per_seq for _ in range(num_seqs_per_group) ], - request_id=str(bdx) - ) + request_id=str(bdx)) assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks can_allocate_result = block_manager.can_allocate(seq_group) - num_required_blocks = num_prompt_blocks + num_output_blocks + num_cross_blocks_per_seq + num_required_blocks = num_prompt_blocks + \ + num_output_blocks + \ + num_cross_blocks_per_seq if num_gpu_blocks - num_required_blocks < num_watermark_blocks: assert can_allocate_result == AllocStatus.NEVER diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62b7132e40462..d6ab246699903 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -89,6 +89,7 @@ def test_allocate(): block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK + def test_allocate_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -100,8 +101,9 @@ def test_allocate_encoder_decoder(): watermark=0) # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size, block_size) + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), block_size, block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -112,12 +114,14 @@ def test_allocate_encoder_decoder(): num_cpu_blocks, num_gpu_blocks, watermark=1 / num_gpu_blocks) - for i in range((num_gpu_blocks - 1)//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(str(i), block_size//2, block_size//2) + for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), block_size // 2, block_size // 2) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 @@ -268,6 +272,7 @@ def test_swap(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + def test_swap_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -277,9 +282,11 @@ def test_swap_encoder_decoder(): num_gpu_blocks, watermark=0) - decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) decoder_prompt.status = SequenceStatus.WAITING encoder_prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) @@ -321,6 +328,7 @@ def test_swap_encoder_decoder(): assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + def test_free(): block_size = 4 num_cpu_blocks = 4 @@ -344,6 +352,7 @@ def test_free(): with pytest.raises(KeyError): block_manager.get_block_table(prompt) + def test_free_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -353,9 +362,11 @@ def test_free_encoder_decoder(): num_gpu_blocks, watermark=0) - decoder_prompt, encoder_prompt, seq_group = create_dummy_prompt_encoder_decoder("1", - decoder_prompt_length=block_size//2, - encoder_prompt_length=block_size//2) + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size // 2, + encoder_prompt_length=block_size // 2) block_manager.allocate(seq_group) # Free allocated seq. @@ -373,6 +384,7 @@ def test_free_encoder_decoder(): block_manager.get_block_table(decoder_prompt) block_manager.get_block_table(encoder_prompt) + def test_reset(): block_size = 4 num_cpu_blocks = 4 @@ -393,6 +405,7 @@ def test_reset(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks + def test_reset_encoder_decoder(): block_size = 4 num_cpu_blocks = 4 @@ -405,10 +418,11 @@ def test_reset_encoder_decoder(): # Allocate same seq group on all available gpu blocks. original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks//block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder(f"{i}", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) block_manager.allocate(seq_group) assert block_manager.get_num_free_gpu_blocks() == 0 @@ -416,6 +430,7 @@ def test_reset_encoder_decoder(): block_manager.reset() assert block_manager.get_num_free_gpu_blocks() == original_blocks + def test_sliding_window_multi_seq(): """ Tests that memory allocation and deallocation is handled diff --git a/tests/core/utils.py b/tests/core/utils.py index 91930457bd25b..376af0f0eac4f 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -32,6 +32,7 @@ def create_dummy_prompt( return prompt, seq_group + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, @@ -48,20 +49,24 @@ def create_dummy_prompt_encoder_decoder( # and prompt "0 ... block_size". decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - decoder_prompt = Sequence(int(request_id), decoder_prompt_str, decoder_prompt_tokens, block_size) + decoder_prompt = Sequence(int(request_id), decoder_prompt_str, + decoder_prompt_tokens, block_size) encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - encoder_prompt = Sequence(int(request_id), encoder_prompt_str, encoder_prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id=request_id, - seqs=[decoder_prompt], - sampling_params=SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - arrival_time=time.time(), - lora_request=lora_request, - encoder_seq=encoder_prompt) + encoder_prompt = Sequence(int(request_id), encoder_prompt_str, + encoder_prompt_tokens, block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) return decoder_prompt, encoder_prompt, seq_group + def create_seq_group( seq_prompt_len: int = 1024, seq_output_lens: Iterable[int] = (128, ), @@ -134,20 +139,18 @@ def create_seq_group_encoder_decoder( # Encoder sequence encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - prompt="", - prompt_token_ids=prompt_token_ids, - block_size=16, - ) - - seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq + seq_id=seq_id_start + len(seq_output_lens), + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, ) + seq_group = SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) + return seq_group diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 11a52b3618b44..03eba2e80c78d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -263,15 +263,18 @@ def __init__( def get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: return 0 - return len(seq.logical_token_blocks) + return len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - self_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self.get_seq_num_required_blocks(seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks+cross_num_required_blocks + self_num_required_blocks = self.get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self.get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -328,7 +331,8 @@ def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % self.block_sliding_window] + block = block_table[logical_idx % + self.block_sliding_window] # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() elif self.enable_caching: @@ -550,7 +554,7 @@ def swap_in(self, self.block_tables[seq.seq_id] = new_block_table if seq_group.encoder_seq is not None: - new_block_table: BlockTable = [] + new_block_table = [] block_table = self.cross_block_tables[request_id] for cpu_block in block_table: @@ -601,7 +605,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.block_tables[seq.seq_id] = new_block_table if seq_group.encoder_seq is not None: - new_block_table: BlockTable = [] + new_block_table = [] block_table = self.cross_block_tables[request_id] for gpu_block in block_table: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 426612f615508..4ae3361e7b234 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -148,7 +148,9 @@ def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id seq = seq_group.encoder_seq - assert not (request_id in self.cross_block_tables), "block table already exists" + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" seq = seq_group.get_encoder_seq() if seq is not None: @@ -236,9 +238,9 @@ def get_block_table(self, seq: Sequence) -> List[int]: def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: request_id = seq_group.request_id assert request_id in self.cross_block_tables - block_ids = self.block_tables[request_id].physical_block_ids + block_ids = self.cross_block_tables[request_id].physical_block_ids assert all(b is not None for b in block_ids) - return block_ids + return block_ids # type: ignore def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed diff --git a/vllm/sequence.py b/vllm/sequence.py index a73e70c1ae69d..a11c411876ea8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -528,7 +528,7 @@ def get_seqs( ] def get_encoder_seq(self) -> Sequence: - return self.encoder_seq + return self.encoder_seq # type: ignore def get_unfinished_seqs(self) -> List[Sequence]: return [ @@ -613,8 +613,11 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data for the single encoder prompt. - cross_block_table: Optional, the cross-attention block table associated with the single encoder prompt. + encoder_seq_data: Optional, the sequence data + for the single encoder prompt. + cross_block_table: Optional, the cross-attention + block table associated with + the single encoder prompt. """ def __init__( From 8ff1ddf224d742cf3aea6b7fa9f55409b5815b7b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 15 May 2024 19:40:45 -0400 Subject: [PATCH 008/239] attention test & xformers backend changes --- tests/layer/test_self_and_cross_attn.py | 466 ++++++++++++++++++++++++ vllm/attention/backends/xformers.py | 29 +- 2 files changed, 489 insertions(+), 6 deletions(-) create mode 100644 tests/layer/test_self_and_cross_attn.py diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py new file mode 100644 index 0000000000000..c29f7224ab57f --- /dev/null +++ b/tests/layer/test_self_and_cross_attn.py @@ -0,0 +1,466 @@ +import random +from typing import List, Optional +import itertools + +import pytest +import torch +import copy +from vllm.attention import Attention, AttentionMetadata, AttentionMetadataPerStage + +from vllm.attention.backends.xformers import XFormersBackend +from vllm.attention.backends.abstract import AttentionBackend + +from vllm.attention.ops.paged_attn import PagedAttention + +from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import is_hip +from vllm.utils import make_tensor_with_pad + +from vllm.attention.layer import Attention + +import random + +# FlashAttention forward only supports head dimension at most 128 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +HEAD_SIZES = [64] + +# [64, 80, 96, 112, 128, 256 +# ] if not is_hip() else [64, 80, 96, 112, 128] + +NUM_HEADS = [1] + +BATCH_SIZES = [16] +BLOCK_SIZES = [16] +#KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] +BACKEND_NAMES = ["xformers"] +#CUDA_DEVICES = [ +# f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +#] + +PROMPT_LENS = [32] + +def build_causal_mask(q_max_prompt_len, k_max_prompt_len): + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_prompt_len, k_max_prompt_len), diagonal=1) #.transpose(0, 1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) + return mask + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + #query=query.unsqueeze(-2) + #key=key.unsqueeze(-2) + #value=value.unsqueeze(-2) + #assert False,f"{query.shape} ; {key.shape}" + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() + #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) + #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" + return out + +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): + if force_max_len: + q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens + else: + # K,V prompt lens come from K,V operands + kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + else: + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens + else: + # K,V prompt lens come from K,V operands + kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + + query=torch.rand((batch_size,max_q_prompt_len,head_size)) + key=torch.rand((batch_size,max_kv_prompt_len,head_size)) + value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + + for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): + query[bdx,q_prompt_len:] = 0 + key[bdx,kv_prompt_len:] = 0 + value[bdx,kv_prompt_len:] = 0 + + query=query.unsqueeze(-2) + key=key.unsqueeze(-2) + value=value.unsqueeze(-2) + + return query,key,value,q_prompt_lens,kv_prompt_lens + +def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): + num_tok = sum(prompt_lens) + num_heads = unpacked_tensor.shape[-2] + head_size = unpacked_tensor.shape[-1] + start_loc_list = [0]+list(itertools.accumulate(prompt_lens)) + packed_tensor = torch.zeros((num_tok,num_heads,head_size), + device=device) + + #assert False, f"{start_loc_list}" + + #assert False, f"{packed_tensor.shape} ; {unpacked_tensor.shape}" + + for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): + try: + packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] + except: + assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" + + return packed_tensor,start_loc_list + +def pack_qkv(query,key,value,q_prompt_lens,kv_prompt_lens): + packed_query,q_start_loc_list = pack_tensor(query,q_prompt_lens) + packed_key,kv_start_loc_list = pack_tensor(key,kv_prompt_lens) + packed_value,_ = pack_tensor(value,kv_prompt_lens) + packed_query=packed_query.view(-1,packed_query.shape[-1]*packed_query.shape[-2]) + packed_key=packed_key.view(-1,packed_key.shape[-1]*packed_key.shape[-2]) + packed_value=packed_value.view(-1,packed_value.shape[-1]*packed_value.shape[-2]) + return packed_query,packed_key,packed_value,q_start_loc_list,kv_start_loc_list + +def make_backend(backend_name: str) -> AttentionBackend: + if backend_name == "xformers": + return XFormersBackend() + assert False, f"Unrecognized backend_name {backend_name} for unit test" + +def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> AttentionMetadataPerStage: + ''' + Assumptions: + * No chunked prefill + * No (automatic) prefix caching + * Packed variable-length sequences + ''' + prompt_lens_tensor=torch.tensor(prompt_lens, + dtype=torch.int, + device=device) + context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, + dtype=torch.int, + device=device) + max_query_len=None if prompt_lens is None else max(prompt_lens) + max_context_len=None if context_lens is None else max(context_lens) + max_prompt_len=max_query_len + + seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + query_start_loc = copy.deepcopy(seq_start_loc) + + return attn_backend.make_metadata( + is_prompt=is_prompt, + is_cross_attn=is_cross_attn, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + cross_seq_lens=cross_prompt_lens, + max_query_len=max_query_len, + #max_context_len=max_context_len, + max_seq_len=max_prompt_len, + subquery_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) + +def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): + #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) + #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) + kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) + if default_val is not None: + kv_cache[:,:,:] = default_val + return kv_cache + +def num_tokens_to_min_blocks(num_tokens,block_size): + return (num_tokens+block_size)//block_size + +def make_flat_block_tables_slot_mapping(block_size,prompt_lens): + ''' + Naive block table: + * For each batch element... + * Block table has + ''' + num_tokens = sum(prompt_lens) + num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) + block_tables = list(range(num_blocks*100)) + slot_mapping = [(idx % block_size) + block_tables[idx//block_size]*block_size for idx in range(num_tokens)] + prefill_block_tables_tensor = torch.tensor( + [], + device='cuda:0' + ) + block_tables_tensor = torch.tensor( + block_tables, + device='cuda:0' + ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device='cuda:0' + ) + + return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + +def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): + ''' + Naive block table: + * For each batch element... + * Block table has + ''' + num_prompts = len(prompt_lens) + total_num_tokens = sum(prompt_lens) + # Over-provision block table blocks by 1 + num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] + max_block_table_len = max(num_blocks_list) + #block_tables = [list(range(num_blocks*10)) for num_blocks in num_blocks_list] + block_table_pad_tokens = 10 + + block_tables = [] + slot_mapping = [] + block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed + #seq_base_idx = 0 + for sdx,num_tokens in enumerate(prompt_lens): + #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) + num_blocks = num_blocks_list[sdx] + block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) + for idx in range(num_tokens): + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + + #seq_base_idx += num_tokens + block_base_idx -= num_blocks + block_tables.append(block_table) + + prefill_block_tables_tensor = torch.tensor( + [], + device='cuda:0' + ) + block_tables_tensor = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len+block_table_pad_tokens, + pad=0, + dtype=torch.int, + device=device, + ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device=device + ) + + return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + + +def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): + ''' + Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + ''' + + if is_prompt: + num_prefills = len(prompt_lens) + num_prefill_tokens = sum(prompt_lens) + num_decode_tokens = 0 + + # make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) + stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) + + return AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=stage_metadata, + decode_metadata=None, + kv_cache_dtype=kv_cache_dtype, + ) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = sum(context_lens) + + stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) + + return AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=None, + decode_metadata=stage_metadata, + kv_cache_dtype=kv_cache_dtype, + ) + +def make_attention(num_heads: int, head_size: int, scale: float): + # Attention operator instance + return Attention(num_heads, + head_size, + scale=scale,) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_prompt_len",PROMPT_LENS) +def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=False + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + max_q_prompt_len = max_prompt_len + max_kv_prompt_len = max_q_prompt_len + context_lens = None + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, num_heads, head_size) + #(key_cache, value_cache) = kv_cache + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) + #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + attn_mask=causal_mask + ) + + prefill_query = query[:,:-1] + prefill_key = key[:,:-1] + prefill_value = value[:,:-1] + decode_query = query[:,-1:] + decode_key = key[:,-1:] + decode_value = value[:,-1:] + prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = [plen-1 for plen in kv_prompt_lens] + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] + prefill_ideal_output = ideal_output[:,:-1] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_ideal_output = ideal_output[:,-1:] + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + # Put KVs in KV cache + # Deprecated - handled automatically inside attention + # PagedAttention.write_to_paged_cache(key, value, key_cache, + # value_cache, + # prefill_attn_metadata.slot_mapping, + # prefill_attn_metadata.kv_cache_dtype, + # scale) + + is_prompt = False + context_lens = [1 for _ in range(batch_size)] + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",PROMPT_LENS) +def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=True + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + #max_q_prompt_len = max_prompt_len + #max_kv_prompt_len = max_prompt_len + context_lens = None + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + # key_cache, value_cache = PagedAttention.split_kv_cache( + # kv_cache, num_heads, head_size) + #(key_cache, value_cache) = kv_cache + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=True) + #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) + #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + #attn_mask=causal_mask + ) + + prefill_query = query[:,:-1] + prefill_key = key #key[:,:-1] + prefill_value = value #value[:,:-1] + decode_query = query[:,-1:] + decode_key = key #key[:,-1:] + decode_value = value #value[:,-1:] + prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = kv_prompt_lens + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = kv_prompt_lens + prefill_ideal_output = ideal_output[:,:-1] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_ideal_output = ideal_output[:,-1:] + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_kv_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=prefill_kv_prompt_lens) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = [1 for _ in range(batch_size)] + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fc46af054de4f..0a35a41a69a93 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -5,6 +5,7 @@ import torch from xformers import ops as xops from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalMask, BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) @@ -108,6 +109,14 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _cached_prefill_metadata: Optional["XFormersMetadata"] = None _cached_decode_metadata: Optional["XFormersMetadata"] = None + # Need to make KV cache read-only for cross-attention + is_cross_attn: bool = False + + # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value + # sequence length (usually encoder sequence length) in the cross-attention + # computation. None if this is self-attention + cross_seq_lens: Optional[List[int]] = None + def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt @@ -270,16 +279,20 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + is_cross_attn = (attn_metadata.prefill_metadata is not None and attn_metadata.prefill_metadata.is_cross_attn) or (attn_metadata.decode_metadata is not None and attn_metadata.decode_metadata.is_cross_attn) + assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + + if not is_cross_attn: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens @@ -374,8 +387,12 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + if attn_metadata.is_cross_attn: + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 685afc07eaa6560821676e8bb212d88ed09a206b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 16 May 2024 12:52:49 -0400 Subject: [PATCH 009/239] wip attn tests --- tests/layer/__init__.py | 0 tests/layer/test_self_and_cross_attn.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tests/layer/__init__.py diff --git a/tests/layer/__init__.py b/tests/layer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index c29f7224ab57f..2da312b03afaf 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -5,7 +5,7 @@ import pytest import torch import copy -from vllm.attention import Attention, AttentionMetadata, AttentionMetadataPerStage +from vllm.attention import Attention, AttentionMetadata #, AttentionMetadataPerStage from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend From 5278f13426fb17574b6060fb62afcc56e1b20077 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 10:58:48 -0400 Subject: [PATCH 010/239] wip self attention test --- tests/layer/test_self_and_cross_attn.py | 128 +++++++++++++++++------- 1 file changed, 92 insertions(+), 36 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 2da312b03afaf..96419a54b7ad6 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -135,7 +135,7 @@ def make_backend(backend_name: str) -> AttentionBackend: return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> AttentionMetadataPerStage: +def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> tuple: ''' Assumptions: * No chunked prefill @@ -156,27 +156,35 @@ def make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_ dtype=torch.int32, device=device) - torch.cumsum(prompt_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) + # torch.cumsum(prompt_lens_tensor, + # dim=0, + # dtype=seq_start_loc.dtype, + # out=seq_start_loc[1:]) query_start_loc = copy.deepcopy(seq_start_loc) - return attn_backend.make_metadata( - is_prompt=is_prompt, - is_cross_attn=is_cross_attn, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - cross_seq_lens=cross_prompt_lens, - max_query_len=max_query_len, - #max_context_len=max_context_len, - max_seq_len=max_prompt_len, - subquery_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) + return prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc + + # return attn_backend.make_metadata( + # is_prompt=is_prompt, + # is_cross_attn=is_cross_attn, + # seq_lens=prompt_lens, + # seq_lens_tensor=prompt_lens_tensor, + # cross_seq_lens=cross_prompt_lens, + # max_query_len=max_query_len, + # #max_context_len=max_context_len, + # max_seq_len=max_prompt_len, + # subquery_start_loc=query_start_loc, + # seq_start_loc=seq_start_loc, + # context_lens_tensor=context_lens_tensor, + # block_tables=block_tables, + # use_cuda_graph=False, + # ) def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) @@ -275,17 +283,40 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefill_tokens = sum(prompt_lens) num_decode_tokens = 0 - # make_stage_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) - stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) - - return AttentionMetadata( + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(attn_backend, + is_prompt, + is_cross_attn, + prompt_lens, + context_lens, + block_tables, + device=device, + cross_prompt_lens=cross_prompt_lens) + + slot_mapping_tensor=torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + + return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - prefill_metadata=stage_metadata, - decode_metadata=None, - kv_cache_dtype=kv_cache_dtype, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, ) else: # not is_prompt @@ -294,16 +325,40 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefill_tokens = 0 num_decode_tokens = sum(context_lens) - stage_metadata:AttentionMetadataPerStage = make_stage_metadata(attn_backend, is_prompt, is_cross_attn, prompt_lens, context_lens, block_tables, device=device, cross_prompt_lens=cross_prompt_lens) - - return AttentionMetadata( + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + max_context_len, \ + max_prompt_len, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(attn_backend, + is_prompt, + is_cross_attn, + prompt_lens, + context_lens, + block_tables, + device=device, + cross_prompt_lens=cross_prompt_lens) + + slot_mapping_tensor=torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + + return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - prefill_metadata=None, - decode_metadata=stage_metadata, - kv_cache_dtype=kv_cache_dtype, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, ) def make_attention(num_heads: int, head_size: int, scale: float): @@ -326,7 +381,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len - context_lens = None + context_lens = [0 for _ in range(batch_size)] key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') @@ -392,6 +447,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) +@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) From 1824a9523064f6410f766cc22a032db54b7a429a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 13:45:59 -0400 Subject: [PATCH 011/239] tests run but do not pass --- tests/layer/test_self_and_cross_attn.py | 31 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 96419a54b7ad6..61ab4075e0597 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -148,19 +148,28 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, dtype=torch.int, device=device) - max_query_len=None if prompt_lens is None else max(prompt_lens) max_context_len=None if context_lens is None else max(context_lens) - max_prompt_len=max_query_len + max_prompt_len=None if prompt_lens is None else max(prompt_lens) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) - # torch.cumsum(prompt_lens_tensor, - # dim=0, - # dtype=seq_start_loc.dtype, - # out=seq_start_loc[1:]) - query_start_loc = copy.deepcopy(seq_start_loc) + torch.cumsum(prompt_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + if is_prompt: + # Prefill: query_start_loc matches seq_start_loc + query_start_loc = copy.deepcopy(seq_start_loc) + max_query_len=max_prompt_len + else: + # Decode: one new query input token per batch + # element, thus query_start_loc is the cumsum + # of [1,1,1,...] + query_start_loc = list(range(len(seq_start_loc))) + max_query_len = 1 return prompt_lens_tensor, \ context_lens_tensor, \ @@ -323,7 +332,7 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = sum(context_lens) + num_decode_tokens = len(prompt_lens) prompt_lens_tensor, \ context_lens_tensor, \ @@ -352,8 +361,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b seq_lens=prompt_lens, seq_lens_tensor=prompt_lens_tensor, max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), - max_decode_seq_len=0, + max_prefill_seq_len=0, + max_decode_seq_len=max(prompt_lens), query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -437,7 +446,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # scale) is_prompt = False - context_lens = [1 for _ in range(batch_size)] + context_lens = copy.deepcopy(prefill_kv_prompt_lens) decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) From 64b7b6154c077c78d84cc26f67011e3a04a3d9b0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 15:06:38 -0400 Subject: [PATCH 012/239] passing self-attention --- tests/layer/test_self_and_cross_attn.py | 38 ++++++++++++------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 61ab4075e0597..fd0682996c8da 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -12,8 +12,6 @@ from vllm.attention.ops.paged_attn import PagedAttention -from vllm.utils import get_max_shared_memory_bytes -from vllm.utils import is_hip from vllm.utils import make_tensor_with_pad from vllm.attention.layer import Attention @@ -247,15 +245,18 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_table_pad_tokens = 10 block_tables = [] - slot_mapping = [] + prefill_slot_mapping = [] + decode_slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) num_blocks = num_blocks_list[sdx] block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) - for idx in range(num_tokens): - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + for idx in range(num_tokens-1): + prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + idx = num_tokens-1 + decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) #seq_base_idx += num_tokens block_base_idx -= num_blocks @@ -265,20 +266,25 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): [], device='cuda:0' ) - block_tables_tensor = make_tensor_with_pad( + decode_block_tables_tensor = make_tensor_with_pad( block_tables, max_len=max_block_table_len+block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - slot_mapping_tensor = torch.tensor( - slot_mapping, + prefill_slot_mapping_tensor = torch.tensor( + prefill_slot_mapping, + dtype=torch.long, + device=device + ) + decode_slot_mapping_tensor = torch.tensor( + decode_slot_mapping, dtype=torch.long, device=device ) - return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): @@ -427,8 +433,8 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) @@ -437,17 +443,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - # Put KVs in KV cache - # Deprecated - handled automatically inside attention - # PagedAttention.write_to_paged_cache(key, value, key_cache, - # value_cache, - # prefill_attn_metadata.slot_mapping, - # prefill_attn_metadata.kv_cache_dtype, - # scale) - is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) From c99aa0dcc579a22fcaf698c30169c5d7bfb69898 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 16:01:50 -0400 Subject: [PATCH 013/239] passing self-attention test with variable lengths! --- tests/layer/test_self_and_cross_attn.py | 124 ++++++++++++++++-------- 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index fd0682996c8da..5cabef7e5f8b6 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -64,40 +64,79 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): - if force_max_len: - q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] - kv_prompt_lens = None - if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens - else: - # K,V prompt lens come from K,V operands - kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True): + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = None + if not is_cross_attn: + # K,V prompt lens match Q for self-attention + kv_prompt_lens = q_prompt_lens else: - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] - kv_prompt_lens = None - if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens - else: - # K,V prompt lens come from K,V operands - kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] - + # K,V prompt lens come from K,V operands + kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + + actual_max_q_prompt_len = max(q_prompt_lens) + actual_max_kv_prompt_len = max(kv_prompt_lens) + query=torch.rand((batch_size,max_q_prompt_len,head_size)) key=torch.rand((batch_size,max_kv_prompt_len,head_size)) value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)) + prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + + decode_query=torch.zeros((batch_size,1,head_size)) + decode_key=torch.zeros((batch_size,1,head_size)) + decode_value=torch.zeros((batch_size,1,head_size)) + for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): - query[bdx,q_prompt_len:] = 0 - key[bdx,kv_prompt_len:] = 0 - value[bdx,kv_prompt_len:] = 0 + query[bdx,q_prompt_len:,:] = 0 + key[bdx,kv_prompt_len:,:] = 0 + value[bdx,kv_prompt_len:,:] = 0 + + prefill_query[bdx,0:(q_prompt_len-1),:] = query[bdx,0:(q_prompt_len-1),:] + prefill_key[bdx,0:(kv_prompt_len-1),:] = key[bdx,0:(kv_prompt_len-1),:] + prefill_value[bdx,0:(kv_prompt_len-1),:] = value[bdx,0:(kv_prompt_len-1),:] + + decode_query[bdx,:,:] = query[bdx,(q_prompt_len-1):q_prompt_len,:] + decode_key[bdx,:,:] = key[bdx,(kv_prompt_len-1):kv_prompt_len,:] + decode_value[bdx,:,:] = value[bdx,(kv_prompt_len-1):kv_prompt_len,:] + + prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] + prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] + + decode_q_prompt_lens = [1 for _ in q_prompt_lens] + decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] query=query.unsqueeze(-2) key=key.unsqueeze(-2) value=value.unsqueeze(-2) - return query,key,value,q_prompt_lens,kv_prompt_lens + prefill_query=prefill_query.unsqueeze(-2) + prefill_key=prefill_key.unsqueeze(-2) + prefill_value=prefill_value.unsqueeze(-2) + + decode_query=decode_query.unsqueeze(-2) + decode_key=decode_key.unsqueeze(-2) + decode_value=decode_value.unsqueeze(-2) + + return query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): num_tok = sum(prompt_lens) @@ -400,13 +439,26 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, num_heads, head_size) - #(key_cache, value_cache) = kv_cache scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) - query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) @@ -418,19 +470,13 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n attn_mask=causal_mask ) - prefill_query = query[:,:-1] - prefill_key = key[:,:-1] - prefill_value = value[:,:-1] - decode_query = query[:,-1:] - decode_key = key[:,-1:] - decode_value = value[:,-1:] - prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = [plen-1 for plen in kv_prompt_lens] - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - prefill_ideal_output = ideal_output[:,:-1] + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) From 270d95e78af170692d19bc9d37d6badec5a91869 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 17 May 2024 17:16:44 -0400 Subject: [PATCH 014/239] wip cross-attention; is_cross_atn and cross_seq_lens is transferred from parent metadata struct to child metadata structs; cross-attn test runs without functional errors but fails all_close --- tests/layer/test_self_and_cross_attn.py | 147 ++++++++++-------------- vllm/attention/backends/xformers.py | 6 +- 2 files changed, 68 insertions(+), 85 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5cabef7e5f8b6..d5ca2c5ccbc5b 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -146,10 +146,6 @@ def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): packed_tensor = torch.zeros((num_tok,num_heads,head_size), device=device) - #assert False, f"{start_loc_list}" - - #assert False, f"{packed_tensor.shape} ; {unpacked_tensor.shape}" - for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): try: packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] @@ -216,22 +212,6 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros seq_start_loc, \ query_start_loc - # return attn_backend.make_metadata( - # is_prompt=is_prompt, - # is_cross_attn=is_cross_attn, - # seq_lens=prompt_lens, - # seq_lens_tensor=prompt_lens_tensor, - # cross_seq_lens=cross_prompt_lens, - # max_query_len=max_query_len, - # #max_context_len=max_context_len, - # max_seq_len=max_prompt_len, - # subquery_start_loc=query_start_loc, - # seq_start_loc=seq_start_loc, - # context_lens_tensor=context_lens_tensor, - # block_tables=block_tables, - # use_cuda_graph=False, - # ) - def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) @@ -243,32 +223,6 @@ def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, def num_tokens_to_min_blocks(num_tokens,block_size): return (num_tokens+block_size)//block_size -def make_flat_block_tables_slot_mapping(block_size,prompt_lens): - ''' - Naive block table: - * For each batch element... - * Block table has - ''' - num_tokens = sum(prompt_lens) - num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) - block_tables = list(range(num_blocks*100)) - slot_mapping = [(idx % block_size) + block_tables[idx//block_size]*block_size for idx in range(num_tokens)] - prefill_block_tables_tensor = torch.tensor( - [], - device='cuda:0' - ) - block_tables_tensor = torch.tensor( - block_tables, - device='cuda:0' - ) - slot_mapping_tensor = torch.tensor( - slot_mapping, - dtype=torch.long, - device='cuda:0' - ) - - return block_tables_tensor, slot_mapping_tensor, prefill_block_tables_tensor - def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): ''' Naive block table: @@ -286,6 +240,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_tables = [] prefill_slot_mapping = [] decode_slot_mapping = [] + slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): @@ -294,8 +249,10 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) for idx in range(num_tokens-1): prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) idx = num_tokens-1 decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) #seq_base_idx += num_tokens block_base_idx -= num_blocks @@ -322,8 +279,18 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): dtype=torch.long, device=device ) + slot_mapping_tensor = torch.tensor( + slot_mapping, + dtype=torch.long, + device=device + ) + empty_slot_mapping_tensor = torch.tensor( + [], + dtype=torch.long, + device=device + ) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): @@ -371,6 +338,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + is_cross_attn=is_cross_attn, + cross_seq_lens=cross_prompt_lens ) else: # not is_prompt @@ -413,6 +382,8 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + is_cross_attn=is_cross_attn, + cross_seq_lens=cross_prompt_lens ) def make_attention(num_heads: int, head_size: int, scale: float): @@ -421,6 +392,7 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) +@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -452,15 +424,14 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_key, \ decode_value, \ q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ + _, \ + _, \ + _, \ prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) - #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, @@ -479,10 +450,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,q_prompt_lens) + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) @@ -493,14 +464,13 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n context_lens = copy.deepcopy(prefill_kv_prompt_lens) decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) -@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -514,21 +484,32 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ device='cuda:0' kv_cache_dtype='auto' is_prompt = True - #max_q_prompt_len = max_prompt_len - #max_kv_prompt_len = max_prompt_len - context_lens = None + context_lens = [0 for _ in range(batch_size)] key_read_width = 4 num_blocks = 4096 kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - # key_cache, value_cache = PagedAttention.split_kv_cache( - # kv_cache, num_heads, head_size) - #(key_cache, value_cache) = kv_cache scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) - query,key,value,q_prompt_lens,kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=True) - #block_tables, slot_mapping = make_block_tables_slot_mapping(block_size,q_prompt_lens) - #prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + query, \ + key, \ + value, \ + prefill_query, \ + _, \ + _, \ + decode_query, \ + _, \ + _, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + _, \ + decode_q_prompt_lens, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, @@ -538,25 +519,23 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ #attn_mask=causal_mask ) - prefill_query = query[:,:-1] - prefill_key = key #key[:,:-1] - prefill_value = value #value[:,:-1] - decode_query = query[:,-1:] - decode_key = key #key[:,-1:] - decode_value = value #value[:,-1:] - prefill_q_prompt_lens = [plen-1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = kv_prompt_lens - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = kv_prompt_lens - prefill_ideal_output = ideal_output[:,:-1] + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_ideal_output = ideal_output[:,-1:] decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - block_tables, slot_mapping, prefill_block_tables = make_block_tables_slot_mapping(block_size,prefill_kv_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=prefill_kv_prompt_lens) + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) + + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) @@ -564,12 +543,12 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) is_prompt = False - context_lens = [1 for _ in range(batch_size)] - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, block_tables, slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + context_lens = copy.deepcopy(kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3f5752ab4d445..0ef80ca410355 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -154,6 +154,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, + is_cross_attn=self.is_cross_attn, + cross_seq_lens=self.cross_seq_lens ) return self._cached_prefill_metadata @@ -182,6 +184,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, + is_cross_attn=self.is_cross_attn, + cross_seq_lens=self.cross_seq_lens ) return self._cached_decode_metadata @@ -280,7 +284,7 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - is_cross_attn = (attn_metadata.prefill_metadata is not None and attn_metadata.prefill_metadata.is_cross_attn) or (attn_metadata.decode_metadata is not None and attn_metadata.decode_metadata.is_cross_attn) + is_cross_attn = attn_metadata.is_cross_attn assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) From c41079153e5f2d4b1f9a21dc92d4e6e7f99b2182 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 15:04:41 -0400 Subject: [PATCH 015/239] moved ideal to cuda --- tests/layer/test_self_and_cross_attn.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d5ca2c5ccbc5b..feab7aad97d8a 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -27,7 +27,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [16] +BATCH_SIZES = [1] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -77,17 +77,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,head_size)) - key=torch.rand((batch_size,max_kv_prompt_len,head_size)) - value=torch.rand((batch_size,max_kv_prompt_len,head_size)) + query=torch.rand((batch_size,max_q_prompt_len,head_size)).cuda() + key=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() + value=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() - prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)) - prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) - prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)) + prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)).cuda() + prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() + prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() - decode_query=torch.zeros((batch_size,1,head_size)) - decode_key=torch.zeros((batch_size,1,head_size)) - decode_value=torch.zeros((batch_size,1,head_size)) + decode_query=torch.zeros((batch_size,1,head_size)).cuda() + decode_key=torch.zeros((batch_size,1,head_size)).cuda() + decode_value=torch.zeros((batch_size,1,head_size)).cuda() for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): query[bdx,q_prompt_len:,:] = 0 @@ -432,7 +432,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() ideal_output = ref_masked_attention( query, key, From 3719f5c2dff4532720703119bb141ac5dd3c9053 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 16:18:36 -0400 Subject: [PATCH 016/239] wip cross-attention; appears to be power-of-two key-length requirement? --- tests/layer/test_self_and_cross_attn.py | 22 ++++++++++++++++------ vllm/attention/backends/xformers.py | 4 ++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index feab7aad97d8a..1621ee99d4a50 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -35,7 +35,11 @@ # f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) #] -PROMPT_LENS = [32] +PROMPT_LENS = [8] + +Q_PROMPT_LENS = [7] + +K_PROMPT_LENS = [32] def build_causal_mask(q_max_prompt_len, k_max_prompt_len): # Create a matrix where entry (i, j) is True if i >= j @@ -64,15 +68,21 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True): - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): + if force_max_len: + q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + else: + q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention kv_prompt_lens = q_prompt_lens else: # K,V prompt lens come from K,V operands - kv_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + if force_max_len: + kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + else: + kv_prompt_lens = [random.randint(1, max_kv_prompt_len) for _ in range(batch_size)] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -476,8 +486,8 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",PROMPT_LENS) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0ef80ca410355..46862d72cd7f9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -392,8 +392,8 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: if attn_metadata.is_cross_attn: - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_bias = None #BlockDiagonalMask.from_seqlens( + # attn_metadata.seq_lens,attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 96082e155e501951890c25fbefa387805a817452 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 18 May 2024 17:35:01 -0400 Subject: [PATCH 017/239] trying to debug cross-attention issue --- tests/layer/test_self_and_cross_attn.py | 13 ++++++++----- vllm/attention/backends/xformers.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 1621ee99d4a50..8c3b3b0c0e359 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -37,9 +37,9 @@ PROMPT_LENS = [8] -Q_PROMPT_LENS = [7] +Q_PROMPT_LENS = [128] -K_PROMPT_LENS = [32] +K_PROMPT_LENS = [128] def build_causal_mask(q_max_prompt_len, k_max_prompt_len): # Create a matrix where entry (i, j) is True if i >= j @@ -68,7 +68,9 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): + assert max_kv_prompt_len >= max_q_prompt_len + if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: @@ -82,7 +84,8 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [random.randint(1, max_kv_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = [16*((q_prompt_len + random.randint(0, max_kv_prompt_len-q_prompt_len))//16) + for q_prompt_len,_ in zip(q_prompt_lens,range(batch_size))] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -547,7 +550,7 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output=attn.forward(prefill_packed_query.contiguous(),prefill_packed_key.contiguous(),prefill_packed_value.contiguous(),kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 46862d72cd7f9..21d828edefc78 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,7 +420,8 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale) + scale=self.scale, + op=xops.MemoryEfficientAttentionOp()) return out.view_as(original_query) # Attention with alibi slopes. From d99c5d94164dc7695b5933465cce15c74d39609a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 11:07:04 -0400 Subject: [PATCH 018/239] wip --- tests/layer/test_self_and_cross_attn.py | 251 +++++++++++++++++++++++- vllm/attention/backends/xformers.py | 3 +- 2 files changed, 243 insertions(+), 11 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 8c3b3b0c0e359..ddf67c04bfef5 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -18,6 +18,13 @@ import random +from xformers import ops as xops + +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) + # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64] @@ -27,7 +34,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [1] +BATCH_SIZES = [2] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -35,9 +42,9 @@ # f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) #] -PROMPT_LENS = [8] +PROMPT_LENS = [128] -Q_PROMPT_LENS = [128] +Q_PROMPT_LENS = [129] K_PROMPT_LENS = [128] @@ -68,13 +75,13 @@ def ref_masked_attention( #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=True): - assert max_kv_prompt_len >= max_q_prompt_len +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): + #assert max_kv_prompt_len >= max_q_prompt_len if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(1, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [random.randint(3, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -84,8 +91,8 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [16*((q_prompt_len + random.randint(0, max_kv_prompt_len-q_prompt_len))//16) - for q_prompt_len,_ in zip(q_prompt_lens,range(batch_size))] + kv_prompt_lens = [min(q_prompt_len-1,max_kv_prompt_len) + for q_prompt_len in q_prompt_lens] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -405,7 +412,115 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -@pytest.mark.skip +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) +def test_xops_memory_efficient_attention_forward_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int): + # Attention operator instance + is_cross_attn=True + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + context_lens = [0 for _ in range(batch_size)] + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + + query, \ + key, \ + value, \ + prefill_query, \ + _, \ + _, \ + decode_query, \ + _, \ + _, \ + q_prompt_lens, \ + kv_prompt_lens, \ + actual_max_q_prompt_len, \ + actual_max_kv_prompt_len, \ + prefill_q_prompt_lens, \ + _, \ + decode_q_prompt_lens, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + + #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + #attn_mask=causal_mask + ) + + original_query = query + #query = query.unsqueeze(0) + #key = key.unsqueeze(0) + #value = value.unsqueeze(0) + xops_out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=None, + p=0.0, + scale=scale) + + assert torch.allclose(ideal_output,xops_out) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) + + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + xops_out = xops.memory_efficient_attention_forward( + prefill_packed_query.unsqueeze(0), + prefill_packed_key.unsqueeze(0), + prefill_packed_value.unsqueeze(0), + attn_bias=None, + p=0.0, + scale=scale) + xops_out=xops_out.view_as(prefill_packed_query) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,xops_out) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = copy.deepcopy(kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + + decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + +#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -470,6 +585,18 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + # attn_bias = BlockDiagonalCausalMask.from_seqlens(prefill_q_prompt_lens) + # xops_out = xops.memory_efficient_attention_forward( + # prefill_packed_query.unsqueeze(0), + # prefill_packed_key.unsqueeze(0), + # prefill_packed_value.unsqueeze(0), + # attn_bias=attn_bias, + # p=0.0, + # scale=scale) + # xops_out=xops_out.view_as(prefill_packed_query) + + # assert torch.allclose(xops_out,prefill_packed_ideal_output[:,0,:]) + # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) @@ -484,6 +611,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -492,6 +620,111 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + # Attention operator instance + is_cross_attn=False + device='cuda:0' + kv_cache_dtype='auto' + is_prompt = True + #max_q_prompt_len = max_prompt_len + max_kv_prompt_len = max_q_prompt_len + context_lens = [0 for _ in range(batch_size)] + key_read_width = 4 + num_blocks = 4096 + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + scale = float(1.0 / (head_size**0.5)) + attn = make_attention(num_heads, head_size, scale) + attn_backend = make_backend(backend_name) + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + _, \ + _, \ + _, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() + ideal_output = ref_masked_attention( + query, + key, + value, + scale=scale, + attn_mask=causal_mask + ) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) + for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] + decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] + + prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) + decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + + + + shorten_amt = 17 + new_ideal_output = ref_masked_attention( + query[:,:(prefill_q_prompt_lens[0]-shorten_amt),:,:], + key[:,:prefill_kv_prompt_lens[0],:,:], + value[:,:prefill_kv_prompt_lens[0],:,:], + scale=scale, + attn_mask=None #causal_mask[:(prefill_q_prompt_lens[0]-shorten_amt),:prefill_kv_prompt_lens[0]] + ) + + attn_bias = None #BlockDiagonalCausalMask.from_seqlens([(prefill_q_prompt_lens[0]-shorten_amt)],[prefill_kv_prompt_lens[0]]) + + xops_out = xops.memory_efficient_attention_forward( + prefill_packed_query.view(-1, num_heads, head_size)[:-shorten_amt,:,:].unsqueeze(0), + prefill_packed_key.view(-1, num_heads, head_size).unsqueeze(0), + prefill_packed_value.view(-1, num_heads, head_size).unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale) + #xops_out=xops_out.view_as(prefill_packed_query) + + assert torch.allclose(xops_out,new_ideal_output.view_as(xops_out)) + + # eval correctness of prefill output + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + + is_prompt = False + context_lens = copy.deepcopy(prefill_kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + + decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + + decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + + # eval correctness of decode output + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + + +@pytest.mark.skip +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size",BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) +def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True device='cuda:0' diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 21d828edefc78..46862d72cd7f9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -420,8 +420,7 @@ def _run_memory_efficient_xformers_forward( value, attn_bias=attn_metadata.attn_bias[0], p=0.0, - scale=self.scale, - op=xops.MemoryEfficientAttentionOp()) + scale=self.scale) return out.view_as(original_query) # Attention with alibi slopes. From 7880b0ea9ecf60d49802d9944a20466b404a7e1b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 16:31:39 -0400 Subject: [PATCH 019/239] cross-attention prefill works! --- tests/layer/test_self_and_cross_attn.py | 255 +++--------------------- vllm/attention/backends/xformers.py | 4 +- 2 files changed, 29 insertions(+), 230 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index ddf67c04bfef5..954e56d8bcc39 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -10,21 +10,12 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.ops.paged_attn import PagedAttention - from vllm.utils import make_tensor_with_pad from vllm.attention.layer import Attention import random -from xformers import ops as xops - -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalMask, - BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) - # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64] @@ -60,7 +51,9 @@ def ref_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, - attn_mask: Optional[torch.Tensor] = None, + custom_mask: Optional[torch.Tensor] = None, + q_prompt_lens: Optional[List] = None, + kv_prompt_lens: Optional[List] = None ) -> torch.Tensor: #query=query.unsqueeze(-2) #key=key.unsqueeze(-2) @@ -68,8 +61,23 @@ def ref_masked_attention( #assert False,f"{query.shape} ; {key.shape}" attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" - if attn_mask is not None: + + # Lowest-level attention mask, derived from prompt lens + if (q_prompt_lens is not None) or (kv_prompt_lens is not None): + attn_mask = torch.zeros_like(attn_weights) + if q_prompt_lens is not None: + for bdx,plen in enumerate(q_prompt_lens): + attn_mask[bdx,:,plen:,:] = -torch.inf + if kv_prompt_lens is not None: + for bdx,plen in enumerate(kv_prompt_lens): + attn_mask[bdx,:,:,plen:] = -torch.inf + attn_weights = attn_weights + attn_mask.float() + + # Custom attention mask + if custom_mask is not None: + attn_weights = attn_weights + custom_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" @@ -412,114 +420,6 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_xops_memory_efficient_attention_forward_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int): - # Attention operator instance - is_cross_attn=True - device='cuda:0' - kv_cache_dtype='auto' - is_prompt = True - context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - - query, \ - key, \ - value, \ - prefill_query, \ - _, \ - _, \ - decode_query, \ - _, \ - _, \ - q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ - prefill_q_prompt_lens, \ - _, \ - decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) - - #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - #attn_mask=causal_mask - ) - - original_query = query - #query = query.unsqueeze(0) - #key = key.unsqueeze(0) - #value = value.unsqueeze(0) - xops_out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=None, - p=0.0, - scale=scale) - - assert torch.allclose(ideal_output,xops_out) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - - xops_out = xops.memory_efficient_attention_forward( - prefill_packed_query.unsqueeze(0), - prefill_packed_key.unsqueeze(0), - prefill_packed_value.unsqueeze(0), - attn_bias=None, - p=0.0, - scale=scale) - xops_out=xops_out.view_as(prefill_packed_query) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,xops_out) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - - is_prompt = False - context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) - - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) - - # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - #@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -552,7 +452,7 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_key, \ decode_value, \ q_prompt_lens, \ - _, \ + kv_prompt_lens, \ _, \ _, \ prefill_q_prompt_lens, \ @@ -566,7 +466,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n key, value, scale=scale, - attn_mask=causal_mask + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens ) prefill_ideal_output = torch.zeros_like(ideal_output) @@ -609,9 +511,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) +#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -620,111 +523,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n @pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn=False - device='cuda:0' - kv_cache_dtype='auto' - is_prompt = True - #max_q_prompt_len = max_prompt_len - max_kv_prompt_len = max_q_prompt_len - context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_prompt_lens, \ - _, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - attn_mask=causal_mask - ) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) - - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) - - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - - - - shorten_amt = 17 - new_ideal_output = ref_masked_attention( - query[:,:(prefill_q_prompt_lens[0]-shorten_amt),:,:], - key[:,:prefill_kv_prompt_lens[0],:,:], - value[:,:prefill_kv_prompt_lens[0],:,:], - scale=scale, - attn_mask=None #causal_mask[:(prefill_q_prompt_lens[0]-shorten_amt),:prefill_kv_prompt_lens[0]] - ) - - attn_bias = None #BlockDiagonalCausalMask.from_seqlens([(prefill_q_prompt_lens[0]-shorten_amt)],[prefill_kv_prompt_lens[0]]) - - xops_out = xops.memory_efficient_attention_forward( - prefill_packed_query.view(-1, num_heads, head_size)[:-shorten_amt,:,:].unsqueeze(0), - prefill_packed_key.view(-1, num_heads, head_size).unsqueeze(0), - prefill_packed_value.view(-1, num_heads, head_size).unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale) - #xops_out=xops_out.view_as(prefill_packed_query) - - assert torch.allclose(xops_out,new_ideal_output.view_as(xops_out)) - - # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) - - is_prompt = False - context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) - - decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) - - # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - - -@pytest.mark.skip -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True device='cuda:0' @@ -762,7 +560,8 @@ def test_prefill_decode_cross_attention_old(num_heads: int, head_size: int, back key, value, scale=scale, - #attn_mask=causal_mask + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens ) prefill_ideal_output = torch.zeros_like(ideal_output) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 46862d72cd7f9..0ef80ca410355 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -392,8 +392,8 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: if attn_metadata.is_cross_attn: - attn_bias = None #BlockDiagonalMask.from_seqlens( - # attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens,attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 2ad68f1932d50a4cbd71068efa1b06df8f18eb54 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 16:35:17 -0400 Subject: [PATCH 020/239] reintroduced completely random Q/K sequence lengths --- tests/layer/test_self_and_cross_attn.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 954e56d8bcc39..42dec244b5c9c 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -25,7 +25,7 @@ NUM_HEADS = [1] -BATCH_SIZES = [2] +BATCH_SIZES = [16] BLOCK_SIZES = [16] #KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] @@ -89,7 +89,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(3, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [random.randint(2, max_q_prompt_len) for _ in range(batch_size)] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -99,8 +99,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [min(q_prompt_len-1,max_kv_prompt_len) - for q_prompt_len in q_prompt_lens] + kv_prompt_lens = [random.randint(2, max_kv_prompt_len) for _ in range(batch_size)] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) @@ -420,7 +419,6 @@ def make_attention(num_heads: int, head_size: int, scale: float): head_size, scale=scale,) -#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -487,18 +485,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) - # attn_bias = BlockDiagonalCausalMask.from_seqlens(prefill_q_prompt_lens) - # xops_out = xops.memory_efficient_attention_forward( - # prefill_packed_query.unsqueeze(0), - # prefill_packed_key.unsqueeze(0), - # prefill_packed_value.unsqueeze(0), - # attn_bias=attn_bias, - # p=0.0, - # scale=scale) - # xops_out=xops_out.view_as(prefill_packed_query) - - # assert torch.allclose(xops_out,prefill_packed_ideal_output[:,0,:]) - # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) From 93e96d493c13484038f7874c7ee7e8fe4119751c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 17:21:29 -0400 Subject: [PATCH 021/239] cross-attention works for both prefill and decode! --- tests/layer/test_self_and_cross_attn.py | 4 +-- vllm/attention/backends/xformers.py | 35 +++++++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 42dec244b5c9c..d024b28cbf13b 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -499,8 +499,6 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n # eval correctness of decode output assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) - -#@pytest.mark.skip @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -568,7 +566,7 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query.contiguous(),prefill_packed_key.contiguous(),prefill_packed_value.contiguous(),kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0ef80ca410355..57316f9e0c967 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -266,20 +266,27 @@ def forward( shape = [num_tokens, num_heads * head_size] """ query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache is not None: + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + + if (kv_cache is not None): + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + if (key is not None) and (value is not None): + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -294,7 +301,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not is_cross_attn: + if not is_cross_attn and key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -339,8 +346,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.seq_lens_tensor if not is_cross_attn else torch.tensor(decode_meta.cross_seq_lens,dtype=decode_meta.seq_lens_tensor.dtype,device=decode_meta.seq_lens_tensor.device), + decode_meta.max_decode_seq_len if not is_cross_attn else max(decode_meta.cross_seq_lens), self.kv_cache_dtype, self.num_kv_heads, self.scale, From 5d91c94c990d531a059ce3e42b131b88d8c0f0bd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 19:01:31 -0400 Subject: [PATCH 022/239] test refactoring: adding function comments, removing unnecessary comments & arguments --- tests/layer/test_self_and_cross_attn.py | 222 ++++++++++++++---------- 1 file changed, 131 insertions(+), 91 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d024b28cbf13b..751f6b3d15f5d 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -5,7 +5,7 @@ import pytest import torch import copy -from vllm.attention import Attention, AttentionMetadata #, AttentionMetadataPerStage +from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.xformers import XFormersBackend from vllm.attention.backends.abstract import AttentionBackend @@ -23,25 +23,33 @@ # [64, 80, 96, 112, 128, 256 # ] if not is_hip() else [64, 80, 96, 112, 128] -NUM_HEADS = [1] +NUM_HEADS = [16] BATCH_SIZES = [16] BLOCK_SIZES = [16] -#KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] BACKEND_NAMES = ["xformers"] -#CUDA_DEVICES = [ -# f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -#] +CUDA_DEVICE="cuda:0" PROMPT_LENS = [128] -Q_PROMPT_LENS = [129] +Q_PROMPT_LENS = [128] K_PROMPT_LENS = [128] -def build_causal_mask(q_max_prompt_len, k_max_prompt_len): +def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): + ''' + Create a q_max_prompt_len x kv_max_prompt_len causal mask + + Arguments: + * q_max_prompt_len: query max prompt len + * kv_max_prompt_len: key/value max prompt len + + Returns: + * 2D tensor, q_max_prompt_len x kv_max_prompt_len + ''' + # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, k_max_prompt_len), diagonal=1) #.transpose(0, 1) + mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) return mask @@ -55,14 +63,31 @@ def ref_masked_attention( q_prompt_lens: Optional[List] = None, kv_prompt_lens: Optional[List] = None ) -> torch.Tensor: - #query=query.unsqueeze(-2) - #key=key.unsqueeze(-2) - #value=value.unsqueeze(-2) - #assert False,f"{query.shape} ; {key.shape}" + ''' + "Golden" masked attention reference. Supports two types of masking: + * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal + + Arguments: + * query: batch_size x q_padded_seq_len x num_heads x head_size + * key: batch_size x kv_padded_seq_len x num_heads x head_size + * value: batch_size x kv_padded_seq_len x num_heads x head_size + * scale: Attention scale factor + * Custom mask: custom attention mask; good place to inject a causal attention mask + * q_prompt_lens: list of unpadded query seq_lens for each batch index + * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index + + Returns: + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size + ''' + + batch_size = query.shape[0] + assert(len(q_prompt_lens) == batch_size) + assert(len(kv_prompt_lens) == batch_size) + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - #assert False,f"{query.shape} ; {key.shape} ; {attn_weights.shape}" - # Lowest-level attention mask, derived from prompt lens + # Basic attention mask, derived from prompt lens if (q_prompt_lens is not None) or (kv_prompt_lens is not None): attn_mask = torch.zeros_like(attn_weights) if q_prompt_lens is not None: @@ -80,11 +105,48 @@ def ref_masked_attention( attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) - #assert False, f"{attn_weights.shape} ; {value.shape} ; {out.shape}" return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_attn=True, force_max_len=False): - #assert max_kv_prompt_len >= max_q_prompt_len +def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, is_cross_attn=True, force_max_len=False, device=CUDA_DEVICE): + ''' + Construct QKV test tensors for self- and cross-attention. + + Generates three query/key/value triplets: + * "Baseline" query/key/value (for input to reference attention function) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, for use as input to decode kernel) + + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v seqlens + + Arguments: + * batch_size + * max_q_prompt_len: max query prompt len + * max_kv_prompt_len: max key/value prompt len + * num_heads + * head_size + * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_prompt_len is unused) + * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens and max_kv_prompt_len, unless forced by is_cross_attn=False + * device: CPU or CUDA device + + Returns: + * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x head_size + * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x head_size + * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x head_size + * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * decode_query: batch_size x 1 x num_heads x head_size + * decode_key: batch_size x 1 x num_heads x head_size + * decode_value: batch_size x 1 x num_heads x head_size + * q_prompt_lens: "baseline" query seqlen list + * kv_prompt_lens: "baseline" key/value seqlen list + * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= max_q_prompt_len due to randomness) + * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may be <= max_kv_prompt_len due to randomness) + * prefill_q_prompt_lens: "prefill" query seqlen list + * prefill_kv_prompt_lens: "prefill" key/value seqlen list + * decode_q_prompt_lens: "decode" query seqlen list (all ones) + * decode_kv_prompt_lens: "decode" key/value seqlen list + ''' if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] @@ -95,7 +157,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a # K,V prompt lens match Q for self-attention kv_prompt_lens = q_prompt_lens else: - # K,V prompt lens come from K,V operands + # K,V prompt lens are distinct from Q prompt lens & random if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: @@ -104,17 +166,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,head_size)).cuda() - key=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() - value=torch.rand((batch_size,max_kv_prompt_len,head_size)).cuda() + query=torch.rand((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) + key=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) + value=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - prefill_query=torch.zeros((batch_size,max_q_prompt_len-1,head_size)).cuda() - prefill_key=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() - prefill_value=torch.zeros((batch_size,max_kv_prompt_len-1,head_size)).cuda() + prefill_query=torch.zeros((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) + prefill_key=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) + prefill_value=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - decode_query=torch.zeros((batch_size,1,head_size)).cuda() - decode_key=torch.zeros((batch_size,1,head_size)).cuda() - decode_value=torch.zeros((batch_size,1,head_size)).cuda() + decode_query=torch.zeros((batch_size,1,num_heads*head_size)).to(device) + decode_key=torch.zeros((batch_size,1,num_heads*head_size)).to(device) + decode_value=torch.zeros((batch_size,1,num_heads*head_size)).to(device) for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): query[bdx,q_prompt_len:,:] = 0 @@ -135,17 +197,17 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a decode_q_prompt_lens = [1 for _ in q_prompt_lens] decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - query=query.unsqueeze(-2) - key=key.unsqueeze(-2) - value=value.unsqueeze(-2) + query=query.view(batch_size,query.shape[1],num_heads,head_size) + key=key.view(batch_size,key.shape[1],num_heads,head_size) + value=value.view(batch_size,value.shape[1],num_heads,head_size) - prefill_query=prefill_query.unsqueeze(-2) - prefill_key=prefill_key.unsqueeze(-2) - prefill_value=prefill_value.unsqueeze(-2) + prefill_query=prefill_query.view(batch_size,prefill_query.shape[1],num_heads,head_size) + prefill_key=prefill_key.view(batch_size,prefill_key.shape[1],num_heads,head_size) + prefill_value=prefill_value.view(batch_size,prefill_value.shape[1],num_heads,head_size) - decode_query=decode_query.unsqueeze(-2) - decode_key=decode_key.unsqueeze(-2) - decode_value=decode_value.unsqueeze(-2) + decode_query=decode_query.view(batch_size,decode_query.shape[1],num_heads,head_size) + decode_key=decode_key.view(batch_size,decode_key.shape[1],num_heads,head_size) + decode_value=decode_value.view(batch_size,decode_value.shape[1],num_heads,head_size) return query, \ key, \ @@ -165,7 +227,7 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size, is_cross_a decode_q_prompt_lens, \ decode_kv_prompt_lens -def pack_tensor(unpacked_tensor,prompt_lens, device='cuda:0'): +def pack_tensor(unpacked_tensor,prompt_lens, device=CUDA_DEVICE): num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] @@ -195,7 +257,7 @@ def make_backend(backend_name: str) -> AttentionBackend: return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, device='cuda:0', cross_prompt_lens:Optional[List[int]] = None) -> tuple: +def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:List[int], device=CUDA_DEVICE) -> tuple: ''' Assumptions: * No chunked prefill @@ -239,9 +301,7 @@ def make_metadata_tensors(attn_backend:AttentionBackend, is_prompt:bool, is_cros seq_start_loc, \ query_start_loc -def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0', default_val=0.0): - #key_cache = torch.rand((num_blocks, num_heads, head_size//key_read_width, block_size, key_read_width),device=device) - #val_cache = torch.rand((num_blocks, num_heads, head_size, block_size),device=device) +def make_kv_cache(num_blocks, num_heads, head_size, block_size, device=CUDA_DEVICE, default_val=0.0): kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: kv_cache[:,:,:] = default_val @@ -250,18 +310,16 @@ def make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, def num_tokens_to_min_blocks(num_tokens,block_size): return (num_tokens+block_size)//block_size -def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): +def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): ''' Naive block table: * For each batch element... * Block table has ''' - num_prompts = len(prompt_lens) - total_num_tokens = sum(prompt_lens) + # Over-provision block table blocks by 1 num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] max_block_table_len = max(num_blocks_list) - #block_tables = [list(range(num_blocks*10)) for num_blocks in num_blocks_list] block_table_pad_tokens = 10 block_tables = [] @@ -269,9 +327,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): decode_slot_mapping = [] slot_mapping = [] block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed - #seq_base_idx = 0 for sdx,num_tokens in enumerate(prompt_lens): - #num_blocks = num_tokens_to_min_blocks(num_tokens,block_size) num_blocks = num_blocks_list[sdx] block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) for idx in range(num_tokens-1): @@ -281,13 +337,12 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - #seq_base_idx += num_tokens block_base_idx -= num_blocks block_tables.append(block_table) prefill_block_tables_tensor = torch.tensor( [], - device='cuda:0' + device=CUDA_DEVICE ) decode_block_tables_tensor = make_tensor_with_pad( block_tables, @@ -320,7 +375,7 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device='cuda:0'): return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor -def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device='cuda:0', kv_cache_dtype='auto', cross_prompt_lens:Optional[List[int]] = None): +def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device=CUDA_DEVICE, cross_prompt_lens:Optional[List[int]] = None): ''' Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -334,17 +389,13 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b prompt_lens_tensor, \ context_lens_tensor, \ max_query_len, \ - max_context_len, \ - max_prompt_len, \ + _, \ + _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(attn_backend, - is_prompt, - is_cross_attn, + query_start_loc = make_metadata_tensors(is_prompt, prompt_lens, context_lens, - block_tables, - device=device, - cross_prompt_lens=cross_prompt_lens) + device=device) slot_mapping_tensor=torch.tensor(slot_mapping, dtype=torch.long, @@ -378,17 +429,13 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b prompt_lens_tensor, \ context_lens_tensor, \ max_query_len, \ - max_context_len, \ - max_prompt_len, \ + _, \ + _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(attn_backend, - is_prompt, - is_cross_attn, + query_start_loc = make_metadata_tensors(is_prompt, prompt_lens, context_lens, - block_tables, - device=device, - cross_prompt_lens=cross_prompt_lens) + device=device) slot_mapping_tensor=torch.tensor(slot_mapping, dtype=torch.long, @@ -428,15 +475,12 @@ def make_attention(num_heads: int, head_size: int, scale: float): def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: # Attention operator instance is_cross_attn=False - device='cuda:0' - kv_cache_dtype='auto' is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -456,9 +500,9 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=False) + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).cuda() + causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention( query, key, @@ -479,25 +523,25 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=None) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=None) prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping) decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -509,13 +553,10 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: # Attention operator instance is_cross_attn=True - device='cuda:0' - kv_cache_dtype='auto' is_prompt = True context_lens = [0 for _ in range(batch_size)] - key_read_width = 4 num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size, key_read_width, device='cuda:0') + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -531,14 +572,13 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ _, \ q_prompt_lens, \ kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ + _, \ + _, \ prefill_q_prompt_lens, \ _, \ decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,head_size,is_cross_attn=is_cross_attn) + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - #causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len) ideal_output = ref_masked_attention( query, key, @@ -562,22 +602,22 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ # - Decode slot-mapping is empty decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,prefill_q_start_loc_list,prefill_kv_start_loc_list = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output[:,0,:]) + assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, device=device, kv_cache_dtype=kv_cache_dtype, cross_prompt_lens=kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, cross_prompt_lens=kv_prompt_lens) - decode_packed_query,decode_packed_key,decode_packed_value,decode_q_start_loc_list,decode_kv_start_loc_list = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) + decode_packed_query,_,_,_,_ = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output[:,0,:]) \ No newline at end of file + assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) \ No newline at end of file From 7cc88f780598d86a51563b20a1acf552eddc1afd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 19:08:02 -0400 Subject: [PATCH 023/239] formatting --- tests/layer/test_self_and_cross_attn.py | 556 +++++++++++++++--------- vllm/attention/backends/xformers.py | 28 +- 2 files changed, 361 insertions(+), 223 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 751f6b3d15f5d..18afff46d8371 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -18,7 +18,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64] +HEAD_SIZES = [64] # [64, 80, 96, 112, 128, 256 # ] if not is_hip() else [64, 80, 96, 112, 128] @@ -28,7 +28,7 @@ BATCH_SIZES = [16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] -CUDA_DEVICE="cuda:0" +CUDA_DEVICE = "cuda:0" PROMPT_LENS = [128] @@ -36,6 +36,7 @@ K_PROMPT_LENS = [128] + def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): ''' Create a q_max_prompt_len x kv_max_prompt_len causal mask @@ -49,20 +50,22 @@ def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), diagonal=1) + mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), + diagonal=1) # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) return mask + def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_prompt_lens: Optional[List] = None, - kv_prompt_lens: Optional[List] = None -) -> torch.Tensor: + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_prompt_lens: Optional[List] = None, + kv_prompt_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements @@ -82,8 +85,8 @@ def ref_masked_attention( ''' batch_size = query.shape[0] - assert(len(q_prompt_lens) == batch_size) - assert(len(kv_prompt_lens) == batch_size) + assert (len(q_prompt_lens) == batch_size) + assert (len(kv_prompt_lens) == batch_size) attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() @@ -91,11 +94,11 @@ def ref_masked_attention( if (q_prompt_lens is not None) or (kv_prompt_lens is not None): attn_mask = torch.zeros_like(attn_weights) if q_prompt_lens is not None: - for bdx,plen in enumerate(q_prompt_lens): - attn_mask[bdx,:,plen:,:] = -torch.inf + for bdx, plen in enumerate(q_prompt_lens): + attn_mask[bdx, :, plen:, :] = -torch.inf if kv_prompt_lens is not None: - for bdx,plen in enumerate(kv_prompt_lens): - attn_mask[bdx,:,:,plen:] = -torch.inf + for bdx, plen in enumerate(kv_prompt_lens): + attn_mask[bdx, :, :, plen:] = -torch.inf attn_weights = attn_weights + attn_mask.float() @@ -107,7 +110,15 @@ def ref_masked_attention( out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out -def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, is_cross_attn=True, force_max_len=False, device=CUDA_DEVICE): + +def make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=True, + force_max_len=False, + device=CUDA_DEVICE): ''' Construct QKV test tensors for self- and cross-attention. @@ -151,7 +162,9 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, if force_max_len: q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] else: - q_prompt_lens = [random.randint(2, max_q_prompt_len) for _ in range(batch_size)] + q_prompt_lens = [ + random.randint(2, max_q_prompt_len) for _ in range(batch_size) + ] kv_prompt_lens = None if not is_cross_attn: # K,V prompt lens match Q for self-attention @@ -161,35 +174,53 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, if force_max_len: kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] else: - kv_prompt_lens = [random.randint(2, max_kv_prompt_len) for _ in range(batch_size)] + kv_prompt_lens = [ + random.randint(2, max_kv_prompt_len) for _ in range(batch_size) + ] actual_max_q_prompt_len = max(q_prompt_lens) actual_max_kv_prompt_len = max(kv_prompt_lens) - query=torch.rand((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) - key=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - value=torch.rand((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - - prefill_query=torch.zeros((batch_size,max_q_prompt_len,num_heads*head_size)).to(device) - prefill_key=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - prefill_value=torch.zeros((batch_size,max_kv_prompt_len,num_heads*head_size)).to(device) - - decode_query=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - decode_key=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - decode_value=torch.zeros((batch_size,1,num_heads*head_size)).to(device) - - for bdx,(q_prompt_len,kv_prompt_len) in enumerate(zip(q_prompt_lens,kv_prompt_lens)): - query[bdx,q_prompt_len:,:] = 0 - key[bdx,kv_prompt_len:,:] = 0 - value[bdx,kv_prompt_len:,:] = 0 - - prefill_query[bdx,0:(q_prompt_len-1),:] = query[bdx,0:(q_prompt_len-1),:] - prefill_key[bdx,0:(kv_prompt_len-1),:] = key[bdx,0:(kv_prompt_len-1),:] - prefill_value[bdx,0:(kv_prompt_len-1),:] = value[bdx,0:(kv_prompt_len-1),:] - - decode_query[bdx,:,:] = query[bdx,(q_prompt_len-1):q_prompt_len,:] - decode_key[bdx,:,:] = key[bdx,(kv_prompt_len-1):kv_prompt_len,:] - decode_value[bdx,:,:] = value[bdx,(kv_prompt_len-1):kv_prompt_len,:] + query = torch.rand( + (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + key = torch.rand( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + value = torch.rand( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + + prefill_query = torch.zeros( + (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + prefill_key = torch.zeros( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + prefill_value = torch.zeros( + (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + + decode_query = torch.zeros( + (batch_size, 1, num_heads * head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads * head_size)).to(device) + decode_value = torch.zeros( + (batch_size, 1, num_heads * head_size)).to(device) + + for bdx, (q_prompt_len, + kv_prompt_len) in enumerate(zip(q_prompt_lens, kv_prompt_lens)): + query[bdx, q_prompt_len:, :] = 0 + key[bdx, kv_prompt_len:, :] = 0 + value[bdx, kv_prompt_len:, :] = 0 + + prefill_query[bdx, + 0:(q_prompt_len - 1), :] = query[bdx, + 0:(q_prompt_len - 1), :] + prefill_key[bdx, + 0:(kv_prompt_len - 1), :] = key[bdx, + 0:(kv_prompt_len - 1), :] + prefill_value[bdx, 0:(kv_prompt_len - + 1), :] = value[bdx, 0:(kv_prompt_len - 1), :] + + decode_query[bdx, :, :] = query[bdx, + (q_prompt_len - 1):q_prompt_len, :] + decode_key[bdx, :, :] = key[bdx, (kv_prompt_len - 1):kv_prompt_len, :] + decode_value[bdx, :, :] = value[bdx, + (kv_prompt_len - 1):kv_prompt_len, :] prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] @@ -197,17 +228,23 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, decode_q_prompt_lens = [1 for _ in q_prompt_lens] decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] - query=query.view(batch_size,query.shape[1],num_heads,head_size) - key=key.view(batch_size,key.shape[1],num_heads,head_size) - value=value.view(batch_size,value.shape[1],num_heads,head_size) + query = query.view(batch_size, query.shape[1], num_heads, head_size) + key = key.view(batch_size, key.shape[1], num_heads, head_size) + value = value.view(batch_size, value.shape[1], num_heads, head_size) - prefill_query=prefill_query.view(batch_size,prefill_query.shape[1],num_heads,head_size) - prefill_key=prefill_key.view(batch_size,prefill_key.shape[1],num_heads,head_size) - prefill_value=prefill_value.view(batch_size,prefill_value.shape[1],num_heads,head_size) + prefill_query = prefill_query.view(batch_size, prefill_query.shape[1], + num_heads, head_size) + prefill_key = prefill_key.view(batch_size, prefill_key.shape[1], num_heads, + head_size) + prefill_value = prefill_value.view(batch_size, prefill_value.shape[1], + num_heads, head_size) - decode_query=decode_query.view(batch_size,decode_query.shape[1],num_heads,head_size) - decode_key=decode_key.view(batch_size,decode_key.shape[1],num_heads,head_size) - decode_value=decode_value.view(batch_size,decode_value.shape[1],num_heads,head_size) + decode_query = decode_query.view(batch_size, decode_query.shape[1], + num_heads, head_size) + decode_key = decode_key.view(batch_size, decode_key.shape[1], num_heads, + head_size) + decode_value = decode_value.view(batch_size, decode_value.shape[1], + num_heads, head_size) return query, \ key, \ @@ -227,51 +264,62 @@ def make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size, decode_q_prompt_lens, \ decode_kv_prompt_lens -def pack_tensor(unpacked_tensor,prompt_lens, device=CUDA_DEVICE): + +def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] - start_loc_list = [0]+list(itertools.accumulate(prompt_lens)) - packed_tensor = torch.zeros((num_tok,num_heads,head_size), - device=device) + start_loc_list = [0] + list(itertools.accumulate(prompt_lens)) + packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx,(prompt_len,start_loc) in enumerate(zip(prompt_lens,start_loc_list)): + for bdx, (prompt_len, + start_loc) in enumerate(zip(prompt_lens, start_loc_list)): try: - packed_tensor[start_loc:(start_loc+prompt_len),:,:] = unpacked_tensor[bdx,:prompt_len,:,:] + packed_tensor[start_loc:( + start_loc + + prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] except: assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" - return packed_tensor,start_loc_list - -def pack_qkv(query,key,value,q_prompt_lens,kv_prompt_lens): - packed_query,q_start_loc_list = pack_tensor(query,q_prompt_lens) - packed_key,kv_start_loc_list = pack_tensor(key,kv_prompt_lens) - packed_value,_ = pack_tensor(value,kv_prompt_lens) - packed_query=packed_query.view(-1,packed_query.shape[-1]*packed_query.shape[-2]) - packed_key=packed_key.view(-1,packed_key.shape[-1]*packed_key.shape[-2]) - packed_value=packed_value.view(-1,packed_value.shape[-1]*packed_value.shape[-2]) - return packed_query,packed_key,packed_value,q_start_loc_list,kv_start_loc_list + return packed_tensor, start_loc_list + + +def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): + packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) + packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) + packed_value, _ = pack_tensor(value, kv_prompt_lens) + packed_query = packed_query.view( + -1, packed_query.shape[-1] * packed_query.shape[-2]) + packed_key = packed_key.view(-1, + packed_key.shape[-1] * packed_key.shape[-2]) + packed_value = packed_value.view( + -1, packed_value.shape[-1] * packed_value.shape[-2]) + return packed_query, packed_key, packed_value, q_start_loc_list, kv_start_loc_list + def make_backend(backend_name: str) -> AttentionBackend: if backend_name == "xformers": return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" -def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:List[int], device=CUDA_DEVICE) -> tuple: + +def make_metadata_tensors(is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + device=CUDA_DEVICE) -> tuple: ''' Assumptions: * No chunked prefill * No (automatic) prefix caching * Packed variable-length sequences ''' - prompt_lens_tensor=torch.tensor(prompt_lens, - dtype=torch.int, - device=device) - context_lens_tensor=None if context_lens is None else torch.tensor(context_lens, - dtype=torch.int, - device=device) - max_context_len=None if context_lens is None else max(context_lens) - max_prompt_len=None if prompt_lens is None else max(prompt_lens) + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.int, + device=device) + context_lens_tensor = None if context_lens is None else torch.tensor( + context_lens, dtype=torch.int, device=device) + max_context_len = None if context_lens is None else max(context_lens) + max_prompt_len = None if prompt_lens is None else max(prompt_lens) seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, dtype=torch.int32, @@ -285,7 +333,7 @@ def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:Li if is_prompt: # Prefill: query_start_loc matches seq_start_loc query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len=max_prompt_len + max_query_len = max_prompt_len else: # Decode: one new query input token per batch # element, thus query_start_loc is the cumsum @@ -301,16 +349,27 @@ def make_metadata_tensors(is_prompt:bool, prompt_lens:List[int], context_lens:Li seq_start_loc, \ query_start_loc -def make_kv_cache(num_blocks, num_heads, head_size, block_size, device=CUDA_DEVICE, default_val=0.0): - kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(device) + +def make_kv_cache(num_blocks, + num_heads, + head_size, + block_size, + device=CUDA_DEVICE, + default_val=0.0): + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: - kv_cache[:,:,:] = default_val + kv_cache[:, :, :] = default_val return kv_cache -def num_tokens_to_min_blocks(num_tokens,block_size): - return (num_tokens+block_size)//block_size -def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): +def num_tokens_to_min_blocks(num_tokens, block_size): + return (num_tokens + block_size) // block_size + + +def make_block_tables_slot_mapping(block_size, + prompt_lens, + device=CUDA_DEVICE): ''' Naive block table: * For each batch element... @@ -318,7 +377,10 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): ''' # Over-provision block table blocks by 1 - num_blocks_list = [num_tokens_to_min_blocks(num_tokens,block_size)+1 for num_tokens in prompt_lens] + num_blocks_list = [ + num_tokens_to_min_blocks(num_tokens, block_size) + 1 + for num_tokens in prompt_lens + ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -326,56 +388,60 @@ def make_block_tables_slot_mapping(block_size,prompt_lens,device=CUDA_DEVICE): prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = sum(num_blocks_list)*2-1 # Support more blocks than needed - for sdx,num_tokens in enumerate(prompt_lens): + block_base_idx = sum( + num_blocks_list) * 2 - 1 # Support more blocks than needed + for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] - block_table = list(range(block_base_idx,block_base_idx-num_blocks,-1)) - for idx in range(num_tokens-1): - prefill_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - idx = num_tokens-1 - decode_slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) - slot_mapping.append((idx % block_size) + block_table[idx//block_size]*block_size) + block_table = list( + range(block_base_idx, block_base_idx - num_blocks, -1)) + for idx in range(num_tokens - 1): + prefill_slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * + block_size) + slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) + idx = num_tokens - 1 + decode_slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) + slot_mapping.append((idx % block_size) + + block_table[idx // block_size] * block_size) block_base_idx -= num_blocks block_tables.append(block_table) - - prefill_block_tables_tensor = torch.tensor( - [], - device=CUDA_DEVICE - ) + + prefill_block_tables_tensor = torch.tensor([], device=CUDA_DEVICE) decode_block_tables_tensor = make_tensor_with_pad( block_tables, - max_len=max_block_table_len+block_table_pad_tokens, + max_len=max_block_table_len + block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = torch.tensor( - prefill_slot_mapping, - dtype=torch.long, - device=device - ) - decode_slot_mapping_tensor = torch.tensor( - decode_slot_mapping, - dtype=torch.long, - device=device - ) - slot_mapping_tensor = torch.tensor( - slot_mapping, - dtype=torch.long, - device=device - ) - empty_slot_mapping_tensor = torch.tensor( - [], - dtype=torch.long, - device=device - ) + prefill_slot_mapping_tensor = torch.tensor(prefill_slot_mapping, + dtype=torch.long, + device=device) + decode_slot_mapping_tensor = torch.tensor(decode_slot_mapping, + dtype=torch.long, + device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) + empty_slot_mapping_tensor = torch.tensor([], + dtype=torch.long, + device=device) return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor - - -def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:bool, prompt_lens:List[int], context_lens:List[int], block_tables, slot_mapping, device=CUDA_DEVICE, cross_prompt_lens:Optional[List[int]] = None): + + +def make_metadata(attn_backend: AttentionBackend, + is_prompt: bool, + is_cross_attn: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_prompt_lens: Optional[List[int]] = None): ''' Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -392,14 +458,14 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b _, \ _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, device=device) - slot_mapping_tensor=torch.tensor(slot_mapping, - dtype=torch.long, - device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -417,10 +483,9 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b block_tables=block_tables, use_cuda_graph=False, is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens - ) + cross_seq_lens=cross_prompt_lens) - else: # not is_prompt + else: # not is_prompt num_prefills = 0 num_prefill_tokens = 0 @@ -432,14 +497,14 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b _, \ _, \ seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, device=device) - slot_mapping_tensor=torch.tensor(slot_mapping, - dtype=torch.long, - device=device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -457,30 +522,36 @@ def make_metadata(attn_backend:AttentionBackend, is_prompt:bool, is_cross_attn:b block_tables=block_tables, use_cuda_graph=False, is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens - ) + cross_seq_lens=cross_prompt_lens) + def make_attention(num_heads: int, head_size: int, scale: float): # Attention operator instance - return Attention(num_heads, - head_size, - scale=scale,) + return Attention( + num_heads, + head_size, + scale=scale, + ) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_prompt_len",PROMPT_LENS) -def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_prompt_len: int) -> None: +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_prompt_len", PROMPT_LENS) +def test_prefill_decode_self_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, + max_prompt_len: int) -> None: # Attention operator instance - is_cross_attn=False + is_cross_attn = False is_prompt = True max_q_prompt_len = max_prompt_len max_kv_prompt_len = max_q_prompt_len context_lens = [0 for _ in range(batch_size)] num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -502,61 +573,94 @@ def test_prefill_decode_self_attention(num_heads: int, head_size: int, backend_n decode_q_prompt_lens, \ decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens - ) + causal_mask = build_causal_mask(max_q_prompt_len, + max_kv_prompt_len).to(CUDA_DEVICE) + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping(block_size,q_prompt_lens) - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=None) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping( + block_size, q_prompt_lens) + prefill_attn_metadata: AttentionMetadata = make_metadata( + attn_backend, + is_prompt, + is_cross_attn, + prefill_q_prompt_lens, + context_lens, + prefill_block_tables, + prefill_slot_mapping, + cross_prompt_lens=None) - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,prefill_key,prefill_value,prefill_q_prompt_lens,prefill_kv_prompt_lens) + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, + prefill_kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output = attn.forward(prefill_packed_query, + prefill_packed_key, + prefill_packed_value, kv_cache, + prefill_attn_metadata, scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) + assert torch.allclose( + prefill_packed_actual_output, + prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping) - - decode_packed_query,decode_packed_key,decode_packed_value,_,_ = pack_qkv(decode_query,decode_key,decode_value,decode_q_prompt_lens,decode_kv_prompt_lens) + decode_attn_metadata = make_metadata(attn_backend, is_prompt, + is_cross_attn, q_prompt_lens, + context_lens, decode_block_tables, + decode_slot_mapping) + + decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_query, decode_key, decode_value, decode_q_prompt_lens, + decode_kv_prompt_lens) - decode_packed_actual_output=attn.forward(decode_packed_query,decode_packed_key,decode_packed_value,kv_cache,decode_attn_metadata,scale) + decode_packed_actual_output = attn.forward(decode_packed_query, + decode_packed_key, + decode_packed_value, kv_cache, + decode_attn_metadata, scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) + assert torch.allclose( + decode_packed_actual_output, + decode_packed_ideal_output.view_as(decode_packed_actual_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size",BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len",Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len",K_PROMPT_LENS) -def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +def test_prefill_decode_cross_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: # Attention operator instance - is_cross_attn=True + is_cross_attn = True is_prompt = True context_lens = [0 for _ in range(batch_size)] num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) scale = float(1.0 / (head_size**0.5)) attn = make_attention(num_heads, head_size, scale) attn_backend = make_backend(backend_name) @@ -579,45 +683,75 @@ def test_prefill_decode_cross_attention(num_heads: int, head_size: int, backend_ decode_q_prompt_lens, \ _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - ideal_output = ref_masked_attention( - query, - key, - value, - scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens - ) + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:,0:1]) - for bdx,prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx,:prefill_q_prompt_len] = ideal_output[bdx,:prefill_q_prompt_len] - decode_ideal_output[bdx,:] = ideal_output[bdx,prefill_q_prompt_len:(prefill_q_prompt_len+1)] - - prefill_packed_ideal_output,_ = pack_tensor(prefill_ideal_output,prefill_q_prompt_lens) - decode_packed_ideal_output,_ = pack_tensor(decode_ideal_output,[1 for _ in range(batch_size)]) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping(block_size,kv_prompt_lens) - - prefill_attn_metadata:AttentionMetadata = make_metadata(attn_backend, is_prompt, is_cross_attn,prefill_q_prompt_lens, context_lens, prefill_block_tables, prefill_slot_mapping, cross_prompt_lens=kv_prompt_lens) + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping( + block_size, kv_prompt_lens) + + prefill_attn_metadata: AttentionMetadata = make_metadata( + attn_backend, + is_prompt, + is_cross_attn, + prefill_q_prompt_lens, + context_lens, + prefill_block_tables, + prefill_slot_mapping, + cross_prompt_lens=kv_prompt_lens) - prefill_packed_query,prefill_packed_key,prefill_packed_value,_,_ = pack_qkv(prefill_query,key,value,prefill_q_prompt_lens,kv_prompt_lens) + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, key, value, prefill_q_prompt_lens, kv_prompt_lens) - prefill_packed_actual_output=attn.forward(prefill_packed_query,prefill_packed_key,prefill_packed_value,kv_cache,prefill_attn_metadata,scale) + prefill_packed_actual_output = attn.forward(prefill_packed_query, + prefill_packed_key, + prefill_packed_value, kv_cache, + prefill_attn_metadata, scale) # eval correctness of prefill output - assert torch.allclose(prefill_packed_actual_output,prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) + assert torch.allclose( + prefill_packed_actual_output, + prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) is_prompt = False context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, is_cross_attn, q_prompt_lens, context_lens, decode_block_tables, decode_slot_mapping, cross_prompt_lens=kv_prompt_lens) - - decode_packed_query,_,_,_,_ = pack_qkv(decode_query,key,value,decode_q_prompt_lens,kv_prompt_lens) - - decode_packed_actual_output=attn.forward(decode_packed_query,None,None,kv_cache,decode_attn_metadata,scale) + decode_attn_metadata = make_metadata(attn_backend, + is_prompt, + is_cross_attn, + q_prompt_lens, + context_lens, + decode_block_tables, + decode_slot_mapping, + cross_prompt_lens=kv_prompt_lens) + + decode_packed_query, _, _, _, _ = pack_qkv(decode_query, key, value, + decode_q_prompt_lens, + kv_prompt_lens) + + decode_packed_actual_output = attn.forward(decode_packed_query, None, None, + kv_cache, decode_attn_metadata, + scale) # eval correctness of decode output - assert torch.allclose(decode_packed_actual_output,decode_packed_ideal_output.view_as(decode_packed_actual_output)) \ No newline at end of file + assert torch.allclose( + decode_packed_actual_output, + decode_packed_ideal_output.view_as(decode_packed_actual_output)) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 57316f9e0c967..3dfe363cbe7f2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -4,8 +4,7 @@ import torch from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalMask, +from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalMask, BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) @@ -155,8 +154,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens - ) + cross_seq_lens=self.cross_seq_lens) return self._cached_prefill_metadata @property @@ -185,8 +183,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens - ) + cross_seq_lens=self.cross_seq_lens) return self._cached_decode_metadata @@ -286,14 +283,17 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + self.kv_cache_dtype, + kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens is_cross_attn = attn_metadata.is_cross_attn - assert is_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert is_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (key.shape[0] + == num_prefill_tokens + num_decode_tokens) + assert is_cross_attn or (value.shape[0] + == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -346,8 +346,12 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.seq_lens_tensor if not is_cross_attn else torch.tensor(decode_meta.cross_seq_lens,dtype=decode_meta.seq_lens_tensor.dtype,device=decode_meta.seq_lens_tensor.device), - decode_meta.max_decode_seq_len if not is_cross_attn else max(decode_meta.cross_seq_lens), + decode_meta.seq_lens_tensor if not is_cross_attn else + torch.tensor(decode_meta.cross_seq_lens, + dtype=decode_meta.seq_lens_tensor.dtype, + device=decode_meta.seq_lens_tensor.device), + decode_meta.max_decode_seq_len + if not is_cross_attn else max(decode_meta.cross_seq_lens), self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -400,7 +404,7 @@ def _run_memory_efficient_xformers_forward( if self.alibi_slopes is None: if attn_metadata.is_cross_attn: attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens,attn_metadata.cross_seq_lens) + attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) From 86591214f2116d04b38688cc691e93b1ecd33c71 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 20 May 2024 22:44:15 -0400 Subject: [PATCH 024/239] Self & cross attention tests pass with new cross-compatible attention metadata structure! --- tests/layer/test_self_and_cross_attn.py | 430 +++++++++++++++++++++++- vllm/attention/backends/xformers.py | 288 ++++++++++++---- 2 files changed, 647 insertions(+), 71 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 18afff46d8371..5f1e35ceb21d7 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -285,11 +285,16 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): - packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) + if query is None: + packed_query = None + q_start_loc_list = None + else: + packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) packed_value, _ = pack_tensor(value, kv_prompt_lens) - packed_query = packed_query.view( - -1, packed_query.shape[-1] * packed_query.shape[-2]) + if packed_query is not None: + packed_query = packed_query.view( + -1, packed_query.shape[-1] * packed_query.shape[-2]) packed_key = packed_key.view(-1, packed_key.shape[-1] * packed_key.shape[-2]) packed_value = packed_value.view( @@ -369,6 +374,7 @@ def num_tokens_to_min_blocks(num_tokens, block_size): def make_block_tables_slot_mapping(block_size, prompt_lens, + block_base_addr=0, device=CUDA_DEVICE): ''' Naive block table: @@ -388,8 +394,8 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = sum( - num_blocks_list) * 2 - 1 # Support more blocks than needed + block_base_idx = block_base_addr + sum(num_blocks_list) * 2 - 1 # Support more blocks than needed + max_block_idx = block_base_idx for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] block_table = list( @@ -430,7 +436,7 @@ def make_block_tables_slot_mapping(block_size, dtype=torch.long, device=device) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor + return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx def make_metadata(attn_backend: AttentionBackend, @@ -525,6 +531,102 @@ def make_metadata(attn_backend: AttentionBackend, cross_seq_lens=cross_prompt_lens) +def make_metadata_self_cross(attn_backend: AttentionBackend, + is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None,): + ''' + Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + ''' + + if is_prompt: + num_prefills = len(prompt_lens) + num_prefill_tokens = sum(prompt_lens) + num_decode_tokens = 0 + + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + _, \ + _, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, + device=device) + + slot_mapping_tensor = slot_mapping + + cross_slot_mapping_tensor = cross_slot_mapping + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max(prompt_lens), + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + is_cross_attn=False, + cross_seq_lens=cross_seq_lens, + cross_slot_mapping=cross_slot_mapping_tensor, + cross_block_tables=cross_block_tables) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = len(prompt_lens) + + prompt_lens_tensor, \ + context_lens_tensor, \ + max_query_len, \ + _, \ + _, \ + seq_start_loc, \ + query_start_loc = make_metadata_tensors(is_prompt, + prompt_lens, + context_lens, + device=device) + + slot_mapping_tensor = slot_mapping + + cross_slot_mapping_tensor = cross_slot_mapping + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=prompt_lens, + seq_lens_tensor=prompt_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=max(prompt_lens), + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + is_cross_attn=False, + cross_seq_lens=cross_seq_lens, + cross_slot_mapping=cross_slot_mapping_tensor, + cross_block_tables=cross_block_tables) + def make_attention(num_heads: int, head_size: int, scale: float): # Attention operator instance return Attention( @@ -534,6 +636,322 @@ def make_attention(num_heads: int, head_size: int, scale: float): ) +def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): + scale = float(1.0 / (head_size**0.5)) + attn_backend = make_backend(backend_name) + attn = make_attention(num_heads, head_size, scale) + kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + return scale, attn_backend, attn, kv_cache + +def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + + max_kv_prompt_len = max_q_prompt_len + + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_prompt_lens, \ + kv_prompt_lens, \ + _, \ + _, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_prompt_len, + max_kv_prompt_len).to(CUDA_DEVICE) + + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + custom_mask=causal_mask, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _, max_block_idx = make_block_tables_slot_mapping( + block_size, q_prompt_lens, block_base_addr=block_base_addr) + + prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, + prefill_kv_prompt_lens) + + decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_query, decode_key, decode_value, decode_q_prompt_lens, + decode_kv_prompt_lens) + + return query, \ + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + prefill_q_prompt_lens, \ + prefill_kv_prompt_lens, \ + decode_packed_query, \ + decode_packed_key, \ + decode_packed_value, \ + decode_packed_ideal_output, \ + decode_q_prompt_lens, \ + decode_kv_prompt_lens, \ + q_prompt_lens, \ + kv_prompt_lens, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + max_block_idx + + +def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): + + _, \ + key, \ + value, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ + kv_prompt_lens, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=True) + + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_prompt_lens=q_prompt_lens, + kv_prompt_lens=kv_prompt_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) + for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): + prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ + bdx, :prefill_q_prompt_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( + prefill_q_prompt_len + 1)] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_prompt_lens) + decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, + [1 for _ in range(batch_size)]) + + # Unlike self-attention: + # - Prefill slot-mapping includes all key slots + # - Decode slot-mapping is empty + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + block_size, kv_prompt_lens, block_base_addr=block_base_addr) + + _, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + None, key, value, prefill_q_prompt_lens, kv_prompt_lens) + + return prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + decode_packed_ideal_output, \ + kv_prompt_lens, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + max_block_idx + +def run_self_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + attn_metadata.do_cross_attn = False + return attn.forward(packed_query, + packed_key, + packed_value, + kv_cache, + attn_metadata, + scale) + +def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + attn_metadata.do_cross_attn = True + return attn.forward(packed_query, + packed_key, + packed_value, + kv_cache, + attn_metadata, + scale) + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, + # attention backend instance, + # attention wrapper instance, + # KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr=0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_prompt_lens, \ + self_prefill_kv_prompt_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + decode_q_prompt_lens, \ + self_decode_kv_prompt_lens, \ + q_prompt_lens, \ + self_kv_prompt_lens, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_prompt_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + final_max_block_idx = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, + True, + prefill_q_prompt_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + cross_seq_lens = cross_kv_prompt_lens, + cross_block_tables = cross_prefill_block_tables, + cross_slot_mapping = cross_prefill_slot_mapping,) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test(attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + scale) + + # - Prefill self-attention correct? + assert torch.allclose(self_prefill_packed_ideal_output,self_prefill_packed_actual_output.view_as(self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, + prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + scale) + + # - Prefill cross-attention correct? + assert torch.allclose(cross_prefill_packed_ideal_output,cross_prefill_packed_actual_output.view_as(cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, + False, + q_prompt_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + cross_seq_lens = cross_kv_prompt_lens, + cross_block_tables = cross_decode_block_tables, + cross_slot_mapping = cross_decode_slot_mapping,) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test(attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + scale) + + assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, + decode_packed_query, + None, + None, + kv_cache, + decode_attn_metadata, + scale) + + assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3dfe363cbe7f2..5b6d2ac0e144f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -105,16 +105,36 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - _cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cached_decode_metadata: Optional["XFormersMetadata"] = None - # Need to make KV cache read-only for cross-attention + # Self-attention prefill/decode metadata cache + _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None + _self_cached_decode_metadata: Optional["XFormersMetadata"] = None + # Cross-attention prefill/decode metadata cache + _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin cross-attention fields... + + # If True, prefill_metadata() and decode_metadata() will return + # seqlen & memory-mapping data structures for cross-attention; + # otherwise, self-attention data structures will be returned. is_cross_attn: bool = False # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention # computation. None if this is self-attention cross_seq_lens: Optional[List[int]] = None + cross_seq_lens_tensor: Optional[torch.Tensor] = None + + # The maximum cross-sequence-length, if cross_seq_lens is specified. + # Note that for cross-attention there is no difference in key/value + # sequence length between prefill and decode + max_cross_seq_len: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -124,67 +144,183 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None + @property + def has_valid_cross_attn_metadata(self): + # No cross-attention metadata is present whatsoever + no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is None) and (self.cross_block_tables is None) + # If any cross-attention metadata is present, it is invalid + invalid_md_if_not_no_md = (self.cross_seq_lens is None) or (self.cross_slot_mapping is None) or (self.cross_block_tables is None) + + if no_md: + return False + + assert (not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + + return True + + @property + def do_cross_attn(self): + return self.is_cross_attn + + @do_cross_attn.setter + def do_cross_attn(self,state:bool): + + if state: + assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" + + # Infer implicit cross-attention fields from user-provided fields, if needed + if self.cross_seq_lens_tensor is None: + self.cross_seq_lens_tensor = torch.tensor(self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_cross_seq_len is None: + self.max_cross_seq_len = max(self.cross_seq_lens) + + self.is_cross_attn = True + else: + self.is_cross_attn = False + @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens) - return self._cached_prefill_metadata + if not self.do_cross_attn: + # Self-attention prefill + + if self._self_cached_prefill_metadata is not None: + return self._self_cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + is_cross_attn=False, # Begin cross-attention fields below... + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_block_tables=None, + cross_slot_mapping=None) + return self._self_cached_prefill_metadata + + else: + # Cross-attention prefill + + if self._cross_cached_prefill_metadata is not None: + return self._cross_cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cross_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + is_cross_attn=True, # Begin cross-attention fields below... + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cross_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - is_cross_attn=self.is_cross_attn, - cross_seq_lens=self.cross_seq_lens) - return self._cached_decode_metadata + if not self.do_cross_attn: + # Self-attention decode + + if self._self_cached_decode_metadata is not None: + return self._self_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + is_cross_attn=False, # Begin cross-attention fields below... + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_block_tables=None, + cross_slot_mapping=None) + return self._self_cached_decode_metadata + + else: + # Cross-attention decode + + if self._cross_cached_decode_metadata is not None: + return self._cross_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cross_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + is_cross_attn=True, # Begin cross-attention fields below... + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cross_cached_decode_metadata class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -268,6 +404,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + do_cross_attn = attn_metadata.do_cross_attn + if (kv_cache is not None): # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache @@ -277,22 +418,30 @@ def forward( if (key is not None) and (value is not None): + if do_cross_attn: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, - attn_metadata.slot_mapping, + updated_slot_mapping, self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - is_cross_attn = attn_metadata.is_cross_attn - assert is_cross_attn or (key.shape[0] + assert do_cross_attn or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert is_cross_attn or (value.shape[0] + assert do_cross_attn or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) @@ -301,7 +450,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not is_cross_attn and key is not None and value is not None: + if not do_cross_attn and key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -323,6 +472,8 @@ def forward( # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. + # + # TODO(afeldman-nm): support cross-attention out = PagedAttention.forward_prefix( query, key, @@ -341,17 +492,24 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: + if do_cross_attn: + # Paged attention against cross-attention KV cache + seq_lens_arg = decode_meta.cross_seq_lens_tensor + max_seq_len_arg = decode_meta.max_cross_seq_len + block_tables_arg = decode_meta.cross_block_tables + else: + # Paged attention against self-attention KV cache + seq_lens_arg = decode_meta.seq_lens_tensor + max_seq_len_arg = decode_meta.max_decode_seq_len + block_tables_arg = decode_meta.block_tables + output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor if not is_cross_attn else - torch.tensor(decode_meta.cross_seq_lens, - dtype=decode_meta.seq_lens_tensor.dtype, - device=decode_meta.seq_lens_tensor.device), - decode_meta.max_decode_seq_len - if not is_cross_attn else max(decode_meta.cross_seq_lens), + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, self.kv_cache_dtype, self.num_kv_heads, self.scale, From 78ce588834c17c7e9b73a5544b1c691492ab8c5c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:15:13 -0400 Subject: [PATCH 025/239] refactoring --- tests/layer/test_self_and_cross_attn.py | 667 +++++++++++------------- 1 file changed, 311 insertions(+), 356 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5f1e35ceb21d7..5cab054b61069 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -16,25 +16,22 @@ import random +# If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] +# +# TODO: # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64] +HEAD_SIZES = [64,256] -# [64, 80, 96, 112, 128, 256 -# ] if not is_hip() else [64, 80, 96, 112, 128] +NUM_HEADS = [1,16] -NUM_HEADS = [16] - -BATCH_SIZES = [16] +BATCH_SIZES = [1,16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" -PROMPT_LENS = [128] - -Q_PROMPT_LENS = [128] - -K_PROMPT_LENS = [128] +MAX_Q_PROMPT_LENS = [128] +MAX_K_PROMPT_LENS = [128] def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): @@ -266,6 +263,22 @@ def make_qkv(batch_size, def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): + ''' + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where + number_of_tokens = sum(prompt_lens) + + Arguments: + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size + * prompt_lens: list of token counts for each prompt + * device: CPU or CUDA device + + Returns + * packed_tensor: number_of_tokens x num_heads x head_size + * start_loc_list: start idx of each batch elt in packed_tensor; + [0] + list(itertools.accumulate(prompt_lens)) + ''' + num_tok = sum(prompt_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] @@ -285,6 +298,30 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): + ''' + Individually pack each of Q, K and V, each with dimensions + batch_size x padded_seq_len x num_heads x head_size, into + respective number_of_tokens x num_heads x head_size tensors. + + For Q, number_of_tokens = sum(q_prompt_lens). + + For K and V, number_of_tokens = sum(kv_prompt_lens) + + Arguments: + * query: batch_size x padded_seq_len x num_heads x head_size + * key: batch_size x padded_seq_len x num_heads x head_size + * value: batch_size x padded_seq_len x num_heads x head_size + * q_prompt_lens: list of token counts for each query + * kv_prompt_lens: list of token counts for each key/value + + Returns + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * q_start_loc_list: start idx of each query in packed_query + * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + ''' + if query is None: packed_query = None q_start_loc_list = None @@ -303,6 +340,16 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): def make_backend(backend_name: str) -> AttentionBackend: + ''' + Construct the backend instance determined by the backend_name string argument. + + "xformers" -> construct xformers backend + + TODO: flash attention backend + + Returns: + * Backend instance + ''' if backend_name == "xformers": return XFormersBackend() assert False, f"Unrecognized backend_name {backend_name} for unit test" @@ -313,10 +360,22 @@ def make_metadata_tensors(is_prompt: bool, context_lens: List[int], device=CUDA_DEVICE) -> tuple: ''' - Assumptions: - * No chunked prefill - * No (automatic) prefix caching - * Packed variable-length sequences + Build scalar & tensor values required to build attention metadata structure. + + Arguments: + * is_prompt: True -> Prefill, False -> Decode + * prompt_lens: list of token-counts for each prompt + * context_lens: list of context length values for each prompt + * device: CPU or CUDA device + + Returns: + * prompt_lens_tensor: prompt_lens list, as tensor + * context_lens_tensor: context_lens list, as tensor + * max_query_len: max(prompt_lens) if is_prompt, o/w 1 + * max_context_len: max(context_lens) + * max_prompt_len: max(prompt_lens) + * seq_start_loc: start idx of each sequence + * query_start_loc: start idx of each query ''' prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.int, @@ -361,6 +420,21 @@ def make_kv_cache(num_blocks, block_size, device=CUDA_DEVICE, default_val=0.0): + ''' + Create a fake KV cache. + + Arguments: + * num_blocks: number of blocks in the KV cache + * num_heads: number of attention heads + * head_size: head dimension + * block_size: number of offsets within a block + * device: CPU or CUDA device + * default_val: initialization value for KV cache elements + + Returns: + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + ''' + kv_cache = torch.rand( (2, num_blocks, block_size * num_heads * head_size)).to(device) if default_val is not None: @@ -369,6 +443,10 @@ def make_kv_cache(num_blocks, def num_tokens_to_min_blocks(num_tokens, block_size): + ''' + Compute the minimum number of blocks required + to hold num_tokens tokens, given block_size + ''' return (num_tokens + block_size) // block_size @@ -377,9 +455,29 @@ def make_block_tables_slot_mapping(block_size, block_base_addr=0, device=CUDA_DEVICE): ''' - Naive block table: - * For each batch element... - * Block table has + Construct fake block tables & slot mappings. + + The first block is at + + block_base_addr + sum(num_blocks_list) * 2 - 1 + + and subsequent blocks count downward toward + block_base_addr + + Arguments: + * block_size: number of offsets per block + * prompt_lens: list of token-counts for each sequence + * block_base_addr: the block table base address + * device: CPU or CUDA device + + Return: + * decode_block_tables_tensor: fake the state of the block tables during decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during prefill + * prefill_block_tables_tensor: fake the state of the block tables during prefill + * slot_mapping_tensor: union of prefill and decode slot mappings + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase cross attention) + * max_block_idx: the highest block address within this block table ''' # Over-provision block table blocks by 1 @@ -438,99 +536,6 @@ def make_block_tables_slot_mapping(block_size, return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx - -def make_metadata(attn_backend: AttentionBackend, - is_prompt: bool, - is_cross_attn: bool, - prompt_lens: List[int], - context_lens: List[int], - block_tables, - slot_mapping, - device=CUDA_DEVICE, - cross_prompt_lens: Optional[List[int]] = None): - ''' - Assumptions: - * No chunked prefill -> a batch is 100% prefill or 100% decode, never both - ''' - - if is_prompt: - num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) - num_decode_tokens = 0 - - prompt_lens_tensor, \ - context_lens_tensor, \ - max_query_len, \ - _, \ - _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, - device=device) - - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens) - - else: # not is_prompt - - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = len(prompt_lens) - - prompt_lens_tensor, \ - context_lens_tensor, \ - max_query_len, \ - _, \ - _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, - context_lens, - device=device) - - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, - max_query_len=max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=max(prompt_lens), - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - is_cross_attn=is_cross_attn, - cross_seq_lens=cross_prompt_lens) - - def make_metadata_self_cross(attn_backend: AttentionBackend, is_prompt: bool, prompt_lens: List[int], @@ -540,10 +545,29 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, device=CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None,): + cross_slot_mapping: Optional[List[int]] = None,) -> AttentionMetadata: ''' + Construct fake attention metadata for a combined + self-/cross-attention scenario i.e. an encoder/decoder + model. + Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + + Arguments: + * attn_backend: Backend for sourcing attention kernels + * is_prompt: prefill if True, o/w decode + * prompt_lens: list of token counts for each sequence + * context_lens: list of context lengths for each sequence + * block_tables: self-attention block tables + * slot_mapping: self-attention slot_mapping + * device: CPU or CUDA device + * cross_seq_lens: list of token counts for each encoder sequence, if any exist + * cross_block_tables: cross-attention block tables, if required + * cross_slot_mapping: cross-attention slot mapping, if required + + Return: + * AttentionMetadata structure supporting self- and cross-attention ''' if is_prompt: @@ -628,7 +652,13 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, cross_block_tables=cross_block_tables) def make_attention(num_heads: int, head_size: int, scale: float): - # Attention operator instance + ''' + Construct an instance of the Attention wrapper, suited to + the number of attention heads and head dimension + (num_heads and head_size respectively) as well as the + attention scale factor (scale) + ''' + return Attention( num_heads, head_size, @@ -637,6 +667,23 @@ def make_attention(num_heads: int, head_size: int, scale: float): def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): + ''' + Compute & build entities required for the self-/cross-attention test. + + Arguments: + * num_heads: Number of attention heads + * head_size: Head dimension + * num_blocks: Number of KV cache blocks + * block_size: Number of offsets within a KV cache block + * backend_name: selection of backend + + Returns: + * scale: 1/sqrt(head_size) + * attn_backend: backend instance + * attn: Attention wrapper instance + * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + ''' + scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) attn = make_attention(num_heads, head_size, scale) @@ -644,6 +691,64 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): return scale, attn_backend, attn, kv_cache def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + ''' + Set up test vectors & data structures for self-attention test. + + A triplet of synthetic query/key/value tensors are constructed ("baseline" query/key/value). + Given this is a self-attention test, the key & value sequences will have the same length + as the corresponding queries. + + "Prefill" query/key/value tensors are derived by masking out the last value in each + baseline query/key/value. These tensors are used to test prefill & populate KV cache + for a subsequent decode test. + + "Decode" query/key/value tensors are derived by extracting *only* the last value from + each baseline query/key/value (i.e. complement of the prefill tensors.) These tensors + are used to test decode, conditional on the kv cache being populated during the + prefill test. + + The baseline query/key/value tensors are passed to an ideal reference self-attention implementation + to generate a "Baseline" ideal output tensor. This tensor is split into the "Prefill" + ideal output tensor (all but the last element of each output sequence) and the "Decode" + ideal output tensor (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode test + results, respectively. + + This function also constructs the self-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts + at block_base_addr + + Arguments: + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_prompt_len: upper limit on query length for synthetic test vectors + * block_base_addr: self-attention block table base address + + Returns: + * query: "baseline" query; batch_size x padded_seq_len x num_heads x head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size + * prefill_q_prompt_lens: list of token counts for each *prefill query* (one less than baseline query) + * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill key/value* (should match prefill_q_prompt_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size + * decode_q_prompt_lens: list of token counts for each *decode query* (should be 1) + * decode_kv_prompt_lens: list of token counts for each self-attn *decode key/value* (should match decode_q_prompt_lens) + * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size + * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size + * decode_block_tables: fake self-attn decode-phase block table + * decode_slot_mapping: fake self-attn decode-phase slot mapping + * prefill_slot_mapping: fake self-attn prefill-phase slot mapping + * prefill_block_tables: fake self-attn prefill-phase block table + * max_block_idx: highest block address in the self-attention block-table + ''' max_kv_prompt_len = max_q_prompt_len @@ -723,6 +828,55 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): + ''' + Set up test vectors & data structures for cross-attention test. + + A triplet of synthetic cross-attention key/value tensors are constructed ("baseline" key/value). + Given this is a cross-attention test, we assume query tensors were already synthesized for a + prior self-attention test and will be reused for cross-attention. The key & value sequences + generated here will may have a different length than the corresponding queries (as is often + the case for cross-attention between decoder and encoder sequences.) + + Cross attention key & value tensors do not grow during autoregressive inference; thus + this function obtains a single key/value pair suitable for both prefill and decode. + + The "baseline" query tensor is received as an argument. The "baseline" query/key/value tensors + are passed to an ideal reference cross-attention implementation + to generate a "baseline" ideal output tensor. This tensor is split into the "Prefill" + ideal output tensor (all but the last element of each output sequence) and the "Decode" + ideal output tensor (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode test + results, respectively. + + This function also constructs the cross-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts + at block_base_addr. + + Arguments: + * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size + * q_prompt_lens: list of token-counts for each "baseline" query sequence + * prefill_q_prompt_lens: list of token-counts for each "prefill" query sequence + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_prompt_len: upper limit on query length for synthetic test vectors + * max_kv_prompt_len: upper limit on key/value length for synthetic test vectors + * block_base_addr: cross-attention block table base address + + Returns: + * packed_key: cross-attention key; number_of_tokens x num_heads x head_size + * packed_value: cross-attention value; number_of_tokens x num_heads x head_size + * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x num_heads x head_size + * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size + * kv_prompt_lens: list of token-counts for each key/value + * decode_block_tables: fake decode-phase block tables + * decode_slot_mapping: fake decode-phase slot mapping + * prefill_slot_mapping: fake prefill-phase slot mapping + * prefill_block_tables: fake prefill-phase block tables + * max_block_idx: highest block address in the cross-attention block-table + ''' _, \ key, \ @@ -768,11 +922,12 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) - _, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - None, key, value, prefill_q_prompt_lens, kv_prompt_lens) + # Packed key/value (query is already provided) + _, packed_key, packed_value, _, _ = pack_qkv( + None, key, value, None, kv_prompt_lens) - return prefill_packed_key, \ - prefill_packed_value, \ + return packed_key, \ + packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ kv_prompt_lens, \ @@ -805,12 +960,32 @@ def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache, @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) +@pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) +@pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_prompt_len: int, max_kv_prompt_len: int) -> None: + ''' + Test: + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention attributes + * Test self- and cross-attention in the following order + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the self-/cross-attention block tables, + which we attempt to avoid + * Validate output correctness against ideal reference attention implementation + + Block tables are constructed such that cross-attention KV cache is in a higher, non-intersecting + address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V tensors. Self-attention + K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq + len, as is often the case for cross-attention. + ''' # Num KV cache blocks num_blocks = 4096 @@ -843,10 +1018,10 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - decode_q_prompt_lens, \ - self_decode_kv_prompt_lens, \ + _, \ + _, \ q_prompt_lens, \ - self_kv_prompt_lens, \ + _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -870,17 +1045,17 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - final_max_block_idx = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_prompt_len, - max_kv_prompt_len, - block_base_addr=cross_block_base_addr) + _ = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests @@ -940,6 +1115,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, decode_attn_metadata, scale) + # - Decode self-attention correct? assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, @@ -950,226 +1126,5 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, decode_attn_metadata, scale) - assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_prompt_len", PROMPT_LENS) -def test_prefill_decode_self_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, - max_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn = False - is_prompt = True - max_q_prompt_len = max_prompt_len - max_kv_prompt_len = max_q_prompt_len - context_lens = [0 for _ in range(batch_size)] - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, - max_kv_prompt_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, - scale=scale, - custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) - - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _ = make_block_tables_slot_mapping( - block_size, q_prompt_lens) - prefill_attn_metadata: AttentionMetadata = make_metadata( - attn_backend, - is_prompt, - is_cross_attn, - prefill_q_prompt_lens, - context_lens, - prefill_block_tables, - prefill_slot_mapping, - cross_prompt_lens=None) - - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, - prefill_kv_prompt_lens) - - prefill_packed_actual_output = attn.forward(prefill_packed_query, - prefill_packed_key, - prefill_packed_value, kv_cache, - prefill_attn_metadata, scale) - - # eval correctness of prefill output - assert torch.allclose( - prefill_packed_actual_output, - prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) - - is_prompt = False - context_lens = copy.deepcopy(prefill_kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, is_prompt, - is_cross_attn, q_prompt_lens, - context_lens, decode_block_tables, - decode_slot_mapping) - - decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_prompt_lens, - decode_kv_prompt_lens) - - decode_packed_actual_output = attn.forward(decode_packed_query, - decode_packed_key, - decode_packed_value, kv_cache, - decode_attn_metadata, scale) - - # eval correctness of decode output - assert torch.allclose( - decode_packed_actual_output, - decode_packed_ideal_output.view_as(decode_packed_actual_output)) - - -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", K_PROMPT_LENS) -def test_prefill_decode_cross_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: - # Attention operator instance - is_cross_attn = True - is_prompt = True - context_lens = [0 for _ in range(batch_size)] - num_blocks = 4096 - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) - scale = float(1.0 / (head_size**0.5)) - attn = make_attention(num_heads, head_size, scale) - attn_backend = make_backend(backend_name) - - query, \ - key, \ - value, \ - prefill_query, \ - _, \ - _, \ - decode_query, \ - _, \ - _, \ - q_prompt_lens, \ - kv_prompt_lens, \ - _, \ - _, \ - prefill_q_prompt_lens, \ - _, \ - decode_q_prompt_lens, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=is_cross_attn) - - ideal_output = ref_masked_attention(query, - key, - value, - scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) - - prefill_ideal_output = torch.zeros_like(ideal_output) - decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] - - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) - decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) - - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping = make_block_tables_slot_mapping( - block_size, kv_prompt_lens) - - prefill_attn_metadata: AttentionMetadata = make_metadata( - attn_backend, - is_prompt, - is_cross_attn, - prefill_q_prompt_lens, - context_lens, - prefill_block_tables, - prefill_slot_mapping, - cross_prompt_lens=kv_prompt_lens) - - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( - prefill_query, key, value, prefill_q_prompt_lens, kv_prompt_lens) - - prefill_packed_actual_output = attn.forward(prefill_packed_query, - prefill_packed_key, - prefill_packed_value, kv_cache, - prefill_attn_metadata, scale) - - # eval correctness of prefill output - assert torch.allclose( - prefill_packed_actual_output, - prefill_packed_ideal_output.view_as(prefill_packed_actual_output)) - - is_prompt = False - context_lens = copy.deepcopy(kv_prompt_lens) - decode_attn_metadata = make_metadata(attn_backend, - is_prompt, - is_cross_attn, - q_prompt_lens, - context_lens, - decode_block_tables, - decode_slot_mapping, - cross_prompt_lens=kv_prompt_lens) - - decode_packed_query, _, _, _, _ = pack_qkv(decode_query, key, value, - decode_q_prompt_lens, - kv_prompt_lens) - - decode_packed_actual_output = attn.forward(decode_packed_query, None, None, - kv_cache, decode_attn_metadata, - scale) - - # eval correctness of decode output - assert torch.allclose( - decode_packed_actual_output, - decode_packed_ideal_output.view_as(decode_packed_actual_output)) + # - Decode cross-attention correct? + assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) \ No newline at end of file From 92701a3bd21cd89c0e004c0bdee2b783a12ce846 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:22:55 -0400 Subject: [PATCH 026/239] some format fixes --- tests/layer/test_self_and_cross_attn.py | 244 +++++++++++++----------- vllm/attention/backends/xformers.py | 40 ++-- 2 files changed, 157 insertions(+), 127 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 5cab054b61069..2fe5add8ae453 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -21,11 +21,11 @@ # TODO: # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64,256] +HEAD_SIZES = [64, 256] -NUM_HEADS = [1,16] +NUM_HEADS = [1, 16] -BATCH_SIZES = [1,16] +BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" @@ -492,7 +492,8 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = block_base_addr + sum(num_blocks_list) * 2 - 1 # Support more blocks than needed + block_base_idx = block_base_addr + sum( + num_blocks_list) * 2 - 1 # Support more blocks than needed max_block_idx = block_base_idx for sdx, num_tokens in enumerate(prompt_lens): num_blocks = num_blocks_list[sdx] @@ -536,16 +537,19 @@ def make_block_tables_slot_mapping(block_size, return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx -def make_metadata_self_cross(attn_backend: AttentionBackend, - is_prompt: bool, - prompt_lens: List[int], - context_lens: List[int], - block_tables, - slot_mapping, - device=CUDA_DEVICE, - cross_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None,) -> AttentionMetadata: + +def make_metadata_self_cross( + attn_backend: AttentionBackend, + is_prompt: bool, + prompt_lens: List[int], + context_lens: List[int], + block_tables, + slot_mapping, + device=CUDA_DEVICE, + cross_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None, +) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention scenario i.e. an encoder/decoder @@ -651,6 +655,7 @@ def make_metadata_self_cross(attn_backend: AttentionBackend, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) + def make_attention(num_heads: int, head_size: int, scale: float): ''' Construct an instance of the Attention wrapper, suited to @@ -690,7 +695,14 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache -def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, block_base_addr=0): + +def self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -772,7 +784,7 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) - + ideal_output = ref_masked_attention(query, key, value, @@ -827,7 +839,17 @@ def self_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_p max_block_idx -def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, batch_size, num_heads, head_size, block_size, scale, max_q_prompt_len, max_kv_prompt_len, block_base_addr=0): +def cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, + block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -919,12 +941,12 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) - + # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv( - None, key, value, None, kv_prompt_lens) + _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, + kv_prompt_lens) return packed_key, \ packed_value, \ @@ -937,23 +959,21 @@ def cross_attn_setup_reuses_query(query, q_prompt_lens, prefill_q_prompt_lens, b prefill_block_tables, \ max_block_idx -def run_self_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + +def run_self_attention_test(attn, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata, scale): attn_metadata.do_cross_attn = False - return attn.forward(packed_query, - packed_key, - packed_value, - kv_cache, - attn_metadata, - scale) - -def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache,attn_metadata:AttentionMetadata,scale): + return attn.forward(packed_query, packed_key, packed_value, kv_cache, + attn_metadata, scale) + + +def run_cross_attention_test(attn, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata, + scale): attn_metadata.do_cross_attn = True - return attn.forward(packed_query, - packed_key, - packed_value, - kv_cache, - attn_metadata, - scale) + return attn.forward(packed_query, packed_key, packed_value, kv_cache, + attn_metadata, scale) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -962,10 +982,10 @@ def run_cross_attention_test(attn,packed_query,packed_key,packed_value,kv_cache, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) @pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) -def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: +def test_prefill_decode_self_and_cross_attention( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_q_prompt_len: int, + max_kv_prompt_len: int) -> None: ''' Test: * Construct fake test vectors for self- and cross-attention @@ -997,15 +1017,15 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, backend_name) # Self-attention setup - self_block_base_addr=0 + self_block_base_addr = 0 query, \ prefill_packed_query, \ @@ -1026,11 +1046,11 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, max_q_prompt_len, block_base_addr=self_block_base_addr) @@ -1045,86 +1065,86 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_prompt_len, - max_kv_prompt_len, + _ = cross_attn_setup_reuses_query(query, + q_prompt_lens, + prefill_q_prompt_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_prompt_len, + max_kv_prompt_len, block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, - True, - prefill_q_prompt_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - cross_seq_lens = cross_kv_prompt_lens, - cross_block_tables = cross_prefill_block_tables, - cross_slot_mapping = cross_prefill_slot_mapping,) - - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test(attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - scale) + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + True, + prefill_q_prompt_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + cross_seq_lens=cross_kv_prompt_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, prefill_packed_query, self_prefill_packed_key, + self_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) # - Prefill self-attention correct? - assert torch.allclose(self_prefill_packed_ideal_output,self_prefill_packed_actual_output.view_as(self_prefill_packed_ideal_output)) + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, - prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - scale) + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) # - Prefill cross-attention correct? - assert torch.allclose(cross_prefill_packed_ideal_output,cross_prefill_packed_actual_output.view_as(cross_prefill_packed_ideal_output)) + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) # DECODE: self- and cross-attention tests - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross(attn_backend, - False, - q_prompt_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - cross_seq_lens = cross_kv_prompt_lens, - cross_block_tables = cross_decode_block_tables, - cross_slot_mapping = cross_decode_slot_mapping,) - - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test(attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - scale) + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + False, + q_prompt_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + cross_seq_lens=cross_kv_prompt_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, decode_packed_query, self_decode_packed_key, + self_decode_packed_value, kv_cache, decode_attn_metadata, scale) # - Decode self-attention correct? - assert torch.allclose(self_decode_packed_ideal_output,self_decode_packed_actual_output.view_as(self_decode_packed_ideal_output)) + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test(attn, - decode_packed_query, - None, - None, - kv_cache, - decode_attn_metadata, - scale) + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata, + scale) # - Decode cross-attention correct? - assert torch.allclose(cross_decode_packed_ideal_output,cross_decode_packed_actual_output.view_as(cross_decode_packed_ideal_output)) \ No newline at end of file + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5b6d2ac0e144f..0c8db7e47a50d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -115,7 +115,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Begin cross-attention fields... - # If True, prefill_metadata() and decode_metadata() will return + # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. is_cross_attn: bool = False @@ -147,14 +147,19 @@ def __post_init__(self): @property def has_valid_cross_attn_metadata(self): # No cross-attention metadata is present whatsoever - no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is None) and (self.cross_block_tables is None) + no_md = (self.cross_seq_lens is + None) and (self.cross_slot_mapping is + None) and (self.cross_block_tables is None) # If any cross-attention metadata is present, it is invalid - invalid_md_if_not_no_md = (self.cross_seq_lens is None) or (self.cross_slot_mapping is None) or (self.cross_block_tables is None) + invalid_md_if_not_no_md = (self.cross_seq_lens is None) or ( + self.cross_slot_mapping is None) or (self.cross_block_tables is + None) if no_md: return False - - assert (not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + + assert ( + not invalid_md_if_not_no_md), "Invalid cross-attention metadata" return True @@ -163,17 +168,20 @@ def do_cross_attn(self): return self.is_cross_attn @do_cross_attn.setter - def do_cross_attn(self,state:bool): + def do_cross_attn(self, state: bool): if state: assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" # Infer implicit cross-attention fields from user-provided fields, if needed if self.cross_seq_lens_tensor is None: - self.cross_seq_lens_tensor = torch.tensor(self.cross_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) + assert self.seq_lens_tensor is not None + self.cross_seq_lens_tensor = torch.tensor( + self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) if self.max_cross_seq_len is None: + assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) self.is_cross_attn = True @@ -209,10 +217,11 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=False, # Begin cross-attention fields below... + is_cross_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -244,10 +253,11 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=True, # Begin cross-attention fields below... + is_cross_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -283,7 +293,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=False, # Begin cross-attention fields below... + is_cross_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -314,7 +324,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=True, # Begin cross-attention fields below... + is_cross_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, From 39788150b2efd4177fc2a0da6615ced032e8a3b0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 00:56:59 -0400 Subject: [PATCH 027/239] refactored long lines in self/cross attn test --- tests/layer/test_self_and_cross_attn.py | 304 +++++++++++++++--------- 1 file changed, 187 insertions(+), 117 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 2fe5add8ae453..4fe9862de99a2 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -12,14 +12,9 @@ from vllm.utils import make_tensor_with_pad -from vllm.attention.layer import Attention - -import random - # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # -# TODO: -# FlashAttention forward only supports head dimension at most 128 +# TODO: FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] @@ -39,10 +34,12 @@ def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): Create a q_max_prompt_len x kv_max_prompt_len causal mask Arguments: + * q_max_prompt_len: query max prompt len * kv_max_prompt_len: key/value max prompt len Returns: + * 2D tensor, q_max_prompt_len x kv_max_prompt_len ''' @@ -65,19 +62,25 @@ def ref_masked_attention( kv_prompt_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: - * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out padding elements - * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal + + * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out + padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. + causal Arguments: + * query: batch_size x q_padded_seq_len x num_heads x head_size * key: batch_size x kv_padded_seq_len x num_heads x head_size * value: batch_size x kv_padded_seq_len x num_heads x head_size * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal attention mask + * Custom mask: custom attention mask; good place to inject a causal + attention mask * q_prompt_lens: list of unpadded query seq_lens for each batch index * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index Returns: + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size ''' @@ -120,26 +123,39 @@ def make_qkv(batch_size, Construct QKV test tensors for self- and cross-attention. Generates three query/key/value triplets: + * "Baseline" query/key/value (for input to reference attention function) - * "Prefill" query/key/value (last sequence offset zero'd out, for use as input to prefill kernel) - * "Decode" query/key/value (only the last sequence offset from baseline, for use as input to decode kernel) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as + input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, + for use as input to decode kernel) - Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v seqlens + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v + seqlens Arguments: + * batch_size * max_q_prompt_len: max query prompt len * max_kv_prompt_len: max key/value prompt len * num_heads * head_size - * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_prompt_len is unused) - * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens and max_kv_prompt_len, unless forced by is_cross_attn=False + * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as + is often the case for cross-attention); o/w, query/key/value seqlens match + at each batch index (max_kv_prompt_len is unused) + * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query + seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens + and max_kv_prompt_len, unless forced by is_cross_attn=False * device: CPU or CUDA device Returns: - * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x head_size - * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x head_size - * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x head_size + + * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x + head_size + * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x + head_size + * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x + head_size * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size @@ -148,8 +164,10 @@ def make_qkv(batch_size, * decode_value: batch_size x 1 x num_heads x head_size * q_prompt_lens: "baseline" query seqlen list * kv_prompt_lens: "baseline" key/value seqlen list - * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= max_q_prompt_len due to randomness) - * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may be <= max_kv_prompt_len due to randomness) + * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= + max_q_prompt_len due to randomness) + * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may + be <= max_kv_prompt_len due to randomness) * prefill_q_prompt_lens: "prefill" query seqlen list * prefill_kv_prompt_lens: "prefill" key/value seqlen list * decode_q_prompt_lens: "decode" query seqlen list (all ones) @@ -264,19 +282,21 @@ def make_qkv(batch_size, def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): ''' - Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an - unpadded number_of_tokens x num_heads x head_size tensor, where + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where number_of_tokens = sum(prompt_lens) Arguments: + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size * prompt_lens: list of token counts for each prompt * device: CPU or CUDA device Returns + * packed_tensor: number_of_tokens x num_heads x head_size - * start_loc_list: start idx of each batch elt in packed_tensor; - [0] + list(itertools.accumulate(prompt_lens)) + * start_loc_list: start idx of each batch elt in packed_tensor; [0] + + list(itertools.accumulate(prompt_lens)) ''' num_tok = sum(prompt_lens) @@ -299,15 +319,16 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): ''' - Individually pack each of Q, K and V, each with dimensions - batch_size x padded_seq_len x num_heads x head_size, into - respective number_of_tokens x num_heads x head_size tensors. + Individually pack each of Q, K and V, each with dimensions batch_size x + padded_seq_len x num_heads x head_size, into respective number_of_tokens x + num_heads x head_size tensors. For Q, number_of_tokens = sum(q_prompt_lens). For K and V, number_of_tokens = sum(kv_prompt_lens) Arguments: + * query: batch_size x padded_seq_len x num_heads x head_size * key: batch_size x padded_seq_len x num_heads x head_size * value: batch_size x padded_seq_len x num_heads x head_size @@ -315,6 +336,7 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): * kv_prompt_lens: list of token counts for each key/value Returns + * packed_query: number_of_tokens x num_heads x head_size * packed_key: number_of_tokens x num_heads x head_size * packed_value: number_of_tokens x num_heads x head_size @@ -341,13 +363,15 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): def make_backend(backend_name: str) -> AttentionBackend: ''' - Construct the backend instance determined by the backend_name string argument. + Construct the backend instance determined by the backend_name string + argument. "xformers" -> construct xformers backend TODO: flash attention backend Returns: + * Backend instance ''' if backend_name == "xformers": @@ -363,12 +387,14 @@ def make_metadata_tensors(is_prompt: bool, Build scalar & tensor values required to build attention metadata structure. Arguments: + * is_prompt: True -> Prefill, False -> Decode * prompt_lens: list of token-counts for each prompt * context_lens: list of context length values for each prompt * device: CPU or CUDA device Returns: + * prompt_lens_tensor: prompt_lens list, as tensor * context_lens_tensor: context_lens list, as tensor * max_query_len: max(prompt_lens) if is_prompt, o/w 1 @@ -399,9 +425,8 @@ def make_metadata_tensors(is_prompt: bool, query_start_loc = copy.deepcopy(seq_start_loc) max_query_len = max_prompt_len else: - # Decode: one new query input token per batch - # element, thus query_start_loc is the cumsum - # of [1,1,1,...] + # Decode: one new query input token per batch element, thus + # query_start_loc is the cumsum of [1,1,1,...] query_start_loc = list(range(len(seq_start_loc))) max_query_len = 1 @@ -424,6 +449,7 @@ def make_kv_cache(num_blocks, Create a fake KV cache. Arguments: + * num_blocks: number of blocks in the KV cache * num_heads: number of attention heads * head_size: head dimension @@ -432,6 +458,7 @@ def make_kv_cache(num_blocks, * default_val: initialization value for KV cache elements Returns: + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) ''' @@ -444,8 +471,8 @@ def make_kv_cache(num_blocks, def num_tokens_to_min_blocks(num_tokens, block_size): ''' - Compute the minimum number of blocks required - to hold num_tokens tokens, given block_size + Compute the minimum number of blocks required to hold num_tokens tokens, + given block_size ''' return (num_tokens + block_size) // block_size @@ -461,22 +488,28 @@ def make_block_tables_slot_mapping(block_size, block_base_addr + sum(num_blocks_list) * 2 - 1 - and subsequent blocks count downward toward - block_base_addr + and subsequent blocks count downward toward block_base_addr Arguments: + * block_size: number of offsets per block * prompt_lens: list of token-counts for each sequence * block_base_addr: the block table base address * device: CPU or CUDA device Return: - * decode_block_tables_tensor: fake the state of the block tables during decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during prefill - * prefill_block_tables_tensor: fake the state of the block tables during prefill + + * decode_block_tables_tensor: fake the state of the block tables during + decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during + decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during + prefill + * prefill_block_tables_tensor: fake the state of the block tables during + prefill * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase cross attention) + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase + cross attention) * max_block_idx: the highest block address within this block table ''' @@ -551,14 +584,15 @@ def make_metadata_self_cross( cross_slot_mapping: Optional[List[int]] = None, ) -> AttentionMetadata: ''' - Construct fake attention metadata for a combined - self-/cross-attention scenario i.e. an encoder/decoder - model. + Construct fake attention metadata for a combined self-/cross-attention + scenario i.e. an encoder/decoder model. Assumptions: + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both Arguments: + * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * prompt_lens: list of token counts for each sequence @@ -566,11 +600,13 @@ def make_metadata_self_cross( * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping * device: CPU or CUDA device - * cross_seq_lens: list of token counts for each encoder sequence, if any exist + * cross_seq_lens: list of token counts for each encoder sequence, if any + exist * cross_block_tables: cross-attention block tables, if required * cross_slot_mapping: cross-attention slot mapping, if required Return: + * AttentionMetadata structure supporting self- and cross-attention ''' @@ -658,10 +694,9 @@ def make_metadata_self_cross( def make_attention(num_heads: int, head_size: int, scale: float): ''' - Construct an instance of the Attention wrapper, suited to - the number of attention heads and head dimension - (num_heads and head_size respectively) as well as the - attention scale factor (scale) + Construct an instance of the Attention wrapper, suited to the number of + attention heads and head dimension (num_heads and head_size respectively) as + well as the attention scale factor (scale) ''' return Attention( @@ -676,6 +711,7 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): Compute & build entities required for the self-/cross-attention test. Arguments: + * num_heads: Number of attention heads * head_size: Head dimension * num_blocks: Number of KV cache blocks @@ -683,10 +719,12 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * backend_name: selection of backend Returns: + * scale: 1/sqrt(head_size) * attn_backend: backend instance * attn: Attention wrapper instance - * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * + head_size) ''' scale = float(1.0 / (head_size**0.5)) @@ -706,31 +744,33 @@ def self_attn_setup(batch_size, ''' Set up test vectors & data structures for self-attention test. - A triplet of synthetic query/key/value tensors are constructed ("baseline" query/key/value). - Given this is a self-attention test, the key & value sequences will have the same length - as the corresponding queries. + A triplet of synthetic query/key/value tensors are constructed ("baseline" + query/key/value). Given this is a self-attention test, the key & value + sequences will have the same length as the corresponding queries. - "Prefill" query/key/value tensors are derived by masking out the last value in each - baseline query/key/value. These tensors are used to test prefill & populate KV cache - for a subsequent decode test. + "Prefill" query/key/value tensors are derived by masking out the last value + in each baseline query/key/value. These tensors are used to test prefill & + populate KV cache for a subsequent decode test. - "Decode" query/key/value tensors are derived by extracting *only* the last value from - each baseline query/key/value (i.e. complement of the prefill tensors.) These tensors - are used to test decode, conditional on the kv cache being populated during the - prefill test. + "Decode" query/key/value tensors are derived by extracting *only* the last + value from each baseline query/key/value (i.e. complement of the prefill + tensors.) These tensors are used to test decode, conditional on the kv cache + being populated during the prefill test. - The baseline query/key/value tensors are passed to an ideal reference self-attention implementation - to generate a "Baseline" ideal output tensor. This tensor is split into the "Prefill" - ideal output tensor (all but the last element of each output sequence) and the "Decode" - ideal output tensor (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode test - results, respectively. + The baseline query/key/value tensors are passed to an ideal reference + self-attention implementation to generate a "Baseline" ideal output tensor. + This tensor is split into the "Prefill" ideal output tensor (all but the + last element of each output sequence) and the "Decode" ideal output tensor + (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode + test results, respectively. This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts - at block_base_addr + (slot mapping and block table), ensuring that the block table starts at + block_base_addr Arguments: + * batch_size * num_heads: Number of attention heads * head_size: Head dimension @@ -740,21 +780,37 @@ def self_attn_setup(batch_size, * block_base_addr: self-attention block table base address Returns: - * query: "baseline" query; batch_size x padded_seq_len x num_heads x head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size - * prefill_q_prompt_lens: list of token counts for each *prefill query* (one less than baseline query) - * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill key/value* (should match prefill_q_prompt_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size - * decode_q_prompt_lens: list of token counts for each *decode query* (should be 1) - * decode_kv_prompt_lens: list of token counts for each self-attn *decode key/value* (should match decode_q_prompt_lens) - * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size - * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size + + * query: "baseline" query; batch_size x padded_seq_len x num_heads x + head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x + head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads + x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x + num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; + number_of_tokens x num_heads x head_size + * prefill_q_prompt_lens: list of token counts for each *prefill query* (one + less than baseline query) + * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_prompt_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x + head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x + head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads + x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; + number_of_tokens x num_heads x head_size + * decode_q_prompt_lens: list of token counts for each *decode query* (should + be 1) + * decode_kv_prompt_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_prompt_lens) + * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x + head_size + * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens + x num_heads x head_size * decode_block_tables: fake self-attn decode-phase block table * decode_slot_mapping: fake self-attn decode-phase slot mapping * prefill_slot_mapping: fake self-attn prefill-phase slot mapping @@ -853,45 +909,56 @@ def cross_attn_setup_reuses_query(query, ''' Set up test vectors & data structures for cross-attention test. - A triplet of synthetic cross-attention key/value tensors are constructed ("baseline" key/value). - Given this is a cross-attention test, we assume query tensors were already synthesized for a - prior self-attention test and will be reused for cross-attention. The key & value sequences - generated here will may have a different length than the corresponding queries (as is often + A triplet of synthetic cross-attention key/value tensors are constructed + ("baseline" key/value). Given this is a cross-attention test, we assume + query tensors were already synthesized for a prior self-attention test and + will be reused for cross-attention. The key & value sequences generated here + will may have a different length than the corresponding queries (as is often the case for cross-attention between decoder and encoder sequences.) - Cross attention key & value tensors do not grow during autoregressive inference; thus - this function obtains a single key/value pair suitable for both prefill and decode. + Cross attention key & value tensors do not grow during autoregressive + inference; thus this function obtains a single key/value pair suitable for + both prefill and decode. - The "baseline" query tensor is received as an argument. The "baseline" query/key/value tensors - are passed to an ideal reference cross-attention implementation - to generate a "baseline" ideal output tensor. This tensor is split into the "Prefill" - ideal output tensor (all but the last element of each output sequence) and the "Decode" - ideal output tensor (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode test - results, respectively. + The "baseline" query tensor is received as an argument. The "baseline" + query/key/value tensors are passed to an ideal reference cross-attention + implementation to generate a "baseline" ideal output tensor. This tensor is + split into the "Prefill" ideal output tensor (all but the last element of + each output sequence) and the "Decode" ideal output tensor (*only* the last + element of each output sequence); the "Prefill" and "Decode" ideal output + tensors can be used to validate the prefill and decode test results, + respectively. This function also constructs the cross-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts - at block_base_addr. + (slot mapping and block table), ensuring that the block table starts at + block_base_addr. Arguments: - * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size + + * query: pre-existing "baseline" query; batch_size x padded_seq_len x + num_heads x head_size * q_prompt_lens: list of token-counts for each "baseline" query sequence - * prefill_q_prompt_lens: list of token-counts for each "prefill" query sequence + * prefill_q_prompt_lens: list of token-counts for each "prefill" query + sequence * batch_size * num_heads: Number of attention heads * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter * max_q_prompt_len: upper limit on query length for synthetic test vectors - * max_kv_prompt_len: upper limit on key/value length for synthetic test vectors + * max_kv_prompt_len: upper limit on key/value length for synthetic test + vectors * block_base_addr: cross-attention block table base address Returns: + * packed_key: cross-attention key; number_of_tokens x num_heads x head_size - * packed_value: cross-attention value; number_of_tokens x num_heads x head_size - * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x num_heads x head_size - * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size + * packed_value: cross-attention value; number_of_tokens x num_heads x + head_size + * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x + num_heads x head_size + * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x + num_heads x head_size * kv_prompt_lens: list of token-counts for each key/value * decode_block_tables: fake decode-phase block tables * decode_slot_mapping: fake decode-phase slot mapping @@ -988,32 +1055,35 @@ def test_prefill_decode_self_and_cross_attention( max_kv_prompt_len: int) -> None: ''' Test: + * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention attributes + * Construct attention metadata structure with self- and cross-attention + attributes * Test self- and cross-attention in the following order + * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention - * This order would exacerbate any accidental overlap in the self-/cross-attention block tables, - which we attempt to avoid - * Validate output correctness against ideal reference attention implementation - - Block tables are constructed such that cross-attention KV cache is in a higher, non-intersecting - address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V tensors. Self-attention - K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq - len, as is often the case for cross-attention. + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. ''' # Num KV cache blocks num_blocks = 4096 - # Attention scale factor, - # attention backend instance, - # attention wrapper instance, - # KV cache init + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init scale, \ attn_backend, \ attn, \ From 014a751da86ca922049a4229e753b63d0c5ad75e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:04:25 -0400 Subject: [PATCH 028/239] formatting fixes --- tests/layer/test_self_and_cross_attn.py | 67 +++++++++++++++++++------ vllm/attention/backends/xformers.py | 10 ++-- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 4fe9862de99a2..d1d0d0def15e9 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -307,12 +307,10 @@ def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): for bdx, (prompt_len, start_loc) in enumerate(zip(prompt_lens, start_loc_list)): - try: - packed_tensor[start_loc:( - start_loc + - prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] - except: - assert False, f"{start_loc} ; {prompt_len} ; {packed_tensor.shape} ; {unpacked_tensor.shape}" + + packed_tensor[start_loc:( + start_loc + + prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] return packed_tensor, start_loc_list @@ -358,7 +356,11 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): packed_key.shape[-1] * packed_key.shape[-2]) packed_value = packed_value.view( -1, packed_value.shape[-1] * packed_value.shape[-2]) - return packed_query, packed_key, packed_value, q_start_loc_list, kv_start_loc_list + return packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list def make_backend(backend_name: str) -> AttentionBackend: @@ -376,7 +378,8 @@ def make_backend(backend_name: str) -> AttentionBackend: ''' if backend_name == "xformers": return XFormersBackend() - assert False, f"Unrecognized backend_name {backend_name} for unit test" + raise AssertionError( + f"Unrecognized backend_name {backend_name} for unit test") def make_metadata_tensors(is_prompt: bool, @@ -568,7 +571,13 @@ def make_block_tables_slot_mapping(block_size, dtype=torch.long, device=device) - return decode_block_tables_tensor, decode_slot_mapping_tensor, prefill_slot_mapping_tensor, prefill_block_tables_tensor, slot_mapping_tensor, empty_slot_mapping_tensor, max_block_idx + return decode_block_tables_tensor, \ + decode_slot_mapping_tensor, \ + prefill_slot_mapping_tensor, \ + prefill_block_tables_tensor, \ + slot_mapping_tensor, \ + empty_slot_mapping_tensor, \ + max_block_idx def make_metadata_self_cross( @@ -836,7 +845,12 @@ def self_attn_setup(batch_size, prefill_q_prompt_lens, \ prefill_kv_prompt_lens, \ decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=False) + decode_kv_prompt_lens = make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=False) causal_mask = build_causal_mask(max_q_prompt_len, max_kv_prompt_len).to(CUDA_DEVICE) @@ -862,14 +876,26 @@ def self_attn_setup(batch_size, decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) - decode_block_tables, decode_slot_mapping, prefill_slot_mapping, prefill_block_tables, _, _, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + _, \ + _, \ + max_block_idx = make_block_tables_slot_mapping( block_size, q_prompt_lens, block_base_addr=block_base_addr) - prefill_packed_query, prefill_packed_key, prefill_packed_value, _, _ = pack_qkv( + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, _, _ = pack_qkv( prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, prefill_kv_prompt_lens) - decode_packed_query, decode_packed_key, decode_packed_value, _, _ = pack_qkv( + decode_packed_query, \ + decode_packed_key, \ + decode_packed_value, \ + _, \ + _ = pack_qkv( decode_query, decode_key, decode_value, decode_q_prompt_lens, decode_kv_prompt_lens) @@ -983,7 +1009,12 @@ def cross_attn_setup_reuses_query(query, _, \ _, \ _, \ - _ = make_qkv(batch_size,max_q_prompt_len,max_kv_prompt_len,num_heads,head_size,is_cross_attn=True) + _ = make_qkv(batch_size, + max_q_prompt_len, + max_kv_prompt_len, + num_heads, + head_size, + is_cross_attn=True) ideal_output = ref_masked_attention(query, key, @@ -1008,7 +1039,13 @@ def cross_attn_setup_reuses_query(query, # Unlike self-attention: # - Prefill slot-mapping includes all key slots # - Decode slot-mapping is empty - decode_block_tables, _, _, prefill_block_tables, prefill_slot_mapping, decode_slot_mapping, max_block_idx = make_block_tables_slot_mapping( + decode_block_tables, \ + _, \ + _, \ + prefill_block_tables, \ + prefill_slot_mapping, \ + decode_slot_mapping, \ + max_block_idx = make_block_tables_slot_mapping( block_size, kv_prompt_lens, block_base_addr=block_base_addr) # Packed key/value (query is already provided) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0c8db7e47a50d..b540f05c94d7a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -171,9 +171,12 @@ def do_cross_attn(self): def do_cross_attn(self, state: bool): if state: - assert self.has_valid_cross_attn_metadata, "Must have self.cross_seq_lens not None in order to enable cross-attention" + assert self.has_valid_cross_attn_metadata, \ + "Must have self.cross_seq_lens not None " + \ + "in order to enable cross-attention" - # Infer implicit cross-attention fields from user-provided fields, if needed + # Infer implicit cross-attention fields + # from user-provided fields, if needed if self.cross_seq_lens_tensor is None: assert self.seq_lens_tensor is not None self.cross_seq_lens_tensor = torch.tensor( @@ -439,7 +442,8 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. + # not cached. This happens during the initial memory + # profiling run. PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, updated_slot_mapping, From 8dc501af797a8b618e841756f3dfdd0a6037c25e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:06:28 -0400 Subject: [PATCH 029/239] isort --- tests/layer/test_self_and_cross_attn.py | 9 ++++----- vllm/attention/backends/xformers.py | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index d1d0d0def15e9..6945d15be39c8 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -1,15 +1,14 @@ +import copy +import itertools import random from typing import List, Optional -import itertools import pytest import torch -import copy -from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.xformers import XFormersBackend +from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend - +from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b540f05c94d7a..36f1343e995df 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -4,8 +4,9 @@ import torch from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalMask, +from xformers.ops.fmha.attn_bias import (AttentionBias, BlockDiagonalCausalMask, + BlockDiagonalMask, LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, From 3ea10ea4d6a06df3a48def0dfbeac57174a08c78 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 01:22:03 -0400 Subject: [PATCH 030/239] refactor: prompt -> seq where appropriate in test file --- tests/layer/test_self_and_cross_attn.py | 446 ++++++++++++------------ 1 file changed, 223 insertions(+), 223 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 6945d15be39c8..811878a347e97 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -24,26 +24,26 @@ BACKEND_NAMES = ["xformers"] CUDA_DEVICE = "cuda:0" -MAX_Q_PROMPT_LENS = [128] -MAX_K_PROMPT_LENS = [128] +MAX_Q_SEQ_LENS = [128] +MAX_K_SEQ_LENS = [128] -def build_causal_mask(q_max_prompt_len, kv_max_prompt_len): +def build_causal_mask(q_max_seq_len, kv_max_seq_len): ''' - Create a q_max_prompt_len x kv_max_prompt_len causal mask + Create a q_max_seq_len x kv_max_seq_len causal mask Arguments: - * q_max_prompt_len: query max prompt len - * kv_max_prompt_len: key/value max prompt len + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len Returns: - * 2D tensor, q_max_prompt_len x kv_max_prompt_len + * 2D tensor, q_max_seq_len x kv_max_seq_len ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_prompt_len, kv_max_prompt_len), + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, @@ -57,12 +57,12 @@ def ref_masked_attention( value: torch.Tensor, scale: float, custom_mask: Optional[torch.Tensor] = None, - q_prompt_lens: Optional[List] = None, - kv_prompt_lens: Optional[List] = None) -> torch.Tensor: + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: - * Basic attention mask, utilizing {q,kv}_prompt_lens args to mask out + * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out padding elements * Custom attention mask, which can force an arbitrary mask tensor, i.e. causal @@ -75,8 +75,8 @@ def ref_masked_attention( * scale: Attention scale factor * Custom mask: custom attention mask; good place to inject a causal attention mask - * q_prompt_lens: list of unpadded query seq_lens for each batch index - * kv_prompt_lens: list of unpadded key/value seq_lens for each batch index + * q_seq_lens: list of unpadded query seq_lens for each batch index + * kv_seq_lens: list of unpadded key/value seq_lens for each batch index Returns: @@ -84,19 +84,19 @@ def ref_masked_attention( ''' batch_size = query.shape[0] - assert (len(q_prompt_lens) == batch_size) - assert (len(kv_prompt_lens) == batch_size) + assert (len(q_seq_lens) == batch_size) + assert (len(kv_seq_lens) == batch_size) attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - # Basic attention mask, derived from prompt lens - if (q_prompt_lens is not None) or (kv_prompt_lens is not None): + # Basic attention mask, derived from seq lens + if (q_seq_lens is not None) or (kv_seq_lens is not None): attn_mask = torch.zeros_like(attn_weights) - if q_prompt_lens is not None: - for bdx, plen in enumerate(q_prompt_lens): + if q_seq_lens is not None: + for bdx, plen in enumerate(q_seq_lens): attn_mask[bdx, :, plen:, :] = -torch.inf - if kv_prompt_lens is not None: - for bdx, plen in enumerate(kv_prompt_lens): + if kv_seq_lens is not None: + for bdx, plen in enumerate(kv_seq_lens): attn_mask[bdx, :, :, plen:] = -torch.inf attn_weights = attn_weights + attn_mask.float() @@ -111,8 +111,8 @@ def ref_masked_attention( def make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, num_heads, head_size, is_cross_attn=True, @@ -135,79 +135,79 @@ def make_qkv(batch_size, Arguments: * batch_size - * max_q_prompt_len: max query prompt len - * max_kv_prompt_len: max key/value prompt len + * max_q_seq_len: max query seq len + * max_kv_seq_len: max key/value seq len * num_heads * head_size * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match - at each batch index (max_kv_prompt_len is unused) - * force_max_len: if True, all query seqlens are max_q_prompt_len; o/w query - seqlens are random in [2,max_q_prompt_lens]. Same for key/value seqlens - and max_kv_prompt_len, unless forced by is_cross_attn=False + at each batch index (max_kv_seq_len is unused) + * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query + seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens + and max_kv_seq_len, unless forced by is_cross_attn=False * device: CPU or CUDA device Returns: - * query: "baseline" query; batch_size x max_q_prompt_len x num_heads x + * query: "baseline" query; batch_size x max_q_seq_len x num_heads x head_size - * key: "baseline" key; batch_size x max_kv_prompt_len x num_heads x + * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x head_size - * value: "baseline" value; batch_size x max_kv_prompt_len x num_heads x + * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x head_size - * prefill_query: batch_size x (max_q_prompt_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_prompt_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_prompt_len-1) x num_heads x head_size + * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size * decode_query: batch_size x 1 x num_heads x head_size * decode_key: batch_size x 1 x num_heads x head_size * decode_value: batch_size x 1 x num_heads x head_size - * q_prompt_lens: "baseline" query seqlen list - * kv_prompt_lens: "baseline" key/value seqlen list - * actual_max_q_prompt_len: actual "baseline" query max prompt len (may be <= - max_q_prompt_len due to randomness) - * actual_max_kv_prompt_len: actual "baseline" key/value max prompt len (may - be <= max_kv_prompt_len due to randomness) - * prefill_q_prompt_lens: "prefill" query seqlen list - * prefill_kv_prompt_lens: "prefill" key/value seqlen list - * decode_q_prompt_lens: "decode" query seqlen list (all ones) - * decode_kv_prompt_lens: "decode" key/value seqlen list + * q_seq_lens: "baseline" query seqlen list + * kv_seq_lens: "baseline" key/value seqlen list + * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= + max_q_seq_len due to randomness) + * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may + be <= max_kv_seq_len due to randomness) + * prefill_q_seq_lens: "prefill" query seqlen list + * prefill_kv_seq_lens: "prefill" key/value seqlen list + * decode_q_seq_lens: "decode" query seqlen list (all ones) + * decode_kv_seq_lens: "decode" key/value seqlen list ''' if force_max_len: - q_prompt_lens = [max_q_prompt_len for _ in range(batch_size)] + q_seq_lens = [max_q_seq_len for _ in range(batch_size)] else: - q_prompt_lens = [ - random.randint(2, max_q_prompt_len) for _ in range(batch_size) + q_seq_lens = [ + random.randint(2, max_q_seq_len) for _ in range(batch_size) ] - kv_prompt_lens = None + kv_seq_lens = None if not is_cross_attn: - # K,V prompt lens match Q for self-attention - kv_prompt_lens = q_prompt_lens + # K,V seq lens match Q for self-attention + kv_seq_lens = q_seq_lens else: - # K,V prompt lens are distinct from Q prompt lens & random + # K,V seq lens are distinct from Q seq lens & random if force_max_len: - kv_prompt_lens = [max_kv_prompt_len for _ in range(batch_size)] + kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] else: - kv_prompt_lens = [ - random.randint(2, max_kv_prompt_len) for _ in range(batch_size) + kv_seq_lens = [ + random.randint(2, max_kv_seq_len) for _ in range(batch_size) ] - actual_max_q_prompt_len = max(q_prompt_lens) - actual_max_kv_prompt_len = max(kv_prompt_lens) + actual_max_q_seq_len = max(q_seq_lens) + actual_max_kv_seq_len = max(kv_seq_lens) query = torch.rand( - (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads * head_size)).to(device) key = torch.rand( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) value = torch.rand( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) prefill_query = torch.zeros( - (batch_size, max_q_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads * head_size)).to(device) prefill_key = torch.zeros( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) prefill_value = torch.zeros( - (batch_size, max_kv_prompt_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) decode_query = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) @@ -215,32 +215,32 @@ def make_qkv(batch_size, decode_value = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) - for bdx, (q_prompt_len, - kv_prompt_len) in enumerate(zip(q_prompt_lens, kv_prompt_lens)): - query[bdx, q_prompt_len:, :] = 0 - key[bdx, kv_prompt_len:, :] = 0 - value[bdx, kv_prompt_len:, :] = 0 + for bdx, (q_seq_len, + kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): + query[bdx, q_seq_len:, :] = 0 + key[bdx, kv_seq_len:, :] = 0 + value[bdx, kv_seq_len:, :] = 0 prefill_query[bdx, - 0:(q_prompt_len - 1), :] = query[bdx, - 0:(q_prompt_len - 1), :] + 0:(q_seq_len - 1), :] = query[bdx, + 0:(q_seq_len - 1), :] prefill_key[bdx, - 0:(kv_prompt_len - 1), :] = key[bdx, - 0:(kv_prompt_len - 1), :] - prefill_value[bdx, 0:(kv_prompt_len - - 1), :] = value[bdx, 0:(kv_prompt_len - 1), :] + 0:(kv_seq_len - 1), :] = key[bdx, + 0:(kv_seq_len - 1), :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :] = value[bdx, 0:(kv_seq_len - 1), :] decode_query[bdx, :, :] = query[bdx, - (q_prompt_len - 1):q_prompt_len, :] - decode_key[bdx, :, :] = key[bdx, (kv_prompt_len - 1):kv_prompt_len, :] + (q_seq_len - 1):q_seq_len, :] + decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] decode_value[bdx, :, :] = value[bdx, - (kv_prompt_len - 1):kv_prompt_len, :] + (kv_seq_len - 1):kv_seq_len, :] - prefill_q_prompt_lens = [plen - 1 for plen in q_prompt_lens] - prefill_kv_prompt_lens = [plen - 1 for plen in kv_prompt_lens] + prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] + prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] - decode_q_prompt_lens = [1 for _ in q_prompt_lens] - decode_kv_prompt_lens = [1 for _ in kv_prompt_lens] + decode_q_seq_lens = [1 for _ in q_seq_lens] + decode_kv_seq_lens = [1 for _ in kv_seq_lens] query = query.view(batch_size, query.shape[1], num_heads, head_size) key = key.view(batch_size, key.shape[1], num_heads, head_size) @@ -269,68 +269,68 @@ def make_qkv(batch_size, decode_query, \ decode_key, \ decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ - actual_max_q_prompt_len, \ - actual_max_kv_prompt_len, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens + q_seq_lens, \ + kv_seq_lens, \ + actual_max_q_seq_len, \ + actual_max_kv_seq_len, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens -def pack_tensor(unpacked_tensor, prompt_lens, device=CUDA_DEVICE): +def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where - number_of_tokens = sum(prompt_lens) + number_of_tokens = sum(seq_lens) Arguments: * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size - * prompt_lens: list of token counts for each prompt + * seq_lens: list of token counts for each seq * device: CPU or CUDA device Returns * packed_tensor: number_of_tokens x num_heads x head_size * start_loc_list: start idx of each batch elt in packed_tensor; [0] + - list(itertools.accumulate(prompt_lens)) + list(itertools.accumulate(seq_lens)) ''' - num_tok = sum(prompt_lens) + num_tok = sum(seq_lens) num_heads = unpacked_tensor.shape[-2] head_size = unpacked_tensor.shape[-1] - start_loc_list = [0] + list(itertools.accumulate(prompt_lens)) + start_loc_list = [0] + list(itertools.accumulate(seq_lens)) packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx, (prompt_len, - start_loc) in enumerate(zip(prompt_lens, start_loc_list)): + for bdx, (seq_len, + start_loc) in enumerate(zip(seq_lens, start_loc_list)): packed_tensor[start_loc:( start_loc + - prompt_len), :, :] = unpacked_tensor[bdx, :prompt_len, :, :] + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] return packed_tensor, start_loc_list -def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): +def pack_qkv(query, key, value, q_seq_lens, kv_seq_lens): ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x num_heads x head_size tensors. - For Q, number_of_tokens = sum(q_prompt_lens). + For Q, number_of_tokens = sum(q_seq_lens). - For K and V, number_of_tokens = sum(kv_prompt_lens) + For K and V, number_of_tokens = sum(kv_seq_lens) Arguments: * query: batch_size x padded_seq_len x num_heads x head_size * key: batch_size x padded_seq_len x num_heads x head_size * value: batch_size x padded_seq_len x num_heads x head_size - * q_prompt_lens: list of token counts for each query - * kv_prompt_lens: list of token counts for each key/value + * q_seq_lens: list of token counts for each query + * kv_seq_lens: list of token counts for each key/value Returns @@ -345,9 +345,9 @@ def pack_qkv(query, key, value, q_prompt_lens, kv_prompt_lens): packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(query, q_prompt_lens) - packed_key, kv_start_loc_list = pack_tensor(key, kv_prompt_lens) - packed_value, _ = pack_tensor(value, kv_prompt_lens) + packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) + packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) + packed_value, _ = pack_tensor(value, kv_seq_lens) if packed_query is not None: packed_query = packed_query.view( -1, packed_query.shape[-1] * packed_query.shape[-2]) @@ -382,7 +382,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(is_prompt: bool, - prompt_lens: List[int], + seq_lens: List[int], context_lens: List[int], device=CUDA_DEVICE) -> tuple: ''' @@ -391,33 +391,33 @@ def make_metadata_tensors(is_prompt: bool, Arguments: * is_prompt: True -> Prefill, False -> Decode - * prompt_lens: list of token-counts for each prompt - * context_lens: list of context length values for each prompt + * seq_lens: list of token-counts for each seq + * context_lens: list of context length values for each seq * device: CPU or CUDA device Returns: - * prompt_lens_tensor: prompt_lens list, as tensor + * seq_lens_tensor: seq_lens list, as tensor * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(prompt_lens) if is_prompt, o/w 1 + * max_query_len: max(seq_lens) if is_seq, o/w 1 * max_context_len: max(context_lens) - * max_prompt_len: max(prompt_lens) + * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - prompt_lens_tensor = torch.tensor(prompt_lens, + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) - max_prompt_len = None if prompt_lens is None else max(prompt_lens) + max_seq_len = None if seq_lens is None else max(seq_lens) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=device) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) @@ -425,18 +425,18 @@ def make_metadata_tensors(is_prompt: bool, if is_prompt: # Prefill: query_start_loc matches seq_start_loc query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len = max_prompt_len + max_query_len = max_seq_len else: # Decode: one new query input token per batch element, thus # query_start_loc is the cumsum of [1,1,1,...] query_start_loc = list(range(len(seq_start_loc))) max_query_len = 1 - return prompt_lens_tensor, \ + return seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ max_context_len, \ - max_prompt_len, \ + max_seq_len, \ seq_start_loc, \ query_start_loc @@ -480,7 +480,7 @@ def num_tokens_to_min_blocks(num_tokens, block_size): def make_block_tables_slot_mapping(block_size, - prompt_lens, + seq_lens, block_base_addr=0, device=CUDA_DEVICE): ''' @@ -495,7 +495,7 @@ def make_block_tables_slot_mapping(block_size, Arguments: * block_size: number of offsets per block - * prompt_lens: list of token-counts for each sequence + * seq_lens: list of token-counts for each sequence * block_base_addr: the block table base address * device: CPU or CUDA device @@ -518,7 +518,7 @@ def make_block_tables_slot_mapping(block_size, # Over-provision block table blocks by 1 num_blocks_list = [ num_tokens_to_min_blocks(num_tokens, block_size) + 1 - for num_tokens in prompt_lens + for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -530,7 +530,7 @@ def make_block_tables_slot_mapping(block_size, block_base_idx = block_base_addr + sum( num_blocks_list) * 2 - 1 # Support more blocks than needed max_block_idx = block_base_idx - for sdx, num_tokens in enumerate(prompt_lens): + for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) @@ -582,7 +582,7 @@ def make_block_tables_slot_mapping(block_size, def make_metadata_self_cross( attn_backend: AttentionBackend, is_prompt: bool, - prompt_lens: List[int], + seq_lens: List[int], context_lens: List[int], block_tables, slot_mapping, @@ -603,7 +603,7 @@ def make_metadata_self_cross( * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode - * prompt_lens: list of token counts for each sequence + * seq_lens: list of token counts for each sequence * context_lens: list of context lengths for each sequence * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping @@ -619,18 +619,18 @@ def make_metadata_self_cross( ''' if is_prompt: - num_prefills = len(prompt_lens) - num_prefill_tokens = sum(prompt_lens) + num_prefills = len(seq_lens) + num_prefill_tokens = sum(seq_lens) num_decode_tokens = 0 - prompt_lens_tensor, \ + seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ _, \ _, \ seq_start_loc, \ query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, + seq_lens, context_lens, device=device) @@ -643,10 +643,10 @@ def make_metadata_self_cross( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, - max_prefill_seq_len=max(prompt_lens), + max_prefill_seq_len=max(seq_lens), max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, @@ -662,16 +662,16 @@ def make_metadata_self_cross( num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = len(prompt_lens) + num_decode_tokens = len(seq_lens) - prompt_lens_tensor, \ + seq_lens_tensor, \ context_lens_tensor, \ max_query_len, \ _, \ _, \ seq_start_loc, \ query_start_loc = make_metadata_tensors(is_prompt, - prompt_lens, + seq_lens, context_lens, device=device) @@ -684,11 +684,11 @@ def make_metadata_self_cross( slot_mapping=slot_mapping_tensor, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, - seq_lens=prompt_lens, - seq_lens_tensor=prompt_lens_tensor, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_prefill_seq_len=0, - max_decode_seq_len=max(prompt_lens), + max_decode_seq_len=max(seq_lens), query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, @@ -747,7 +747,7 @@ def self_attn_setup(batch_size, head_size, block_size, scale, - max_q_prompt_len, + max_q_seq_len, block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -784,7 +784,7 @@ def self_attn_setup(batch_size, * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter - * max_q_prompt_len: upper limit on query length for synthetic test vectors + * max_q_seq_len: upper limit on query length for synthetic test vectors * block_base_addr: self-attention block table base address Returns: @@ -799,10 +799,10 @@ def self_attn_setup(batch_size, num_heads x head_size * prefill_packed_ideal_output: self-attn "prefill" ideal output; number_of_tokens x num_heads x head_size - * prefill_q_prompt_lens: list of token counts for each *prefill query* (one + * prefill_q_seq_lens: list of token counts for each *prefill query* (one less than baseline query) - * prefill_kv_prompt_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_prompt_lens) + * prefill_kv_seq_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_seq_lens) * decode_packed_query: "decode" query; number_of_tokens x num_heads x head_size * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x @@ -811,13 +811,13 @@ def self_attn_setup(batch_size, x head_size * decode_packed_ideal_output: self-attn "decode" ideal output; number_of_tokens x num_heads x head_size - * decode_q_prompt_lens: list of token counts for each *decode query* (should + * decode_q_seq_lens: list of token counts for each *decode query* (should be 1) - * decode_kv_prompt_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_prompt_lens) - * q_prompt_lens: "baseline" query seq lens; number_of_tokens x num_heads x + * decode_kv_seq_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_seq_lens) + * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x head_size - * kv_prompt_lens: self-attn "baseline" key/value seq lens; number_of_tokens + * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens x num_heads x head_size * decode_block_tables: fake self-attn decode-phase block table * decode_slot_mapping: fake self-attn decode-phase slot mapping @@ -826,7 +826,7 @@ def self_attn_setup(batch_size, * max_block_idx: highest block address in the self-attention block-table ''' - max_kv_prompt_len = max_q_prompt_len + max_kv_seq_len = max_q_seq_len query, \ key, \ @@ -837,41 +837,41 @@ def self_attn_setup(batch_size, decode_query, \ decode_key, \ decode_value, \ - q_prompt_lens, \ - kv_prompt_lens, \ + q_seq_lens, \ + kv_seq_lens, \ _, \ _, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens = make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, - num_heads, - head_size, - is_cross_attn=False) - - causal_mask = build_causal_mask(max_q_prompt_len, - max_kv_prompt_len).to(CUDA_DEVICE) + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + is_cross_attn=False) + + causal_mask = build_causal_mask(max_q_seq_len, + max_kv_seq_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, value, scale=scale, custom_mask=causal_mask, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) + prefill_q_seq_lens) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) @@ -882,37 +882,37 @@ def self_attn_setup(batch_size, _, \ _, \ max_block_idx = make_block_tables_slot_mapping( - block_size, q_prompt_lens, block_base_addr=block_base_addr) + block_size, q_seq_lens, block_base_addr=block_base_addr) prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_prompt_lens, - prefill_kv_prompt_lens) + prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, + prefill_kv_seq_lens) decode_packed_query, \ decode_packed_key, \ decode_packed_value, \ _, \ _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_prompt_lens, - decode_kv_prompt_lens) + decode_query, decode_key, decode_value, decode_q_seq_lens, + decode_kv_seq_lens) return query, \ prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, \ prefill_packed_ideal_output, \ - prefill_q_prompt_lens, \ - prefill_kv_prompt_lens, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ decode_packed_query, \ decode_packed_key, \ decode_packed_value, \ decode_packed_ideal_output, \ - decode_q_prompt_lens, \ - decode_kv_prompt_lens, \ - q_prompt_lens, \ - kv_prompt_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens, \ + q_seq_lens, \ + kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -921,15 +921,15 @@ def self_attn_setup(batch_size, def cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, + q_seq_lens, + prefill_q_seq_lens, batch_size, num_heads, head_size, block_size, scale, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -962,16 +962,16 @@ def cross_attn_setup_reuses_query(query, * query: pre-existing "baseline" query; batch_size x padded_seq_len x num_heads x head_size - * q_prompt_lens: list of token-counts for each "baseline" query sequence - * prefill_q_prompt_lens: list of token-counts for each "prefill" query + * q_seq_lens: list of token-counts for each "baseline" query sequence + * prefill_q_seq_lens: list of token-counts for each "prefill" query sequence * batch_size * num_heads: Number of attention heads * head_size: Head dimension * block_size: Number of offsets per KV cache block * scale: attention scale parameter - * max_q_prompt_len: upper limit on query length for synthetic test vectors - * max_kv_prompt_len: upper limit on key/value length for synthetic test + * max_q_seq_len: upper limit on query length for synthetic test vectors + * max_kv_seq_len: upper limit on key/value length for synthetic test vectors * block_base_addr: cross-attention block table base address @@ -984,7 +984,7 @@ def cross_attn_setup_reuses_query(query, num_heads x head_size * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x num_heads x head_size - * kv_prompt_lens: list of token-counts for each key/value + * kv_seq_lens: list of token-counts for each key/value * decode_block_tables: fake decode-phase block tables * decode_slot_mapping: fake decode-phase slot mapping * prefill_slot_mapping: fake prefill-phase slot mapping @@ -1002,15 +1002,15 @@ def cross_attn_setup_reuses_query(query, _, \ _, \ _, \ - kv_prompt_lens, \ + kv_seq_lens, \ _, \ _, \ _, \ _, \ _, \ _ = make_qkv(batch_size, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, num_heads, head_size, is_cross_attn=True) @@ -1019,19 +1019,19 @@ def cross_attn_setup_reuses_query(query, key, value, scale=scale, - q_prompt_lens=q_prompt_lens, - kv_prompt_lens=kv_prompt_lens) + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_prompt_len in enumerate(prefill_q_prompt_lens): - prefill_ideal_output[bdx, :prefill_q_prompt_len] = ideal_output[ - bdx, :prefill_q_prompt_len] - decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_prompt_len:( - prefill_q_prompt_len + 1)] + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( + prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_prompt_lens) + prefill_q_seq_lens) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)]) @@ -1045,17 +1045,17 @@ def cross_attn_setup_reuses_query(query, prefill_slot_mapping, \ decode_slot_mapping, \ max_block_idx = make_block_tables_slot_mapping( - block_size, kv_prompt_lens, block_base_addr=block_base_addr) + block_size, kv_seq_lens, block_base_addr=block_base_addr) # Packed key/value (query is already provided) _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, - kv_prompt_lens) + kv_seq_lens) return packed_key, \ packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ - kv_prompt_lens, \ + kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -1083,12 +1083,12 @@ def run_cross_attention_test(attn, packed_query, packed_key, packed_value, @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_prompt_len", MAX_Q_PROMPT_LENS) -@pytest.mark.parametrize("max_kv_prompt_len", MAX_K_PROMPT_LENS) +@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) +@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) def test_prefill_decode_self_and_cross_attention( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_prompt_len: int, - max_kv_prompt_len: int) -> None: + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' Test: @@ -1138,15 +1138,15 @@ def test_prefill_decode_self_and_cross_attention( self_prefill_packed_key, \ self_prefill_packed_value, \ self_prefill_packed_ideal_output, \ - prefill_q_prompt_lens, \ - self_prefill_kv_prompt_lens, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ decode_packed_query, \ self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ _, \ _, \ - q_prompt_lens, \ + q_seq_lens, \ _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ @@ -1157,7 +1157,7 @@ def test_prefill_decode_self_and_cross_attention( head_size, block_size, scale, - max_q_prompt_len, + max_q_seq_len, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1166,21 +1166,21 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_prompt_lens, \ + cross_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ _ = cross_attn_setup_reuses_query(query, - q_prompt_lens, - prefill_q_prompt_lens, + q_seq_lens, + prefill_q_seq_lens, batch_size, num_heads, head_size, block_size, scale, - max_q_prompt_len, - max_kv_prompt_len, + max_q_seq_len, + max_kv_seq_len, block_base_addr=cross_block_base_addr) # PREFILL: self- and cross-attention tests @@ -1190,11 +1190,11 @@ def test_prefill_decode_self_and_cross_attention( prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, - prefill_q_prompt_lens, + prefill_q_seq_lens, context_lens, self_prefill_block_tables, self_prefill_slot_mapping, - cross_seq_lens=cross_kv_prompt_lens, + cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) @@ -1219,18 +1219,18 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_actual_output.view_as( cross_prefill_packed_ideal_output)) - context_lens = copy.deepcopy(self_prefill_kv_prompt_lens) + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) # DECODE: self- and cross-attention tests decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, False, - q_prompt_lens, + q_seq_lens, context_lens, self_decode_block_tables, self_decode_slot_mapping, - cross_seq_lens=cross_kv_prompt_lens, + cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, ) From 2ced012a3e51a77abbbab2268d88730fdffa4a3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 21 May 2024 06:19:19 -0400 Subject: [PATCH 031/239] fix wording nits (ben->been, decoder->encoder/decoder) --- tests/core/test_block_manager.py | 2 +- vllm/core/block_manager_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index d6ab246699903..81e3444815d4e 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -314,7 +314,7 @@ def test_swap_encoder_decoder(): assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks decoder_prompt.status = SequenceStatus.SWAPPED - # Swap decoder seq group from CPU -> GPU. + # Swap encoder/decoder seq group from CPU -> GPU. decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 4ae3361e7b234..978acd915b69b 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -222,7 +222,7 @@ def free(self, seq: Sequence) -> None: def free_cross(self, seq_group: SequenceGroup) -> None: request_id = seq_group.request_id if request_id not in self.cross_block_tables: - # Already freed or hasn't ben scheduled yet. + # Already freed or hasn't been scheduled yet. return self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] From 8286b4cfbe57001767617a9ee33066945f6baa3d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 13:46:58 -0400 Subject: [PATCH 032/239] changed two block manager tests to construct fake prompts that are equal in length to the bock size, rather than half the block size (which had been the case --- tests/core/test_block_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 81e3444815d4e..9dc1c88819b70 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -116,7 +116,7 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size // 2, block_size // 2) + str(i), block_size, block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -365,8 +365,8 @@ def test_free_encoder_decoder(): decoder_prompt, encoder_prompt, seq_group = \ create_dummy_prompt_encoder_decoder( "1", - decoder_prompt_length=block_size // 2, - encoder_prompt_length=block_size // 2) + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) block_manager.allocate(seq_group) # Free allocated seq. From eba551cd7e1d53911cb392d773eec05cfe40cc4f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 13:50:04 -0400 Subject: [PATCH 033/239] keyword args for dummy prompt construction in block manager encoder/decoder tests --- tests/core/test_block_manager.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9dc1c88819b70..19dfc09dbb001 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,7 +103,9 @@ def test_allocate_encoder_decoder(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size, block_size) + str(i), + decoder_prompt_length=block_size, + decoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -116,7 +118,9 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), block_size, block_size) + str(i), + decoder_prompt_length=block_size, + decoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK From a7c8b192cd7c6e6c815caf5acbbd4ed24b16925d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:00:05 -0400 Subject: [PATCH 034/239] bugfix - decoder prompt kwarg repeated in lieu of encoder prompt kwarg --- tests/core/test_block_manager.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 19dfc09dbb001..29956ff028143 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -73,7 +73,7 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -85,7 +85,7 @@ def test_allocate(): watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -105,8 +105,8 @@ def test_allocate_encoder_decoder(): _, _, seq_group = create_dummy_prompt_encoder_decoder( str(i), decoder_prompt_length=block_size, - decoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -120,8 +120,8 @@ def test_allocate_encoder_decoder(): _, _, seq_group = create_dummy_prompt_encoder_decoder( str(i), decoder_prompt_length=block_size, - decoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK From 9feb994966e365fac63bbec526cafb24cf00dcde Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:09:42 -0400 Subject: [PATCH 035/239] In block manager test which used with block to detect error - created a second with block for encoder-related call that previously shared a with block with the corresponding decoder-related call --- tests/core/test_block_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 29956ff028143..808b0a5e651eb 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -386,6 +386,9 @@ def test_free_encoder_decoder(): # Block table for freed encoder & decoder seq's are deleted. with pytest.raises(KeyError): block_manager.get_block_table(decoder_prompt) + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): block_manager.get_block_table(encoder_prompt) From 5eb0032bfaaf5bc43fab66f1fc8bea30045915b7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:17:50 -0400 Subject: [PATCH 036/239] refactoring block manager v1/v2 swap in/swap out functions --- vllm/core/block_manager_v1.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 03eba2e80c78d..119e444df1b11 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -570,12 +570,7 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.cross_block_tables[request_id] = new_block_table - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) @@ -621,12 +616,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.cross_block_tables[request_id] = new_block_table - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up From 0644cde2aced6d7fb6c279025b2a4a3d8f5625d2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:23:50 -0400 Subject: [PATCH 037/239] formatting; changed blocktable type specifier from Dict to List[int] --- tests/core/test_block_manager.py | 6 +++--- vllm/core/block_manager_v1.py | 6 ++++-- vllm/sequence.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 808b0a5e651eb..cdaf2f22115e8 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -103,7 +103,7 @@ def test_allocate_encoder_decoder(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), + str(i), decoder_prompt_length=block_size, encoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) == AllocStatus.OK @@ -118,8 +118,8 @@ def test_allocate_encoder_decoder(): watermark=1 / num_gpu_blocks) for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, + str(i), + decoder_prompt_length=block_size, encoder_prompt_length=block_size) assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 119e444df1b11..2482cf17956f2 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -570,7 +570,8 @@ def swap_in(self, self.cpu_allocator.free(cpu_block) self.cross_block_tables[request_id] = new_block_table - return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) @@ -616,7 +617,8 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self.gpu_allocator.free(gpu_block) self.cross_block_tables[request_id] = new_block_table - return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up diff --git a/vllm/sequence.py b/vllm/sequence.py index a11c411876ea8..6b07a00f09c6f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -527,8 +527,8 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] - def get_encoder_seq(self) -> Sequence: - return self.encoder_seq # type: ignore + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq def get_unfinished_seqs(self) -> List[Sequence]: return [ @@ -635,7 +635,7 @@ def __init__( state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, encoder_seq_data: Optional[SequenceData] = None, - cross_block_table: Optional[Dict[int, List[int]]] = None, + cross_block_table: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt From 19ed7413e315ce665cc07722d72fb874a362fafd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 14:39:50 -0400 Subject: [PATCH 038/239] prefixed internal method with _ --- vllm/core/block_manager_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2482cf17956f2..648ff843fd4e5 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -260,7 +260,7 @@ def __init__( # request ID self.cross_block_tables: Dict[str, BlockTable] = {} - def get_seq_num_required_blocks(self, seq: Sequence) -> int: + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: if seq is None: return 0 return len(seq.logical_token_blocks) @@ -269,9 +269,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - self_num_required_blocks = self.get_seq_num_required_blocks( + self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self.get_seq_num_required_blocks( + cross_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_encoder_seq()) num_required_blocks = self_num_required_blocks + \ cross_num_required_blocks From a5579729928c4151e501138f82340c0afa2dc327 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 17:47:19 -0400 Subject: [PATCH 039/239] refactored self-/cross-attention allocation functions into a single helper function --- vllm/core/block_manager_v1.py | 57 ++++++++++++----------------------- 1 file changed, 19 insertions(+), 38 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 648ff843fd4e5..9f08d4a7939aa 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -290,11 +290,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same - # decoder prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - + def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -304,7 +300,7 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), @@ -312,47 +308,32 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() block_table.append(block) - # Assign the decoder block table for each sequence. - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - self.block_tables[seq.seq_id] = block_table.copy() + return block_table - def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: + def allocate(self, seq_group: SequenceGroup) -> None: + # Allocate decoder sequences + # # NOTE: Here we assume that all sequences in the group have the same - # encoder prompt. + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + block_table: BlockTable = self._allocate_sequence(seq, seq_group.num_seqs()) - # Allocate new physical token blocks that will store the prompt tokens. - seq = seq_group.get_encoder_seq() - if seq is not None: - block_table: BlockTable = [] - num_prompt_blocks = len(seq.logical_token_blocks) - for logical_idx in range(num_prompt_blocks): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % - self.block_sliding_window] - # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - elif self.enable_caching: - block = self.gpu_allocator.allocate( - seq.hash_of_block(logical_idx), - seq.num_hashed_tokens_of_block(logical_idx)) - else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - # TODO: feature not supported with encoder/decoder - block.ref_count = seq_group.num_seqs() - block_table.append(block) + # Assign the self-attention block tables for each sequence. + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() + # Allocate encoder sequence + encoder_seq = seq_group.get_encoder_seq() + if encoder_seq is not None: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table: BlockTable = self._allocate_sequence(encoder_seq, 1) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table - def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_self_block_tables(seq_group) - self.allocate_cross_block_table(seq_group) - def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: From e48bebf727ae67ffbdff206d168eab3e77b988da Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 17:59:44 -0400 Subject: [PATCH 040/239] Refactored block manager v2 self-/cross-block-table alloc functions together --- vllm/core/block_manager_v2.py | 38 ++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 978acd915b69b..a8085f54ac79d 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -121,7 +121,18 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + ) + assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -129,43 +140,34 @@ def allocate_self_block_tables(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - ) - assert self.block_sliding_window is None - block_table.allocate(seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() - def allocate_cross_block_table(self, seq_group: SequenceGroup) -> None: + # Allocate cross-attention block table for encoder sequence + # # NOTE: Here we assume that all sequences in the group have the same - # prompt. + # encoder prompt. request_id = seq_group.request_id - seq = seq_group.encoder_seq + encoder_seq = seq_group.encoder_seq assert (request_id not in self.cross_block_tables), \ "block table already exists" - seq = seq_group.get_encoder_seq() - if seq is not None: + encoder_seq = seq_group.get_encoder_seq() + if encoder_seq is not None: block_table = BlockTable( block_size=self.block_size, block_allocator=self.block_allocator, ) assert self.block_sliding_window is None - block_table.allocate(seq.get_token_ids()) + block_table.allocate(encoder_seq.get_token_ids()) self.cross_block_tables[request_id] = block_table - def allocate(self, seq_group: SequenceGroup) -> None: - self.allocate_self_block_tables(seq_group) - self.allocate_cross_block_table(seq_group) - def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue From ac2da978c786d998247cfe55a3d2a788109b71e4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:11:58 -0400 Subject: [PATCH 041/239] formatting --- vllm/core/block_manager_v1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 9f08d4a7939aa..fa53b3cd33229 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -319,7 +319,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - block_table: BlockTable = self._allocate_sequence(seq, seq_group.num_seqs()) + block_table: BlockTable = \ + self._allocate_sequence(seq, seq_group.num_seqs()) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -330,7 +331,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: if encoder_seq is not None: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table: BlockTable = self._allocate_sequence(encoder_seq, 1) + block_table = self._allocate_sequence(encoder_seq, 1) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table From e985a2f05080a0e311f52adf119447993322541f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:40:07 -0400 Subject: [PATCH 042/239] refactored out block manager v1 swap_n/swap_out helper functions --- vllm/core/block_manager_v1.py | 116 ++++++++++++++++------------------ 1 file changed, 54 insertions(+), 62 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index fa53b3cd33229..dd6d8d702fae0 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -300,7 +300,7 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() elif self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), @@ -308,7 +308,7 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count #seq_group.num_seqs() block_table.append(block) return block_table @@ -507,6 +507,26 @@ def can_swap_in(self, else: return AllocStatus.LATER + def _swap_in_block_table( + self, block_table: BlockTable, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table = [] + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate( + cpu_block.block_hash, cpu_block.num_hashed_tokens) + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + + return new_block_table + def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: @@ -519,38 +539,14 @@ def swap_in(self, # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table + self.block_tables[seq.seq_id] = \ + self._swap_in_block_table(self.block_tables[seq.seq_id], + mapping) if seq_group.encoder_seq is not None: - new_block_table = [] - block_table = self.cross_block_tables[request_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.cross_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = \ + self._swap_in_block_table(self.cross_block_tables[request_id], + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] @@ -559,6 +555,26 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + def _swap_out_block_table( + self, block_table: BlockTable, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + + new_block_table: BlockTable = [] + for gpu_block in block_table: + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate( + gpu_block.block_hash, gpu_block.num_hashed_tokens) + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + + return new_block_table + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id @@ -566,38 +582,14 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table + self.block_tables[seq.seq_id] = \ + self._swap_out_block_table(self.block_tables[seq.seq_id], + mapping) if seq_group.encoder_seq is not None: - new_block_table = [] - block_table = self.cross_block_tables[request_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.cross_block_tables[request_id] = new_block_table + self.cross_block_tables[request_id] = \ + self._swap_out_block_table(self.cross_block_tables[request_id], + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] From 98c5863ef946dbd52221b6b83517e483f48b3848 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 22 May 2024 18:53:40 -0400 Subject: [PATCH 043/239] Help function avoids prefix caching code in encoder/decoder scenarios; alloc method asserts no prefix caching + enc/dec; refactoring --- vllm/core/block_manager_v1.py | 36 +++++++++++++++++------------------ vllm/core/block_manager_v2.py | 16 ---------------- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index dd6d8d702fae0..40274bd29e9b0 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -290,7 +290,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: + def _allocate_sequence(self, \ + seq: Sequence, \ + ref_count: int, \ + decoder_only: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -300,27 +303,36 @@ def _allocate_sequence(self, seq: Sequence, ref_count: int) -> BlockTable: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() - elif self.enable_caching: + block.ref_count = ref_count + elif decoder_only and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = ref_count #seq_group.num_seqs() + block.ref_count = ref_count block_table.append(block) return block_table def allocate(self, seq_group: SequenceGroup) -> None: + decoder_only = \ + seq_group.get_encoder_seq() is None + + assert decoder_only or (not self.enable_caching), \ + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models." + # Allocate decoder sequences # # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] block_table: BlockTable = \ - self._allocate_sequence(seq, seq_group.num_seqs()) + self._allocate_sequence(seq, + seq_group.num_seqs(), + decoder_only) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): @@ -331,7 +343,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: if encoder_seq is not None: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1) + block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -661,18 +673,6 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def access_all_cross_blocks_in_seq_group( - self, - seq_group: SequenceGroup, - access_time: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.cross_block_tables[seq_group.request_id] - for block in block_table: - block.last_accessed = access_time - def compute_full_blocks_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index a8085f54ac79d..31d1a60657832 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -260,22 +260,6 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids, # type: ignore now) - def access_all_cross_blocks_in_seq_group( - self, - seq_group: SequenceGroup, - now: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.cross_block_tables[seq_group.request_id] - block_ids = [] - for block_id in block_table.physical_block_ids: - block_ids.append(block_id) - self.block_allocator.mark_blocks_as_accessed( - block_ids, # type: ignore - now) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, # while currently we could determine whether one block is computed From 611bcec382ae4291f3a963d9eaa75889fe897251 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:17:48 -0400 Subject: [PATCH 044/239] fixed bugs introduced by merge --- tests/layer/test_self_and_cross_attn.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/layer/test_self_and_cross_attn.py index 811878a347e97..6cc365faa4ea5 100644 --- a/tests/layer/test_self_and_cross_attn.py +++ b/tests/layer/test_self_and_cross_attn.py @@ -1063,19 +1063,18 @@ def cross_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata, scale): +def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = False return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata, scale) + attn_metadata) -def run_cross_attention_test(attn, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata, - scale): +def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, + kv_cache, attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = True return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata, scale) + attn_metadata) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1201,7 +1200,7 @@ def test_prefill_decode_self_and_cross_attention( self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) + self_prefill_packed_value, kv_cache, prefill_attn_metadata) # - Prefill self-attention correct? assert torch.allclose( @@ -1211,7 +1210,7 @@ def test_prefill_decode_self_and_cross_attention( cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata, scale) + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( @@ -1237,7 +1236,7 @@ def test_prefill_decode_self_and_cross_attention( self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata, scale) + self_decode_packed_value, kv_cache, decode_attn_metadata) # - Decode self-attention correct? assert torch.allclose( @@ -1246,8 +1245,7 @@ def test_prefill_decode_self_and_cross_attention( self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata, - scale) + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( From ed2f56deee922614f296f591afb57321387d6112 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:20:22 -0400 Subject: [PATCH 045/239] moved enc/dec test into tests/kernels so that it will be run automatically using existing buildkite config --- tests/{layer => kernels}/test_self_and_cross_attn.py | 0 tests/layer/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{layer => kernels}/test_self_and_cross_attn.py (100%) delete mode 100644 tests/layer/__init__.py diff --git a/tests/layer/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py similarity index 100% rename from tests/layer/test_self_and_cross_attn.py rename to tests/kernels/test_self_and_cross_attn.py diff --git a/tests/layer/__init__.py b/tests/layer/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 From b4ec9c6de46a114986f58b8f83608e8bad1ec755 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 11:21:07 -0400 Subject: [PATCH 046/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 76 +++++++++++------------ 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 6cc365faa4ea5..04dc52d51af19 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -43,22 +43,20 @@ def build_causal_mask(q_max_seq_len, kv_max_seq_len): ''' # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), - diagonal=1) + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 mask = mask.masked_fill(mask == 1, float('-inf')).masked_fill(mask == 0, 0.0) return mask -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[List] = None, - kv_seq_lens: Optional[List] = None) -> torch.Tensor: +def ref_masked_attention(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: ''' "Golden" masked attention reference. Supports two types of masking: @@ -215,26 +213,23 @@ def make_qkv(batch_size, decode_value = torch.zeros( (batch_size, 1, num_heads * head_size)).to(device) - for bdx, (q_seq_len, - kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, + kv_seq_lens)): query[bdx, q_seq_len:, :] = 0 key[bdx, kv_seq_len:, :] = 0 value[bdx, kv_seq_len:, :] = 0 - prefill_query[bdx, - 0:(q_seq_len - 1), :] = query[bdx, - 0:(q_seq_len - 1), :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :] = key[bdx, - 0:(kv_seq_len - 1), :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :] = value[bdx, 0:(kv_seq_len - 1), :] - - decode_query[bdx, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :] + prefill_query[bdx, 0:(q_seq_len - 1), :] = query[bdx, + 0:(q_seq_len - 1), :] + prefill_key[bdx, 0:(kv_seq_len - 1), :] = key[bdx, + 0:(kv_seq_len - 1), :] + prefill_value[bdx, + 0:(kv_seq_len - 1), :] = value[bdx, + 0:(kv_seq_len - 1), :] + + decode_query[bdx, :, :] = query[bdx, (q_seq_len - 1):q_seq_len, :] decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] - decode_value[bdx, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :] + decode_value[bdx, :, :] = value[bdx, (kv_seq_len - 1):kv_seq_len, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -304,12 +299,10 @@ def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): start_loc_list = [0] + list(itertools.accumulate(seq_lens)) packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - for bdx, (seq_len, - start_loc) in enumerate(zip(seq_lens, start_loc_list)): + for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): packed_tensor[start_loc:( - start_loc + - seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] return packed_tensor, start_loc_list @@ -405,9 +398,7 @@ def make_metadata_tensors(is_prompt: bool, * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=device) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) @@ -1063,15 +1054,17 @@ def cross_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata): +def run_self_attention_test(attn: Attention, packed_query, packed_key, + packed_value, kv_cache, + attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = False return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, - kv_cache, attn_metadata: AttentionMetadata): +def run_cross_attention_test(attn: Attention, packed_query, packed_key, + packed_value, kv_cache, + attn_metadata: AttentionMetadata): attn_metadata.do_cross_attn = True return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -1084,10 +1077,13 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_v @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_prefill_decode_self_and_cross_attention( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_prefill_decode_self_and_cross_attention(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' Test: From 84f5510a0a4e7d0b81b32e772e1cf710be83112b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:49:30 -0400 Subject: [PATCH 047/239] block manager v1 NotImplementError's for sliding window and automatic prefix caching --- vllm/core/block_manager_v1.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 40274bd29e9b0..d5da128f1a691 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -277,6 +277,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: + if seq_group.get_encoder_seq() is not None: + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -320,9 +325,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: decoder_only = \ seq_group.get_encoder_seq() is None - assert decoder_only or (not self.enable_caching), \ - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models." + if (self.block_sliding_window is not None) and \ + (not decoder_only): + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + + if self.enable_caching and (not decoder_only): + raise NotImplementedError( + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models.") # Allocate decoder sequences # From cc61959d2075816ee49fa7a802e3c2240e737546 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:56:11 -0400 Subject: [PATCH 048/239] Fixes --- vllm/core/block_manager_v2.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 31d1a60657832..9c6466de468e5 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -152,7 +152,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # encoder prompt. request_id = seq_group.request_id - encoder_seq = seq_group.encoder_seq assert (request_id not in self.cross_block_tables), \ @@ -160,12 +159,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - ) - assert self.block_sliding_window is None - block_table.allocate(encoder_seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, @@ -229,8 +223,6 @@ def free_cross(self, seq_group: SequenceGroup) -> None: self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] - del self.cross_block_tables[seq_group.request_id] - def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids From dcb9abe115cfd6bfa8f2131c645cbc0bb6acb2ab Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 13:58:30 -0400 Subject: [PATCH 049/239] formatting --- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index d5da128f1a691..95e9e5e20940d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -281,7 +281,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 9c6466de468e5..b89f1cd05d1c1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -159,7 +159,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: - block_table: BlockTable = self._allocate_sequence(encoder_seq) + block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, From 5cd154102ab645bc246dd734d797e0d5d8a1652f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 16:23:24 -0400 Subject: [PATCH 050/239] Added explanatory comment to XFormersImpl.forward() --- vllm/attention/backends/xformers.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 36f1343e995df..f9d0d924b395e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -395,14 +395,29 @@ def __init__( def forward( self, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] From f3c430b3a226e249c611630ce776530d48971f0a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 16:37:32 -0400 Subject: [PATCH 051/239] Explanatory comment about sequence argument. --- vllm/sequence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6b07a00f09c6f..9670786f8f16c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,7 +420,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. - encoder_seq: Optional, the single encoder sequence. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( From f2564e0f1cc95fec5880847aa380a462ecd3d0bf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:02:37 -0400 Subject: [PATCH 052/239] clarifying comment --- vllm/sequence.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 9670786f8f16c..9c8fcccab75ae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -614,11 +614,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data - for the single encoder prompt. - cross_block_table: Optional, the cross-attention - block table associated with - the single encoder prompt. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( From e8c40fcf152c5d2f6514830644c8eb683eee7aa9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:08:00 -0400 Subject: [PATCH 053/239] explanatory comment --- vllm/sequence.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 6b07a00f09c6f..a456ecc111e4c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -613,11 +613,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. - encoder_seq_data: Optional, the sequence data - for the single encoder prompt. - cross_block_table: Optional, the cross-attention - block table associated with - the single encoder prompt. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( From 5ccb70be1209521d0aa1e3d7cae7bf7707ac2fd8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:18:03 -0400 Subject: [PATCH 054/239] various fixes according to reviews --- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 14 ++++++++++++++ vllm/sequence.py | 3 ++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 95e9e5e20940d..1c81edb7a2df3 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -352,7 +352,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Allocate encoder sequence encoder_seq = seq_group.get_encoder_seq() - if encoder_seq is not None: + if not decoder_only: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b89f1cd05d1c1..f094bf99e3201 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -132,6 +132,9 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: + decoder_only = \ + seq_group.get_encoder_seq() is None + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) @@ -157,6 +160,17 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" + if (self.block_sliding_window is not None) and \ + (not decoder_only): + raise NotImplementedError( + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported.") + + if self.enable_caching and (not decoder_only): + raise NotImplementedError( + "Automatic prefix caching currently not " + \ + "supported for encoder/decoder models.") + encoder_seq = seq_group.get_encoder_seq() if encoder_seq is not None: block_table = self._allocate_sequence(encoder_seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index a456ecc111e4c..9c8fcccab75ae 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -420,7 +420,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. - encoder_seq: Optional, the single encoder sequence. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( From dfcc28b19188a11c74aee06265051eb8fbbe599f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:22:41 -0400 Subject: [PATCH 055/239] slight refactoring --- vllm/core/block_manager_v1.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 1c81edb7a2df3..2daf45182bba9 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -322,8 +322,8 @@ def _allocate_sequence(self, \ return block_table def allocate(self, seq_group: SequenceGroup) -> None: - decoder_only = \ - seq_group.get_encoder_seq() is None + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None if (self.block_sliding_window is not None) and \ (not decoder_only): @@ -351,7 +351,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.copy() # Allocate encoder sequence - encoder_seq = seq_group.get_encoder_seq() if not decoder_only: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 From 8d3ad05a9f7d568f16eea6e090f6803869fc5443 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:26:54 -0400 Subject: [PATCH 056/239] small refactor --- vllm/core/block_manager_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f094bf99e3201..6e02359f51782 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -132,8 +132,9 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: + encoder_seq = seq_group.get_encoder_seq() decoder_only = \ - seq_group.get_encoder_seq() is None + encoder_seq is None # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -171,8 +172,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") - encoder_seq = seq_group.get_encoder_seq() - if encoder_seq is not None: + if not decoder_only: block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table From 5a7697976a964cf23d6141d9e432abb63d3f9e9d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:34:34 -0400 Subject: [PATCH 057/239] replaced all encoder_seq is not None with not decoder_only --- vllm/core/block_manager_v1.py | 19 +++++++++++++++---- vllm/core/block_manager_v2.py | 5 ++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2daf45182bba9..2e5d531565379 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -496,6 +496,10 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + encoder_seq = seq_group.get_encoder_seq() + decoder_only = \ + encoder_seq is None + # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. request_id = seq_group.request_id @@ -505,7 +509,7 @@ def _get_physical_blocks( continue blocks.update(self.block_tables[seq.seq_id]) # Cross-attention blocks - if seq_group.encoder_seq is not None: + if not decoder_only: blocks.update(self.cross_block_tables[request_id]) return list(blocks) @@ -514,9 +518,12 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if seq_group.encoder_seq is not None: + if not decoder_only: num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate @@ -556,6 +563,8 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None request_id = seq_group.request_id # CPU block -> GPU block. @@ -566,7 +575,7 @@ def swap_in(self, self._swap_in_block_table(self.block_tables[seq.seq_id], mapping) - if seq_group.encoder_seq is not None: + if not decoder_only: self.cross_block_tables[request_id] = \ self._swap_in_block_table(self.cross_block_tables[request_id], mapping) @@ -600,6 +609,8 @@ def _swap_out_block_table( def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` @@ -609,7 +620,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self._swap_out_block_table(self.block_tables[seq.seq_id], mapping) - if seq_group.encoder_seq is not None: + if not decoder_only: self.cross_block_tables[request_id] = \ self._swap_out_block_table(self.cross_block_tables[request_id], mapping) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6e02359f51782..a8090c1f93b5a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -91,6 +91,9 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + encoder_seq = seq_group.get_encoder_seq() + decoder_only = encoder_seq is None + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -98,7 +101,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - if seq_group.encoder_seq is not None: + if not decoder_only: num_required_blocks += BlockTable.get_num_required_blocks( seq_group.encoder_seq.get_token_ids(), block_size=self.block_size, From 09ae4adb656b79897d62d28015f968b0c7471d8e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 17:51:23 -0400 Subject: [PATCH 058/239] added is_encoder_decoder() method to sequence group --- vllm/core/block_manager_v1.py | 36 ++++++++++++++--------------------- vllm/core/block_manager_v2.py | 16 ++++++---------- vllm/sequence.py | 3 +++ 3 files changed, 23 insertions(+), 32 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 2e5d531565379..69a280c8bf9c6 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -277,7 +277,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if seq_group.get_encoder_seq() is not None: + if seq_group.is_encoder_decoder(): raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") @@ -298,7 +298,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def _allocate_sequence(self, \ seq: Sequence, \ ref_count: int, \ - decoder_only: bool = True) -> BlockTable: + is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -309,7 +309,7 @@ def _allocate_sequence(self, \ block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. block.ref_count = ref_count - elif decoder_only and self.enable_caching: + elif not is_encoder_decoder and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) @@ -323,15 +323,15 @@ def _allocate_sequence(self, \ def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None + is_encoder_decoder = seq_group.is_encoder_decoder() if (self.block_sliding_window is not None) and \ - (not decoder_only): + is_encoder_decoder: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - if self.enable_caching and (not decoder_only): + if self.enable_caching and is_encoder_decoder: raise NotImplementedError( "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") @@ -344,17 +344,18 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table: BlockTable = \ self._allocate_sequence(seq, seq_group.num_seqs(), - decoder_only) + is_encoder_decoder) # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() # Allocate encoder sequence - if not decoder_only: + if is_encoder_decoder: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1, decoder_only) + block_table = self._allocate_sequence(encoder_seq, 1, + is_encoder_decoder) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -496,9 +497,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - encoder_seq = seq_group.get_encoder_seq() - decoder_only = \ - encoder_seq is None # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. @@ -509,7 +507,7 @@ def _get_physical_blocks( continue blocks.update(self.block_tables[seq.seq_id]) # Cross-attention blocks - if not decoder_only: + if seq_group.is_encoder_decoder(): blocks.update(self.cross_block_tables[request_id]) return list(blocks) @@ -518,12 +516,10 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if not decoder_only: + if seq_group.is_encoder_decoder(): num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate @@ -563,8 +559,6 @@ def swap_in(self, assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None request_id = seq_group.request_id # CPU block -> GPU block. @@ -575,7 +569,7 @@ def swap_in(self, self._swap_in_block_table(self.block_tables[seq.seq_id], mapping) - if not decoder_only: + if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ self._swap_in_block_table(self.cross_block_tables[request_id], mapping) @@ -609,8 +603,6 @@ def _swap_out_block_table( def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` @@ -620,7 +612,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: self._swap_out_block_table(self.block_tables[seq.seq_id], mapping) - if not decoder_only: + if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ self._swap_out_block_table(self.cross_block_tables[request_id], mapping) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index a8090c1f93b5a..0dd2ffcd182ec 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -91,19 +91,16 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - encoder_seq = seq_group.get_encoder_seq() - decoder_only = encoder_seq is None seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) - if not decoder_only: + if seq_group.is_encoder_decoder(): num_required_blocks += BlockTable.get_num_required_blocks( - seq_group.encoder_seq.get_token_ids(), + seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, ) @@ -136,8 +133,7 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: def allocate(self, seq_group: SequenceGroup) -> None: encoder_seq = seq_group.get_encoder_seq() - decoder_only = \ - encoder_seq is None + is_encoder_decoder = seq_group.is_encoder_decoder() # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -165,17 +161,17 @@ def allocate(self, seq_group: SequenceGroup) -> None: "block table already exists" if (self.block_sliding_window is not None) and \ - (not decoder_only): + is_encoder_decoder: raise NotImplementedError( "Sliding window attention for encoder/decoder models " + \ "is not currently supported.") - if self.enable_caching and (not decoder_only): + if self.enable_caching and is_encoder_decoder: raise NotImplementedError( "Automatic prefix caching currently not " + \ "supported for encoder/decoder models.") - if not decoder_only: + if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table diff --git a/vllm/sequence.py b/vllm/sequence.py index 9c8fcccab75ae..ad6c8d54974c3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -528,6 +528,9 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq From ecd1a998579ac171ce1936444fe9f7c8a6a09c92 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 18:59:03 -0400 Subject: [PATCH 059/239] tests for NotImplemented errors when encoder/decoder models are used with prefix cache or SWA --- tests/core/block/test_block_manager_v2.py | 103 +++++++++++++++++++++- tests/core/test_block_manager.py | 64 +++++++++++++- vllm/core/block_manager_v1.py | 29 +++--- vllm/core/block_manager_v2.py | 28 ++++-- 4 files changed, 205 insertions(+), 19 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 06c3389cfa0f0..cf423d292a25e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,6 +1,8 @@ import pytest -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, + str_not_impl_enc_dec_prefix_cache, + str_not_impl_enc_dec_swa) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -103,6 +105,105 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + ''' + SWA short for Sliding Window Attention. + + At time of writing block manager v2 does not support SWA. + + However even when SWA is implemented for block manager v2, + there will still most likely be a separate workstream required + to enable SWA for encoder/decoder models. + + Therefore this test enforces that one of the following cases + hold true: + 1. Block manager v2 does not support SWA at all (true at time of writing) + 2. Block manager v2 fails with NotImplementError when SWA is enabled + AND a SequenceGroup with an encoder sequence (i.e. in support of an + encoder/decoder model) is passed into can_allocate() as an argument + + The setup for this test is stripped down version of + test_can_allocate_seq_group_encoder_decoder() + ''' + + with pytest.raises((NotImplementedError, AssertionError)) as exc_info: + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + sliding_window=5 # SWA + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + block_manager.can_allocate(seq_group) + + # Assert that either + # 1. Block manager v2 constructor fails with assertion that sliding window + # is not yet supported (most likely near-term outcome at time of + # writing), or + # 2. can_allocate() fails with NotImplementedError due to combiantion of + # encoder/decoder and sliding window attention + if isinstance(exc_info.value, NotImplementedError): + assert str(exc_info.value) == str_not_impl_enc_dec_swa + elif isinstance(exc_info.value, AssertionError): + assert str(exc_info.value) == "Sliding window not yet supported" + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_allocate_encoder_decoder_fails_with_prefix_cache( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, + watermark: float): + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + enable_caching=True # Prefix cache + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + # Assert that either can_allocate() fails with NotImplementedError + # due to combination of encoder/decoder and prefix cache + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index cdaf2f22115e8..6039f568fcf1e 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -7,7 +7,9 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator) + UncachedBlockAllocator, + str_not_impl_enc_dec_prefix_cache, + str_not_impl_enc_dec_swa) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -126,6 +128,66 @@ def test_allocate_encoder_decoder(): assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder_fails_with_swa(): + # SWA short for sliding window attention + + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + sliding_window=5) # swa + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_swa + + # Assert that allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_swa + + +def test_allocate_encoder_decoder_fails_with_prefix_caching(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=True) # Prefix cache + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + # Assert that allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 69a280c8bf9c6..904b12cd97b01 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -15,6 +15,17 @@ from vllm.utils import Device logger = init_logger(__name__) +''' +Exception strings for non-implemented encoder/decoder scenarios +''' + +str_not_impl_enc_dec_swa = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +str_not_impl_enc_dec_prefix_cache = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." class BlockAllocatorBase(ABC): @@ -269,6 +280,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + is_encoder_decoder = seq_group.is_encoder_decoder() + if self.enable_caching and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) cross_num_required_blocks = self._get_seq_num_required_blocks( @@ -277,10 +292,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if seq_group.is_encoder_decoder(): - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + if is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_swa) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -327,14 +340,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + raise NotImplementedError(str_not_impl_enc_dec_swa) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError( - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models.") + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 0dd2ffcd182ec..d2dadd9a63dc2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -8,6 +8,17 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +''' +Exception strings for non-implemented encoder/decoder scenarios +''' + +str_not_impl_enc_dec_swa = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +str_not_impl_enc_dec_prefix_cache = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." SeqId = int EncoderSeqId = str @@ -92,13 +103,20 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. + is_encoder_decoder = seq_group.is_encoder_decoder() + if self.enable_caching and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + + if self.block_sliding_window is not None and is_encoder_decoder: + raise NotImplementedError(str_not_impl_enc_dec_swa) + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) - if seq_group.is_encoder_decoder(): + if is_encoder_decoder: num_required_blocks += BlockTable.get_num_required_blocks( seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, @@ -162,14 +180,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError( - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported.") + raise NotImplementedError(str_not_impl_enc_dec_swa) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError( - "Automatic prefix caching currently not " + \ - "supported for encoder/decoder models.") + raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) From d3935f73b5038ba7acc75fff07282b7f7fda6ed5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:05:36 -0400 Subject: [PATCH 060/239] rename tests --- tests/core/block/test_block_manager_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index cf423d292a25e..c893bc8f4209e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -109,10 +109,10 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, @pytest.mark.parametrize("num_gpu_blocks", [16]) @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_allocate_encoder_decoder_fails_with_swa(block_size: int, - num_seqs_per_group: int, - num_gpu_blocks: int, - watermark: float): +def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): ''' SWA short for Sliding Window Attention. @@ -172,7 +172,7 @@ def test_allocate_encoder_decoder_fails_with_swa(block_size: int, @pytest.mark.parametrize("num_gpu_blocks", [16]) @pytest.mark.parametrize("num_seqs_per_group", [1]) @pytest.mark.parametrize("watermark", [0.0, 0.5]) -def test_allocate_encoder_decoder_fails_with_prefix_cache( +def test_can_allocate_encoder_decoder_fails_with_prefix_cache( block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): From e6a7125383488af42dd5020b65824394c9c112e9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:10:35 -0400 Subject: [PATCH 061/239] spelling error --- tests/core/block/test_block_manager_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index c893bc8f4209e..19ea89d01ca7a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -160,7 +160,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, # 1. Block manager v2 constructor fails with assertion that sliding window # is not yet supported (most likely near-term outcome at time of # writing), or - # 2. can_allocate() fails with NotImplementedError due to combiantion of + # 2. can_allocate() fails with NotImplementedError due to combination of # encoder/decoder and sliding window attention if isinstance(exc_info.value, NotImplementedError): assert str(exc_info.value) == str_not_impl_enc_dec_swa From 68b476203ba9c8342e3f6ba5d9db5e7d369a7a52 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 23 May 2024 19:14:25 -0400 Subject: [PATCH 062/239] isort --- vllm/core/block_manager_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index d2dadd9a63dc2..b43f39a8ffaef 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -8,6 +8,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device + ''' Exception strings for non-implemented encoder/decoder scenarios ''' From a80325dcbe4af189e3542f00ffe92a11a7243e92 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 21:45:13 -0400 Subject: [PATCH 063/239] return output of SequenceGroup constructor --- tests/core/utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 376af0f0eac4f..fb53b6cc5e18b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -145,14 +145,11 @@ def create_seq_group_encoder_decoder( block_size=16, ) - seq_group = SequenceGroup(request_id=request_id, - seqs=seqs, - sampling_params=sampling_params, - arrival_time=time.time(), - encoder_seq=encoder_seq) - - return seq_group - + return SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size From 8b387767512a657fd0051c674f4a594159b67eee Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 21:56:25 -0400 Subject: [PATCH 064/239] capitalize constants --- tests/core/block/test_block_manager_v2.py | 8 ++++---- tests/core/test_block_manager.py | 12 ++++++------ vllm/core/block_manager_v1.py | 17 ++++++++--------- vllm/core/block_manager_v2.py | 12 ++++++------ 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 19ea89d01ca7a..3aed0c58bd264 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,8 +1,8 @@ import pytest from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, - str_not_impl_enc_dec_prefix_cache, - str_not_impl_enc_dec_swa) + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -163,7 +163,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, # 2. can_allocate() fails with NotImplementedError due to combination of # encoder/decoder and sliding window attention if isinstance(exc_info.value, NotImplementedError): - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA elif isinstance(exc_info.value, AssertionError): assert str(exc_info.value) == "Sliding window not yet supported" @@ -201,7 +201,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( # due to combination of encoder/decoder and prefix cache with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE @pytest.mark.parametrize("block_size", [1, 8]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 6039f568fcf1e..7e487a021d3c2 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -8,8 +8,8 @@ from vllm.block import PhysicalTokenBlock from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, UncachedBlockAllocator, - str_not_impl_enc_dec_prefix_cache, - str_not_impl_enc_dec_swa) + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -150,13 +150,13 @@ def test_allocate_encoder_decoder_fails_with_swa(): with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA # Assert that allocate() fails due to SWA with pytest.raises(NotImplementedError) as exc_info: block_manager.allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_swa + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA def test_allocate_encoder_decoder_fails_with_prefix_caching(): @@ -179,13 +179,13 @@ def test_allocate_encoder_decoder_fails_with_prefix_caching(): with pytest.raises(NotImplementedError) as exc_info: block_manager.can_allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE # Assert that allocate() fails due to prefix caching with pytest.raises(NotImplementedError) as exc_info: block_manager.allocate(seq_group) - assert str(exc_info.value) == str_not_impl_enc_dec_prefix_cache + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE def test_append_slot_single_seq(): diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 904b12cd97b01..312690ee45893 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -19,11 +19,11 @@ Exception strings for non-implemented encoder/decoder scenarios ''' -str_not_impl_enc_dec_swa = \ +STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ "is not currently supported." -str_not_impl_enc_dec_prefix_cache = \ +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ "is not currently supported." @@ -272,9 +272,8 @@ def __init__( self.cross_block_tables: Dict[str, BlockTable] = {} def _get_seq_num_required_blocks(self, seq: Sequence) -> int: - if seq is None: - return 0 - return len(seq.logical_token_blocks) + return 0 if seq is None \ + else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share @@ -282,7 +281,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: is_encoder_decoder = seq_group.is_encoder_decoder() if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -293,7 +292,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: if self.block_sliding_window is not None: if is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -340,10 +339,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b43f39a8ffaef..6113561032dd1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -13,11 +13,11 @@ Exception strings for non-implemented encoder/decoder scenarios ''' -str_not_impl_enc_dec_swa = \ +STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ "is not currently supported." -str_not_impl_enc_dec_prefix_cache = \ +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ "Prefix caching for encoder/decoder models " + \ "is not currently supported." @@ -106,10 +106,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: is_encoder_decoder = seq_group.is_encoder_decoder() if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) if self.block_sliding_window is not None and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -181,10 +181,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: if (self.block_sliding_window is not None) and \ is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_swa) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(str_not_impl_enc_dec_prefix_cache) + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) if is_encoder_decoder: block_table = self._allocate_sequence(encoder_seq) From f39c3132af87d410507644c9ea86aec1156f3533 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:20:06 -0400 Subject: [PATCH 065/239] refactored swap-block-table functionality --- vllm/core/block_manager_v1.py | 68 +++++++++++++++-------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 312690ee45893..90a485b39e9d6 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -541,23 +541,25 @@ def can_swap_in(self, else: return AllocStatus.LATER - def _swap_in_block_table( + def _swap_block_table( self, block_table: BlockTable, + src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock]) -> BlockTable: new_block_table = [] - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) return new_block_table @@ -574,13 +576,17 @@ def swap_in(self, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): self.block_tables[seq.seq_id] = \ - self._swap_in_block_table(self.block_tables[seq.seq_id], - mapping) + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ - self._swap_in_block_table(self.cross_block_tables[request_id], - mapping) + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] @@ -589,26 +595,6 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def _swap_out_block_table( - self, block_table: BlockTable, - mapping: Dict[PhysicalTokenBlock, - PhysicalTokenBlock]) -> BlockTable: - - new_block_table: BlockTable = [] - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - - return new_block_table - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: request_id = seq_group.request_id @@ -617,13 +603,17 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self.block_tables[seq.seq_id] = \ - self._swap_out_block_table(self.block_tables[seq.seq_id], - mapping) + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) if seq_group.is_encoder_decoder(): self.cross_block_tables[request_id] = \ - self._swap_out_block_table(self.cross_block_tables[request_id], - mapping) + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) return [(cpu_block.block_number, gpu_block.block_number) for cpu_block, gpu_block in mapping.items()] From 90b5a0e5303c937e56c5b8893fc0cbaeb985ac3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:51:09 -0400 Subject: [PATCH 066/239] Refactored block manager + enc dec + unsupported feature checks into utils --- tests/core/block/test_block_manager_v2.py | 6 ++-- tests/core/test_block_manager.py | 6 ++-- tests/core/utils.py | 1 + vllm/core/block/utils.py | 41 +++++++++++++++++++++++ vllm/core/block_manager_v1.py | 34 ++++--------------- vllm/core/block_manager_v2.py | 35 ++++--------------- 6 files changed, 60 insertions(+), 63 deletions(-) create mode 100644 vllm/core/block/utils.py diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 3aed0c58bd264..f1488916b508a 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,8 +1,8 @@ import pytest -from vllm.core.block_manager_v2 import (BlockSpaceManagerV2, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 7e487a021d3c2..2264fe80c9c03 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -6,10 +6,10 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) + UncachedBlockAllocator) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device diff --git a/tests/core/utils.py b/tests/core/utils.py index fb53b6cc5e18b..7ac565c0eccf1 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -151,5 +151,6 @@ def create_seq_group_encoder_decoder( arrival_time=time.time(), encoder_seq=encoder_seq) + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000000000..6599011771cea --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,41 @@ +"""Block manager utils.""" +from typing import Union + +from vllm.core.block_manager_v1 import BlockSpaceManagerV1 +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.sequence import SequenceGroup + +''' +Exception strings for non-implemented block manager encoder/decoder scenarios +''' + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +def check_no_caching_or_swa_for_blckmgr_encdec( + block_mgr: Union[BlockSpaceManagerV1, + BlockSpaceManagerV2], + seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if block_mgr.block_sliding_window is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) \ No newline at end of file diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 90a485b39e9d6..fa64b96a5e7dc 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -15,17 +16,6 @@ from vllm.utils import Device logger = init_logger(__name__) -''' -Exception strings for non-implemented encoder/decoder scenarios -''' - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." class BlockAllocatorBase(ABC): @@ -279,9 +269,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - is_encoder_decoder = seq_group.is_encoder_decoder() - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -291,8 +279,6 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: cross_num_required_blocks if self.block_sliding_window is not None: - if is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -334,15 +320,8 @@ def _allocate_sequence(self, \ return block_table def allocate(self, seq_group: SequenceGroup) -> None: - encoder_seq = seq_group.get_encoder_seq() is_encoder_decoder = seq_group.is_encoder_decoder() - - if (self.block_sliding_window is not None) and \ - is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) # Allocate decoder sequences # @@ -362,8 +341,8 @@ def allocate(self, seq_group: SequenceGroup) -> None: if is_encoder_decoder: # A SequenceGroup has only a single encoder sequence (at most), # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(encoder_seq, 1, - is_encoder_decoder) + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) # Assign the cross-attention block table for the SequenceGroup. self.cross_block_tables[seq_group.request_id] = block_table @@ -542,8 +521,7 @@ def can_swap_in(self, return AllocStatus.LATER def _swap_block_table( - self, block_table: BlockTable, - src_allocator: BlockAllocatorBase, + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, dest_allocator: BlockAllocatorBase, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock]) -> BlockTable: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6113561032dd1..246ab9c297c5b 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,22 +5,11 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -''' -Exception strings for non-implemented encoder/decoder scenarios -''' - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." - SeqId = int EncoderSeqId = str @@ -104,12 +93,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - is_encoder_decoder = seq_group.is_encoder_decoder() - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) - - if self.block_sliding_window is not None and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -117,7 +101,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - if is_encoder_decoder: + if seq_group.is_encoder_decoder(): num_required_blocks += BlockTable.get_num_required_blocks( seq_group.get_encoder_seq().get_token_ids(), block_size=self.block_size, @@ -151,8 +135,6 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: return block_table def allocate(self, seq_group: SequenceGroup) -> None: - encoder_seq = seq_group.get_encoder_seq() - is_encoder_decoder = seq_group.is_encoder_decoder() # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -179,15 +161,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" - if (self.block_sliding_window is not None) and \ - is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) - - if self.enable_caching and is_encoder_decoder: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) + check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) - if is_encoder_decoder: - block_table = self._allocate_sequence(encoder_seq) + if seq_group.is_encoder_decoder(): + block_table = self._allocate_sequence(seq_group.get_encoder_seq()) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, From 9ee2582172b2b273ede9cb0e3ced9d9f197ecc0b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:57:02 -0400 Subject: [PATCH 067/239] removed circular import --- vllm/core/block/utils.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 6599011771cea..14b99496b12dc 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,10 +1,5 @@ """Block manager utils.""" -from typing import Union - -from vllm.core.block_manager_v1 import BlockSpaceManagerV1 -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.sequence import SequenceGroup - ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' @@ -17,10 +12,9 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." + def check_no_caching_or_swa_for_blckmgr_encdec( - block_mgr: Union[BlockSpaceManagerV1, - BlockSpaceManagerV2], - seq_group: SequenceGroup) -> None: + block_mgr, seq_group: SequenceGroup) -> None: ''' Enforce that prefix caching & sliding-window attention (SWA) are currently unsupported *specifically* for encoder/decoder models. @@ -38,4 +32,4 @@ def check_no_caching_or_swa_for_blckmgr_encdec( raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) \ No newline at end of file + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) From 5d0ac231b751466771f25e9275acede785bf4344 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 25 May 2024 22:58:09 -0400 Subject: [PATCH 068/239] apparently isort has to run last? --- vllm/core/block/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 14b99496b12dc..4113f7e52b84f 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,5 +1,6 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup + ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 1bcc949c7c4634da50d80d7bc4b47185e6ac6f18 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 26 May 2024 12:20:12 -0400 Subject: [PATCH 069/239] slight name change --- vllm/core/block/utils.py | 2 +- vllm/core/block_manager_v1.py | 6 +++--- vllm/core/block_manager_v2.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4113f7e52b84f..3dee7ff16dd84 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -14,7 +14,7 @@ "is not currently supported." -def check_no_caching_or_swa_for_blckmgr_encdec( +def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' Enforce that prefix caching & sliding-window attention (SWA) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index fa64b96a5e7dc..201cba309f6ef 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,7 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -269,7 +269,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) self_num_required_blocks = self._get_seq_num_required_blocks( seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) @@ -321,7 +321,7 @@ def _allocate_sequence(self, \ def allocate(self, seq_group: SequenceGroup) -> None: is_encoder_decoder = seq_group.is_encoder_decoder() - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) # Allocate decoder sequences # diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 246ab9c297c5b..6185a65983d3a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,7 +5,7 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator -from vllm.core.block.utils import check_no_caching_or_swa_for_blckmgr_encdec +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -93,7 +93,7 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( @@ -161,7 +161,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: not in self.cross_block_tables), \ "block table already exists" - check_no_caching_or_swa_for_blckmgr_encdec(self, seq_group) + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) if seq_group.is_encoder_decoder(): block_table = self._allocate_sequence(seq_group.get_encoder_seq()) From 1bece71b45331ed5e371a3842e5a1bba5fe7a160 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:27:47 -0400 Subject: [PATCH 070/239] wip merge --- vllm/core/block_manager_v2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b19f4b184db94..cad42ab3c1ba2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -138,7 +138,6 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable: block_allocator=self.block_allocator, max_block_sliding_window=self.max_block_sliding_window, ) - assert self.block_sliding_window is None block_table.allocate(seq.get_token_ids()) return block_table From 1d882ca8d5825ab68988740e81796abadd083b06 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:38:45 -0400 Subject: [PATCH 071/239] fixed utils to correctly handle encoder/decoder unsupported scenarios --- vllm/core/block/utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 3dee7ff16dd84..dd9345ab52d40 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -13,6 +13,26 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: @@ -29,7 +49,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec( ''' if seq_group.is_encoder_decoder(): - if block_mgr.block_sliding_window is not None: + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: From dfd94692e0b35343e64aace3cd4a496564be5809 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:39:17 -0400 Subject: [PATCH 072/239] formatting --- vllm/core/block/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index dd9345ab52d40..c582ab270473c 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -13,6 +13,7 @@ "Prefix caching for encoder/decoder models " + \ "is not currently supported." + def _get_block_mgr_sliding_window_attr(block_mgr): ''' BlockManagerV1 and BlockManagerV2 have slightly different @@ -34,6 +35,7 @@ def _get_block_mgr_sliding_window_attr(block_mgr): "block_sliding_window nor " + \ "max_block_sliding_window attributes.") + def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' From 3c3687e9f59269268264e9f058ef82220fbac4ea Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 12:52:02 -0400 Subject: [PATCH 073/239] renamed xformers metadata is_cross_attn to is_encoder_decoder_attn --- vllm/attention/backends/xformers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 5059fd8dc265b..b45eda9a68dd5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -119,7 +119,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. - is_cross_attn: bool = False + is_encoder_decoder_attn: bool = False # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention @@ -166,7 +166,7 @@ def has_valid_cross_attn_metadata(self): @property def do_cross_attn(self): - return self.is_cross_attn + return self.is_encoder_decoder_attn @do_cross_attn.setter def do_cross_attn(self, state: bool): @@ -188,9 +188,9 @@ def do_cross_attn(self, state: bool): assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) - self.is_cross_attn = True + self.is_encoder_decoder_attn = True else: - self.is_cross_attn = False + self.is_encoder_decoder_attn = False @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -225,7 +225,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -261,7 +261,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_cross_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -297,7 +297,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn=False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -328,7 +328,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_cross_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn=True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -593,7 +593,7 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.is_cross_attn: + if attn_metadata.is_encoder_decoder_attn: attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: From 6f07c77ef3f1367369a2b5d96b5d0ed576b0a5ff Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 13:04:36 -0400 Subject: [PATCH 074/239] wip getting tests to pass after merge --- tests/kernels/test_self_and_cross_attn.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 04dc52d51af19..3cc60e5412d11 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -14,7 +14,8 @@ # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d0 +# 37782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -113,7 +114,7 @@ def make_qkv(batch_size, max_kv_seq_len, num_heads, head_size, - is_cross_attn=True, + is_encoder_decoder_attn=True, force_max_len=False, device=CUDA_DEVICE): ''' @@ -137,12 +138,12 @@ def make_qkv(batch_size, * max_kv_seq_len: max key/value seq len * num_heads * head_size - * is_cross_attn: if True, query seqlen may differ from key/value seqlen (as + * is_encoder_decoder_attn: if True, query seqlen may differ from key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_seq_len is unused) * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens - and max_kv_seq_len, unless forced by is_cross_attn=False + and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False * device: CPU or CUDA device Returns: @@ -178,7 +179,7 @@ def make_qkv(batch_size, random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if not is_cross_attn: + if not is_encoder_decoder_attn: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: @@ -644,7 +645,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_cross_attn=False, + is_encoder_decoder_attn=False, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -685,7 +686,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_cross_attn=False, + is_encoder_decoder_attn=False, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -840,7 +841,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - is_cross_attn=False) + is_encoder_decoder_attn=False) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1004,7 +1005,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - is_cross_attn=True) + is_encoder_decoder_attn=True) ideal_output = ref_masked_attention(query, key, From 481c6463e8b7f7f744fd799bb945301a52182118 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 14:54:50 -0400 Subject: [PATCH 075/239] passing tests; formatting --- tests/kernels/test_self_and_cross_attn.py | 7 ++++--- vllm/attention/backends/xformers.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 3cc60e5412d11..d99a246712425 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -138,9 +138,10 @@ def make_qkv(batch_size, * max_kv_seq_len: max key/value seq len * num_heads * head_size - * is_encoder_decoder_attn: if True, query seqlen may differ from key/value seqlen (as - is often the case for cross-attention); o/w, query/key/value seqlens match - at each batch index (max_kv_seq_len is unused) + * is_encoder_decoder_attn: if True, query seqlen may differ from + key/value seqlen (as is often the case for cross-attention); + o/w, query/key/value seqlens match at each batch index + (max_kv_seq_len is unused) * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b45eda9a68dd5..6886b4836bd87 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -225,7 +225,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn= + False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -261,7 +262,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn= + True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -297,7 +299,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn=False, # Begin cross-attention fields below... + is_encoder_decoder_attn= + False, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -328,7 +331,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn=True, # Begin cross-attention fields below... + is_encoder_decoder_attn= + True, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, From 9c8e19d3bc8a56b1ad31c58d786a9e4c25c593b2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:21:29 -0400 Subject: [PATCH 076/239] removed overprovisioning from make_block_tables_slot_mapping() --- tests/kernels/test_self_and_cross_attn.py | 25 +++++++---------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index d99a246712425..83576c52b8688 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -510,7 +510,7 @@ def make_block_tables_slot_mapping(block_size, # Over-provision block table blocks by 1 num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) + 1 + num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) @@ -521,7 +521,7 @@ def make_block_tables_slot_mapping(block_size, decode_slot_mapping = [] slot_mapping = [] block_base_idx = block_base_addr + sum( - num_blocks_list) * 2 - 1 # Support more blocks than needed + num_blocks_list) # Support more blocks than needed max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] @@ -692,21 +692,6 @@ def make_metadata_self_cross( cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) - -def make_attention(num_heads: int, head_size: int, scale: float): - ''' - Construct an instance of the Attention wrapper, suited to the number of - attention heads and head dimension (num_heads and head_size respectively) as - well as the attention scale factor (scale) - ''' - - return Attention( - num_heads, - head_size, - scale=scale, - ) - - def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): ''' Compute & build entities required for the self-/cross-attention test. @@ -730,7 +715,11 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) - attn = make_attention(num_heads, head_size, scale) + attn = Attention( + num_heads, + head_size, + scale=scale, + ) kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache From ed17ee38478c6b67cdbb63d6cb7f929a9bd2a08b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:25:06 -0400 Subject: [PATCH 077/239] comments' --- tests/kernels/test_self_and_cross_attn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 83576c52b8688..f132bf571defa 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -481,7 +481,7 @@ def make_block_tables_slot_mapping(block_size, The first block is at - block_base_addr + sum(num_blocks_list) * 2 - 1 + block_base_addr + sum(min. block count for each seq_len) and subsequent blocks count downward toward block_base_addr @@ -508,7 +508,7 @@ def make_block_tables_slot_mapping(block_size, * max_block_idx: the highest block address within this block table ''' - # Over-provision block table blocks by 1 + # Provision minimum number of KV cache blocks num_blocks_list = [ num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens @@ -692,6 +692,7 @@ def make_metadata_self_cross( cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) + def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): ''' Compute & build entities required for the self-/cross-attention test. @@ -716,10 +717,10 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): scale = float(1.0 / (head_size**0.5)) attn_backend = make_backend(backend_name) attn = Attention( - num_heads, - head_size, - scale=scale, - ) + num_heads, + head_size, + scale=scale, + ) kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache From d630aa8090463bdc12554e63d79dda1ed7caa253 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 15:34:08 -0400 Subject: [PATCH 078/239] clarified block table address formula --- tests/kernels/test_self_and_cross_attn.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index f132bf571defa..a4379b1ece49f 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -479,11 +479,23 @@ def make_block_tables_slot_mapping(block_size, ''' Construct fake block tables & slot mappings. - The first block is at + For a sequence with num_tokens tokens the minimum number + of required KV cache blocks is - block_base_addr + sum(min. block count for each seq_len) + num_blocks = (num_tokens + block_size) // block_size - and subsequent blocks count downward toward block_base_addr + Then the minimum KV cache size in blocks is + + total_cache_blocks = sum(num_blocks for all seqs) + + Then, the blocktable mapping counts downward from + + block_base_addr + total_cache_blocks + + to + + block_base_addr + Arguments: @@ -520,8 +532,9 @@ def make_block_tables_slot_mapping(block_size, prefill_slot_mapping = [] decode_slot_mapping = [] slot_mapping = [] - block_base_idx = block_base_addr + sum( - num_blocks_list) # Support more blocks than needed + # Compute uppermost address of block table + total_cache_blocks = sum(num_blocks_list) + block_base_idx = block_base_addr + total_cache_blocks max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] From b664806905cee4697470907557f323bc25fd9ddb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:22:00 -0400 Subject: [PATCH 079/239] wip changing cross attention flag --- tests/kernels/test_self_and_cross_attn.py | 11 ++-- vllm/attention/backends/abstract.py | 7 +++ vllm/attention/backends/xformers.py | 77 +++++++++++++---------- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index a4379b1ece49f..e2dcf5b02f165 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -7,7 +7,8 @@ import torch from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionType) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -114,7 +115,7 @@ def make_qkv(batch_size, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=True, + attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len=False, device=CUDA_DEVICE): ''' @@ -180,7 +181,7 @@ def make_qkv(batch_size, random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if not is_encoder_decoder_attn: + if attn_type != AttentionType.ENCODER_DECODER: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: @@ -845,7 +846,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=False) + attn_type=False) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1009,7 +1010,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - is_encoder_decoder_attn=True) + attn_type=True) ideal_output = ref_masked_attention(query, key, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 6396103bf5efa..15e9a7fa5af3a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -5,6 +5,13 @@ import torch +from enum import Enum, auto + +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previously layer Q/K/V + ENCODER = auto() # Encoder attention between previously layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + class AttentionBackend(ABC): """Abstract class for attention backends.""" diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6886b4836bd87..144cb68bbff0b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -10,11 +10,12 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger + logger = init_logger(__name__) @@ -119,7 +120,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; # otherwise, self-attention data structures will be returned. - is_encoder_decoder_attn: bool = False + _attn_type: AttentionType = AttentionType.DECODER # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention @@ -165,13 +166,13 @@ def has_valid_cross_attn_metadata(self): return True @property - def do_cross_attn(self): - return self.is_encoder_decoder_attn + def attention_type(self) -> AttentionType: + return self._attn_type - @do_cross_attn.setter - def do_cross_attn(self, state: bool): + @attention_type.setter + def attention_type(self, atype: AttentionType) -> None: - if state: + if atype == AttentionType.ENCODER_DECODER: assert self.has_valid_cross_attn_metadata, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" @@ -188,17 +189,18 @@ def do_cross_attn(self, state: bool): assert self.cross_seq_lens is not None self.max_cross_seq_len = max(self.cross_seq_lens) - self.is_encoder_decoder_attn = True + self._attn_type = AttentionType.ENCODER_DECODER else: - self.is_encoder_decoder_attn = False + # AttentionType.{ENCODER,DECODER} + self._attn_type = atype @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if not self.do_cross_attn: - # Self-attention prefill + if self._attn_type != AttentionType.ENCODER_DECODER: + # Decoder or encoder self-attention prefill if self._self_cached_prefill_metadata is not None: return self._self_cached_prefill_metadata @@ -225,8 +227,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn= - False, # Begin cross-attention fields below... + _attn_type= + self._attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -235,7 +237,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return self._self_cached_prefill_metadata else: - # Cross-attention prefill + # Encoder/decoder cross-attention prefill if self._cross_cached_prefill_metadata is not None: return self._cross_cached_prefill_metadata @@ -262,8 +264,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - is_encoder_decoder_attn= - True, # Begin cross-attention fields below... + _attn_type= + AttentionType.ENCODER_DECODER, + # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -276,8 +279,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if not self.do_cross_attn: - # Self-attention decode + if self._attn_type != AttentionType.ENCODER_DECODER: + # Decoder or encoder self-attention prefill if self._self_cached_decode_metadata is not None: return self._self_cached_decode_metadata @@ -299,8 +302,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn= - False, # Begin cross-attention fields below... + _attn_type= + self._attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -309,7 +312,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return self._self_cached_decode_metadata else: - # Cross-attention decode + # Encoder/decoder cross-attention decode if self._cross_cached_decode_metadata is not None: return self._cross_cached_decode_metadata @@ -331,8 +334,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - is_encoder_decoder_attn= - True, # Begin cross-attention fields below... + _attn_type= + AttentionType.ENCODER_DECODER, + # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, max_cross_seq_len=self.max_cross_seq_len, @@ -443,7 +447,7 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - do_cross_attn = attn_metadata.do_cross_attn + attn_type = attn_metadata._attn_type if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -454,7 +458,7 @@ def forward( if (key is not None) and (value is not None): - if do_cross_attn: + if attn_type == AttentionType.ENCODER_DECODER: # Update cross-attention KV cache (prefill-only) # During cross-attention decode, key & value will be None, # preventing this IF-statement branch from running @@ -476,9 +480,9 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert do_cross_attn or (key.shape[0] + assert attn_type == AttentionType.ENCODER_DECODER or (key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert do_cross_attn or (value.shape[0] + assert attn_type == AttentionType.ENCODER_DECODER or (value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) @@ -487,7 +491,9 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if not do_cross_attn and key is not None and value is not None: + if attn_type != AttentionType.ENCODER_DECODER \ + and key is not None and value is not None: + key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -529,7 +535,7 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: - if do_cross_attn: + if attn_type == AttentionType.ENCODER_DECODER: # Paged attention against cross-attention KV cache seq_lens_arg = decode_meta.cross_seq_lens_tensor max_seq_len_arg = decode_meta.max_cross_seq_len @@ -597,12 +603,19 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.is_encoder_decoder_attn: + if attn_metadata.attention_type() == AttentionType.ENCODER_DECODER: + # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + if attn_metadata.attention_type() == AttentionType.ENCODER: + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens) + else: + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 611df433882c1e10235084426d63fd817466dd19 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:27:41 -0400 Subject: [PATCH 080/239] yapf fix --- vllm/core/block/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index c582ab270473c..4da5a965616ac 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,6 +1,5 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup - ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 8ee49dde309a93fd309f0117f74cde4949e958e4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 22:30:12 -0400 Subject: [PATCH 081/239] yapf fix --- vllm/core/block/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4da5a965616ac..2c412a8f472e0 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,8 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup -''' -Exception strings for non-implemented block manager encoder/decoder scenarios -''' + +# Exception strings for non-implemented block manager enc/dec scenarios STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ From 039c25eb6661f2aa89b4239235451f2c6f61d63d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 23:03:44 -0400 Subject: [PATCH 082/239] upstream merge --- tests/core/utils.py | 36 +++++++++++++++++++++++++++--------- vllm/core/block/utils.py | 1 + 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index 1ccc5c3cc0a8e..cd2045b8a1889 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -55,12 +55,24 @@ def create_dummy_prompt_encoder_decoder( # and prompt "0 ... block_size". decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) - decoder_prompt = Sequence(int(request_id), decoder_prompt_str, - decoder_prompt_tokens, block_size) + + decoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) - encoder_prompt = Sequence(int(request_id), encoder_prompt_str, - encoder_prompt_tokens, block_size) + encoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams( @@ -134,8 +146,11 @@ def create_seq_group_encoder_decoder( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) @@ -149,8 +164,11 @@ def create_seq_group_encoder_decoder( # Encoder sequence encoder_seq = Sequence( seq_id=seq_id_start + len(seq_output_lens), - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) @@ -162,4 +180,4 @@ def create_seq_group_encoder_decoder( def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size + return (seq_len + block_size - 1) // block_size \ No newline at end of file diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 4da5a965616ac..c582ab270473c 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,5 +1,6 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup + ''' Exception strings for non-implemented block manager encoder/decoder scenarios ''' From 8e9ef5bb5ae7bc3ece7ae527e591df093ff7f31e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 28 May 2024 23:06:08 -0400 Subject: [PATCH 083/239] fix formatting issue --- vllm/core/block/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index c582ab270473c..372bfb5ed2f9e 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,9 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup -''' -Exception strings for non-implemented block manager encoder/decoder scenarios -''' +# Exception strings for non-implemented block manager encoder/decoder scenarios STR_NOT_IMPL_ENC_DEC_SWA = \ "Sliding window attention for encoder/decoder models " + \ From 19d1ca5a6471a55603f257b7c6f6f1364b9d9b0e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:32:16 -0400 Subject: [PATCH 084/239] passing tests with new attention type enum --- tests/kernels/test_self_and_cross_attn.py | 64 ++++++++++++++++------- vllm/attention/backends/xformers.py | 6 +-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e2dcf5b02f165..b36212abf01d1 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -593,6 +593,7 @@ def make_metadata_self_cross( context_lens: List[int], block_tables, slot_mapping, + is_encoder_only_test: bool, device=CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, @@ -625,6 +626,9 @@ def make_metadata_self_cross( * AttentionMetadata structure supporting self- and cross-attention ''' + default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ + else AttentionType.DECODER + if is_prompt: num_prefills = len(seq_lens) num_prefill_tokens = sum(seq_lens) @@ -660,7 +664,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_encoder_decoder_attn=False, + _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -701,7 +705,7 @@ def make_metadata_self_cross( context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - is_encoder_decoder_attn=False, + _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, cross_slot_mapping=cross_slot_mapping_tensor, cross_block_tables=cross_block_tables) @@ -745,6 +749,7 @@ def self_attn_setup(batch_size, block_size, scale, max_q_seq_len, + attn_type: AttentionType, block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -846,7 +851,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - attn_type=False) + attn_type=attn_type) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -1010,7 +1015,7 @@ def cross_attn_setup_reuses_query(query, max_kv_seq_len, num_heads, head_size, - attn_type=True) + attn_type=AttentionType.ENCODER_DECODER) ideal_output = ref_masked_attention(query, key, @@ -1062,8 +1067,9 @@ def cross_attn_setup_reuses_query(query, def run_self_attention_test(attn: Attention, packed_query, packed_key, packed_value, kv_cache, - attn_metadata: AttentionMetadata): - attn_metadata.do_cross_attn = False + attn_metadata: AttentionMetadata, + attn_type: AttentionType): + attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -1071,10 +1077,27 @@ def run_self_attention_test(attn: Attention, packed_query, packed_key, def run_cross_attention_test(attn: Attention, packed_query, packed_key, packed_value, kv_cache, attn_metadata: AttentionMetadata): - attn_metadata.do_cross_attn = True + attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) +@pytest.mark.skip() +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) +@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +def test_encoder_attention(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: + + pass @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1083,15 +1106,15 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_prefill_decode_self_and_cross_attention(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_q_seq_len: int, + max_kv_seq_len: int) -> None: ''' - Test: + Encoder/decoder attention test: * Construct fake test vectors for self- and cross-attention * Construct attention metadata structure with self- and cross-attention @@ -1159,6 +1182,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, block_size, scale, max_q_seq_len, + attn_type=AttentionType.DECODER, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1195,6 +1219,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, context_lens, self_prefill_block_tables, self_prefill_slot_mapping, + is_encoder_only_test=False, cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, @@ -1202,7 +1227,8 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata) + self_prefill_packed_value, kv_cache, prefill_attn_metadata, + attn_type=AttentionType.DECODER) # - Prefill self-attention correct? assert torch.allclose( @@ -1231,6 +1257,7 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, context_lens, self_decode_block_tables, self_decode_slot_mapping, + is_encoder_only_test=False, cross_seq_lens=cross_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, @@ -1238,7 +1265,8 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata) + self_decode_packed_value, kv_cache, decode_attn_metadata, + attn_type=AttentionType.DECODER) # - Decode self-attention correct? assert torch.allclose( @@ -1253,4 +1281,4 @@ def test_prefill_decode_self_and_cross_attention(num_heads: int, assert torch.allclose( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) + cross_decode_packed_ideal_output)) \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 144cb68bbff0b..152f339e27485 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -447,7 +447,7 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - attn_type = attn_metadata._attn_type + attn_type = attn_metadata.attention_type if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -603,12 +603,12 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type() == AttentionType.ENCODER_DECODER: + if attn_metadata.attention_type == AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) else: - if attn_metadata.attention_type() == AttentionType.ENCODER: + if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens) From 700b6dca120d859a5b2a8d89f4b88a2e51187a86 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:33:39 -0400 Subject: [PATCH 085/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 43 ++++++++++++----------- vllm/attention/backends/abstract.py | 8 ++--- vllm/attention/backends/xformers.py | 26 +++++++------- 3 files changed, 38 insertions(+), 39 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index b36212abf01d1..c9ae74e788754 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -7,8 +7,7 @@ import torch from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionType) +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -1081,6 +1080,7 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1089,16 +1089,13 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_encoder_attention(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_q_seq_len: int, max_kv_seq_len: int) -> None: pass + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1106,13 +1103,9 @@ def test_encoder_attention(num_heads: int, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) -def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_q_seq_len: int, - max_kv_seq_len: int) -> None: +def test_enc_dec_self_and_cross_attention_prefill_decode_phases( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: ''' Encoder/decoder attention test: @@ -1226,8 +1219,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, ) self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, prefill_packed_query, self_prefill_packed_key, - self_prefill_packed_value, kv_cache, prefill_attn_metadata, + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? @@ -1264,8 +1261,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, ) self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, decode_packed_query, self_decode_packed_key, - self_decode_packed_value, kv_cache, decode_attn_metadata, + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, attn_type=AttentionType.DECODER) # - Decode self-attention correct? @@ -1281,4 +1282,4 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases(num_heads: int, assert torch.allclose( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) \ No newline at end of file + cross_decode_packed_ideal_output)) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 15e9a7fa5af3a..cffd2d577777c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,16 +1,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields +from enum import Enum, auto from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) import torch -from enum import Enum, auto class AttentionType(Enum): - DECODER = auto() # Decoder attention between previously layer Q/K/V - ENCODER = auto() # Encoder attention between previously layer Q/K/V - ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + DECODER = auto() # Decoder attention between previously layer Q/K/V + ENCODER = auto() # Encoder attention between previously layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V class AttentionBackend(ABC): diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 152f339e27485..3e6fe0717b0e7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -15,7 +15,6 @@ PagedAttentionMetadata) from vllm.logger import init_logger - logger = init_logger(__name__) @@ -227,8 +226,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - _attn_type= - self._attn_type, # Begin cross-attention fields below... + _attn_type=self. + _attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -264,8 +263,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - _attn_type= - AttentionType.ENCODER_DECODER, + _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, @@ -302,8 +300,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - _attn_type= - self._attn_type, # Begin cross-attention fields below... + _attn_type=self. + _attn_type, # Begin cross-attention fields below... cross_seq_lens=None, cross_seq_lens_tensor=None, max_cross_seq_len=None, @@ -334,8 +332,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - _attn_type= - AttentionType.ENCODER_DECODER, + _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... cross_seq_lens=self.cross_seq_lens, cross_seq_lens_tensor=self.cross_seq_lens_tensor, @@ -480,10 +477,10 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert attn_type == AttentionType.ENCODER_DECODER or (key.shape[0] - == num_prefill_tokens + num_decode_tokens) - assert attn_type == AttentionType.ENCODER_DECODER or (value.shape[0] - == num_prefill_tokens + num_decode_tokens) + assert attn_type == AttentionType.ENCODER_DECODER or ( + key.shape[0] == num_prefill_tokens + num_decode_tokens) + assert attn_type == AttentionType.ENCODER_DECODER or ( + value.shape[0] == num_prefill_tokens + num_decode_tokens) output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -603,7 +600,8 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): This is a hack. if attn_metadata.attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type == AttentionType.ENCODER_DECODER: + if attn_metadata.attention_type == \ + AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.cross_seq_lens) From 76c639a461ff30762e23001e73ada922ee8d7c3f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 11:40:27 -0400 Subject: [PATCH 086/239] wip encoder test --- tests/kernels/test_self_and_cross_attn.py | 189 +++++++++++++++++++++- 1 file changed, 187 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c9ae74e788754..02c58a1d57909 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -718,8 +718,9 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * num_heads: Number of attention heads * head_size: Head dimension - * num_blocks: Number of KV cache blocks + * num_blocks: Number of KV cache blocks (no KV cache if None) * block_size: Number of offsets within a KV cache block + (no KV cache if None) * backend_name: selection of backend Returns: @@ -729,6 +730,7 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): * attn: Attention wrapper instance * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * head_size) + * None if num_blocks or block_size is None ''' scale = float(1.0 / (head_size**0.5)) @@ -738,8 +740,14 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): head_size, scale=scale, ) + if num_blocks is None or num_heads is None: + # Caller does not require a KV cache + return scale, attn_backend, attn, None + + # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache + def self_attn_setup(batch_size, @@ -1093,7 +1101,184 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: - pass + ''' + Encoder-only attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Num KV cache blocks + # num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = self_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + attn_type=AttentionType.DECODER, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) From 882640e51a3727c058503e0fc04c91c32f2e11bf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 12:09:33 -0400 Subject: [PATCH 087/239] first pass at encoder attention test --- tests/kernels/test_self_and_cross_attn.py | 333 ++++++++++++---------- 1 file changed, 190 insertions(+), 143 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 02c58a1d57909..64f7ec0eaac40 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -748,16 +748,154 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache +def encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_q_seq_len): + ''' + Set up test vectors & data structures for encoder attention test. + + A triplet of synthetic query/key/value tensors are constructed ("baseline" + query/key/value). Given this is a self-attention test, the key & value + sequences will have the same length as the corresponding queries. + + "Prefill" query/key/value tensors are derived by masking out the last value + in each baseline query/key/value. These tensors are used to test prefill & + populate KV cache for a subsequent decode test. + + "Decode" query/key/value tensors are derived by extracting *only* the last + value from each baseline query/key/value (i.e. complement of the prefill + tensors.) These tensors are used to test decode, conditional on the kv cache + being populated during the prefill test. + + The baseline query/key/value tensors are passed to an ideal reference + self-attention implementation to generate a "Baseline" ideal output tensor. + This tensor is split into the "Prefill" ideal output tensor (all but the + last element of each output sequence) and the "Decode" ideal output tensor + (*only* the last element of each output sequence); the "Prefill" and + "Decode" ideal output tensors can be used to validate the prefill and decode + test results, respectively. + + This function also constructs the self-attention KV cache memory mapping + (slot mapping and block table), ensuring that the block table starts at + block_base_addr + + Arguments: + + * batch_size + * num_heads: Number of attention heads + * head_size: Head dimension + * block_size: Number of offsets per KV cache block + * scale: attention scale parameter + * max_q_seq_len: upper limit on query length for synthetic test vectors + * block_base_addr: self-attention block table base address + + Returns: + + * query: "baseline" query; batch_size x padded_seq_len x num_heads x + head_size + * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x + head_size + * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads + x head_size + * prefill_packed_value: self-attn "prefill" value; number_of_tokens x + num_heads x head_size + * prefill_packed_ideal_output: self-attn "prefill" ideal output; + number_of_tokens x num_heads x head_size + * prefill_q_seq_lens: list of token counts for each *prefill query* (one + less than baseline query) + * prefill_kv_seq_lens: list of token counts for each self-attn *prefill + key/value* (should match prefill_q_seq_lens) + * decode_packed_query: "decode" query; number_of_tokens x num_heads x + head_size + * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x + head_size + * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads + x head_size + * decode_packed_ideal_output: self-attn "decode" ideal output; + number_of_tokens x num_heads x head_size + * decode_q_seq_lens: list of token counts for each *decode query* (should + be 1) + * decode_kv_seq_lens: list of token counts for each self-attn *decode + key/value* (should match decode_q_seq_lens) + * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x + head_size + * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens + x num_heads x head_size + * decode_block_tables: fake self-attn decode-phase block table + * decode_slot_mapping: fake self-attn decode-phase slot mapping + * prefill_slot_mapping: fake self-attn prefill-phase slot mapping + * prefill_block_tables: fake self-attn prefill-phase block table + * max_block_idx: highest block address in the self-attention block-table + ''' + max_kv_seq_len = max_q_seq_len -def self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - attn_type: AttentionType, - block_base_addr=0): + query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_seq_lens, \ + kv_seq_lens, \ + _, \ + _, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER) + + # No attention mask + ideal_output = ref_masked_attention(query, + key, + value, + scale=scale, + q_seq_lens=q_seq_lens, + kv_seq_lens=kv_seq_lens) + + prefill_ideal_output = torch.zeros_like(ideal_output) + for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + bdx, :prefill_q_seq_len] + + prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, + prefill_q_seq_lens) + + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, _, _ = pack_qkv( + prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, + prefill_kv_seq_lens) + + return query, \ + prefill_packed_query, \ + prefill_packed_key, \ + prefill_packed_value, \ + prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens, \ + q_seq_lens, \ + kv_seq_lens + +def decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for self-attention test. @@ -858,7 +996,7 @@ def self_attn_setup(batch_size, max_kv_seq_len, num_heads, head_size, - attn_type=attn_type) + attn_type=AttentionType.DECODER) causal_mask = build_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -929,17 +1067,17 @@ def self_attn_setup(batch_size, max_block_idx -def cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=0): +def enc_dec_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for cross-attention test. @@ -1135,7 +1273,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + _ = basic_setup(num_heads, head_size, None, None, @@ -1146,140 +1284,50 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, self_block_base_addr = 0 query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ + prefill_kv_seq_lens, \ _, \ _, \ q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - attn_type=AttentionType.DECODER, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests + kv_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_q_seq_len) context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, prefill_q_seq_lens, context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, + None, + None, + is_encoder_only_test=True, + cross_seq_lens=None, + cross_block_tables=None, + cross_slot_mapping=None, ) - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + packed_actual_output: torch.Tensor = run_self_attention_test( attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - + packed_ideal_output, + packed_actual_output.view_as( + packed_ideal_output)) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1354,13 +1402,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = self_attn_setup(batch_size, + cross_block_base_addr = decoder_attn_setup(batch_size, num_heads, head_size, block_size, scale, max_q_seq_len, - attn_type=AttentionType.DECODER, block_base_addr=self_block_base_addr) # Cross-attention setup @@ -1374,7 +1421,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = cross_attn_setup_reuses_query(query, + _ = enc_dec_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, From 584297e391915d4eae3ae73a2fea47d79a58cf95 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 12:46:07 -0400 Subject: [PATCH 088/239] wip encoder attention test --- tests/kernels/test_self_and_cross_attn.py | 39 +++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 64f7ec0eaac40..027f17826aacf 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -751,8 +751,10 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): def encoder_attn_setup(batch_size, num_heads, head_size, + block_size, scale, - max_q_seq_len): + max_q_seq_len, + block_base_addr=0): ''' Set up test vectors & data structures for encoder attention test. @@ -871,6 +873,15 @@ def encoder_attn_setup(batch_size, prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, prefill_q_seq_lens) + _, \ + _, \ + prefill_slot_mapping, \ + prefill_block_tables, \ + _, \ + _, \ + _ = make_block_tables_slot_mapping( + block_size, q_seq_lens, block_base_addr=block_base_addr) + prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( @@ -884,10 +895,8 @@ def encoder_attn_setup(batch_size, prefill_packed_ideal_output, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens, \ - q_seq_lens, \ - kv_seq_lens + prefill_slot_mapping, \ + prefill_block_tables def decoder_attn_setup(batch_size, num_heads, @@ -1227,7 +1236,7 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, attn_metadata) -@pytest.mark.skip() +#@pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1280,22 +1289,20 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, backend_name) # Self-attention setup - - self_block_base_addr = 0 - - query, \ + # Let encoder_attn_setup() choose default block table + # base address + _, \ packed_query, \ packed_key, \ packed_value, \ packed_ideal_output, \ prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - _, \ _, \ - q_seq_lens, \ - kv_seq_lens = encoder_attn_setup(batch_size, + slot_mapping, \ + block_tables = encoder_attn_setup(batch_size, num_heads, head_size, + block_size, scale, max_q_seq_len) @@ -1306,8 +1313,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, True, prefill_q_seq_lens, context_lens, - None, - None, + block_tables, + slot_mapping, is_encoder_only_test=True, cross_seq_lens=None, cross_block_tables=None, From a89c7c678b965cce38fb74ee3688cc74d218aee1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:10:49 -0400 Subject: [PATCH 089/239] encoder attention test passes! --- tests/kernels/test_self_and_cross_attn.py | 80 +++++++++++------------ 1 file changed, 38 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 027f17826aacf..def055d356d6a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -837,25 +837,25 @@ def encoder_attn_setup(batch_size, query, \ key, \ value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ + _, \ + _, \ + _, \ + _, \ + _, \ + _, \ q_seq_lens, \ kv_seq_lens, \ _, \ _, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER) + _, \ + _, \ + _, \ + _ = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER) # No attention mask ideal_output = ref_masked_attention(query, @@ -865,38 +865,36 @@ def encoder_attn_setup(batch_size, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - prefill_ideal_output = torch.zeros_like(ideal_output) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - bdx, :prefill_q_seq_len] + # prefill_ideal_output = torch.zeros_like(ideal_output) + # for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ + # bdx, :prefill_q_seq_len] - prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, + q_seq_lens) + block_tables, \ _, \ _, \ - prefill_slot_mapping, \ - prefill_block_tables, \ _, \ + slot_mapping, \ _, \ _ = make_block_tables_slot_mapping( block_size, q_seq_lens, block_base_addr=block_base_addr) - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens) + packed_query, \ + packed_key, \ + packed_value, _, _ = pack_qkv( + query, key, value, q_seq_lens, + kv_seq_lens) - return query, \ - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, \ - prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - prefill_slot_mapping, \ - prefill_block_tables + return packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens def decoder_attn_setup(batch_size, num_heads, @@ -1291,15 +1289,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, # Self-attention setup # Let encoder_attn_setup() choose default block table # base address - _, \ packed_query, \ packed_key, \ packed_value, \ packed_ideal_output, \ - prefill_q_seq_lens, \ - _, \ + block_tables, \ slot_mapping, \ - block_tables = encoder_attn_setup(batch_size, + q_seq_lens = encoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -1311,7 +1307,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_metadata_self_cross( attn_backend, True, - prefill_q_seq_lens, + q_seq_lens, context_lens, block_tables, slot_mapping, From 0bbd0db0f260d1b027fb46f47a417ab4d3532600 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:15:47 -0400 Subject: [PATCH 090/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index def055d356d6a..f3c5c3fd08cc5 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -747,7 +747,8 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache - + + def encoder_attn_setup(batch_size, num_heads, head_size, @@ -870,8 +871,7 @@ def encoder_attn_setup(batch_size, # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ # bdx, :prefill_q_seq_len] - packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) block_tables, \ _, \ @@ -896,6 +896,7 @@ def encoder_attn_setup(batch_size, slot_mapping, \ q_seq_lens + def decoder_attn_setup(batch_size, num_heads, head_size, @@ -1245,7 +1246,6 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: - ''' Encoder-only attention test: @@ -1327,10 +1327,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as( - packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From af998ca8afad19e82320d15119d342bfbbca31bb Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 13:45:51 -0400 Subject: [PATCH 091/239] encoder test arguments --- tests/kernels/test_self_and_cross_attn.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index def055d356d6a..b0c8a143c45f3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1233,18 +1233,15 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) - -#@pytest.mark.skip() @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) -@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +@pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, - max_q_seq_len: int, max_kv_seq_len: int) -> None: + max_seq_len: int) -> None: ''' Encoder-only attention test: @@ -1300,7 +1297,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, head_size, block_size, scale, - max_q_seq_len) + max_seq_len) context_lens = [0 for _ in range(batch_size)] From 78c678add4d25a2c13fd7e5ae8966a83bafee933 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:42:02 -0400 Subject: [PATCH 092/239] type hints; formatting --- tests/kernels/test_self_and_cross_attn.py | 349 +++++++++++----------- 1 file changed, 176 insertions(+), 173 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index b0c8a143c45f3..6e2a3ed19ef76 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1,7 +1,7 @@ import copy import itertools import random -from typing import List, Optional +from typing import List, Optional, Union import pytest import torch @@ -29,7 +29,8 @@ MAX_K_SEQ_LENS = [128] -def build_causal_mask(q_max_seq_len, kv_max_seq_len): +def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: ''' Create a q_max_seq_len x kv_max_seq_len causal mask @@ -109,14 +110,14 @@ def ref_masked_attention(query: torch.Tensor, return out -def make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, +def make_qkv(batch_size: int, + max_q_seq_len: int, + max_kv_seq_len: int, + num_heads: int, + head_size: int, attn_type: AttentionType = AttentionType.ENCODER_DECODER, - force_max_len=False, - device=CUDA_DEVICE): + force_max_len: bool = False, + device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: ''' Construct QKV test tensors for self- and cross-attention. @@ -276,7 +277,9 @@ def make_qkv(batch_size, decode_kv_seq_lens -def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): +def pack_tensor(unpacked_tensor: torch.Tensor, + seq_lens: List[int], + device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -309,7 +312,8 @@ def pack_tensor(unpacked_tensor, seq_lens, device=CUDA_DEVICE): return packed_tensor, start_loc_list -def pack_qkv(query, key, value, q_seq_lens, kv_seq_lens): +def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + q_seq_lens: List[int], kv_seq_lens: List[int]) -> tuple: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x @@ -379,7 +383,8 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(is_prompt: bool, seq_lens: List[int], context_lens: List[int], - device=CUDA_DEVICE) -> tuple: + device: Union[torch.device, str] = \ + CUDA_DEVICE) -> tuple: ''' Build scalar & tensor values required to build attention metadata structure. @@ -434,12 +439,13 @@ def make_metadata_tensors(is_prompt: bool, query_start_loc -def make_kv_cache(num_blocks, - num_heads, - head_size, - block_size, - device=CUDA_DEVICE, - default_val=0.0): +def make_kv_cache(num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str] = \ + CUDA_DEVICE, + default_val: float=0.0) -> torch.Tensor: ''' Create a fake KV cache. @@ -464,7 +470,7 @@ def make_kv_cache(num_blocks, return kv_cache -def num_tokens_to_min_blocks(num_tokens, block_size): +def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' Compute the minimum number of blocks required to hold num_tokens tokens, given block_size @@ -472,10 +478,11 @@ def num_tokens_to_min_blocks(num_tokens, block_size): return (num_tokens + block_size) // block_size -def make_block_tables_slot_mapping(block_size, - seq_lens, - block_base_addr=0, - device=CUDA_DEVICE): +def make_block_tables_slot_mapping(block_size: int, + seq_lens: List, + block_base_addr: int=0, + device: Union[torch.device, str] = \ + CUDA_DEVICE) -> tuple: ''' Construct fake block tables & slot mappings. @@ -585,15 +592,15 @@ def make_block_tables_slot_mapping(block_size, max_block_idx -def make_metadata_self_cross( +def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], context_lens: List[int], - block_tables, - slot_mapping, + block_tables: torch.Tensor, + slot_mapping: torch.Tensor, is_encoder_only_test: bool, - device=CUDA_DEVICE, + device: Union[torch.device, str] = CUDA_DEVICE, cross_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, @@ -602,6 +609,10 @@ def make_metadata_self_cross( Construct fake attention metadata for a combined self-/cross-attention scenario i.e. an encoder/decoder model. + is_encoder_only_test=True causes the default attention metadata attention + type to be AttentionType.ENCODER. False causes the default to + be AttentionType.DECODER. + Assumptions: * No chunked prefill -> a batch is 100% prefill or 100% decode, never both @@ -614,6 +625,8 @@ def make_metadata_self_cross( * context_lens: list of context lengths for each sequence * block_tables: self-attention block tables * slot_mapping: self-attention slot_mapping + * is_encoder_only_test: True if testing encoder; False if testing + decoder self-attention or encoder/decoder cross-attention. * device: CPU or CUDA device * cross_seq_lens: list of token counts for each encoder sequence, if any exist @@ -644,13 +657,9 @@ def make_metadata_self_cross( context_lens, device=device) - slot_mapping_tensor = slot_mapping - - cross_slot_mapping_tensor = cross_slot_mapping - return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, + slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -665,7 +674,7 @@ def make_metadata_self_cross( use_cuda_graph=False, _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, - cross_slot_mapping=cross_slot_mapping_tensor, + cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) else: # not is_prompt @@ -685,13 +694,9 @@ def make_metadata_self_cross( context_lens, device=device) - slot_mapping_tensor = slot_mapping - - cross_slot_mapping_tensor = cross_slot_mapping - return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping_tensor, + slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -706,11 +711,12 @@ def make_metadata_self_cross( use_cuda_graph=False, _attn_type=default_attn_type, cross_seq_lens=cross_seq_lens, - cross_slot_mapping=cross_slot_mapping_tensor, + cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) -def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): +def basic_setup(num_heads: int, head_size: int, num_blocks: int, + block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -747,37 +753,24 @@ def basic_setup(num_heads, head_size, num_blocks, block_size, backend_name): # Construct KV cache kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) return scale, attn_backend, attn, kv_cache - -def encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=0): + + +def encoder_attn_setup(batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for encoder attention test. - A triplet of synthetic query/key/value tensors are constructed ("baseline" - query/key/value). Given this is a self-attention test, the key & value + A triplet of synthetic query/key/value tensors are constructed. + Given this is an encoder attention test, the key & value sequences will have the same length as the corresponding queries. - "Prefill" query/key/value tensors are derived by masking out the last value - in each baseline query/key/value. These tensors are used to test prefill & - populate KV cache for a subsequent decode test. - - "Decode" query/key/value tensors are derived by extracting *only* the last - value from each baseline query/key/value (i.e. complement of the prefill - tensors.) These tensors are used to test decode, conditional on the kv cache - being populated during the prefill test. - - The baseline query/key/value tensors are passed to an ideal reference - self-attention implementation to generate a "Baseline" ideal output tensor. - This tensor is split into the "Prefill" ideal output tensor (all but the - last element of each output sequence) and the "Decode" ideal output tensor - (*only* the last element of each output sequence); the "Prefill" and - "Decode" ideal output tensors can be used to validate the prefill and decode - test results, respectively. + The query/key/value tensors are passed to an ideal reference + self-attention implementation to generate an ideal output tensor. This function also constructs the self-attention KV cache memory mapping (slot mapping and block table), ensuring that the block table starts at @@ -794,42 +787,14 @@ def encoder_attn_setup(batch_size, * block_base_addr: self-attention block table base address Returns: - - * query: "baseline" query; batch_size x padded_seq_len x num_heads x - head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x - head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads - x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x - num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; - number_of_tokens x num_heads x head_size - * prefill_q_seq_lens: list of token counts for each *prefill query* (one - less than baseline query) - * prefill_kv_seq_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_seq_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x - head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x - head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads - x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; - number_of_tokens x num_heads x head_size - * decode_q_seq_lens: list of token counts for each *decode query* (should - be 1) - * decode_kv_seq_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_seq_lens) - * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x - head_size - * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens - x num_heads x head_size - * decode_block_tables: fake self-attn decode-phase block table - * decode_slot_mapping: fake self-attn decode-phase slot mapping - * prefill_slot_mapping: fake self-attn prefill-phase slot mapping - * prefill_block_tables: fake self-attn prefill-phase block table - * max_block_idx: highest block address in the self-attention block-table + + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * packed_ideal_output: number_of_tokens x num_heads x head_size + * block_tables: fake self-attn decode-phase block table + * slot_mapping: fake self-attn decode-phase slot mapping + * q_seq_lens: list of query sequence lengths ''' max_kv_seq_len = max_q_seq_len @@ -857,7 +822,7 @@ def encoder_attn_setup(batch_size, head_size, attn_type=AttentionType.ENCODER) - # No attention mask + # No causal attention mask ideal_output = ref_masked_attention(query, key, value, @@ -865,13 +830,7 @@ def encoder_attn_setup(batch_size, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - # prefill_ideal_output = torch.zeros_like(ideal_output) - # for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): - # prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ - # bdx, :prefill_q_seq_len] - - packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) block_tables, \ _, \ @@ -896,13 +855,14 @@ def encoder_attn_setup(batch_size, slot_mapping, \ q_seq_lens -def decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=0): + +def decoder_attn_setup(batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for self-attention test. @@ -1074,17 +1034,18 @@ def decoder_attn_setup(batch_size, max_block_idx -def enc_dec_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=0): +def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, + q_seq_lens: List, + prefill_q_seq_lens: List, + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + max_kv_seq_len: int, + block_base_addr: Optional[int]=0) \ + -> tuple: ''' Set up test vectors & data structures for cross-attention test. @@ -1092,7 +1053,7 @@ def enc_dec_attn_setup_reuses_query(query, ("baseline" key/value). Given this is a cross-attention test, we assume query tensors were already synthesized for a prior self-attention test and will be reused for cross-attention. The key & value sequences generated here - will may have a different length than the corresponding queries (as is often + may have a different length than the corresponding queries (as is often the case for cross-attention between decoder and encoder sequences.) Cross attention key & value tensors do not grow during autoregressive @@ -1217,22 +1178,63 @@ def enc_dec_attn_setup_reuses_query(query, max_block_idx -def run_self_attention_test(attn: Attention, packed_query, packed_key, - packed_value, kv_cache, +def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: AttentionType): + attn_type: AttentionType) -> torch.Tensor: + ''' + Run encoder attention or decoder self-attention test. + + attn_metadata.attention_type is assigned attn_type in order to configure + the kernel invocation for either encoder or decoder self-attention. + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * kv_cache + * attn_metadata: attention metadata for encoder/decoder-self attention + * attn_type: AttentionType.DECODER or AttentionType.ENCODER + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query, packed_key, - packed_value, kv_cache, - attn_metadata: AttentionMetadata): +def run_cross_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + ''' + Run encoder/decoder cross-attention test. + + attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER + in order to configure the kernel invocation for encoder/decoder cross- + attention. + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * kv_cache + * attn_metadata: attention metadata for encoder/decoder-self attention + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1242,50 +1244,49 @@ def run_cross_attention_test(attn: Attention, packed_query, packed_key, def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_seq_len: int) -> None: - ''' Encoder-only attention test: - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid + * Construct fake test vectors for encoder attention + * Construct attention metadata structure with encoder-attention- + specific attributes + * Run encoder attention with metadata structure & test vectors * Validate output correctness against ideal reference attention implementation - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + Encoder attention (by default) does not restrict which sequence offsets + may attend to each other. Thus the reference ideal reference + implementation does not employ a causal attention mask. - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' + Encoder attention does not utilize KV cache however the XFormer backend + requires block_tables & slot_mapping to be non-None and have a valid + structure, thus this test constructs dummy memory-mapping structures. - # Num KV cache blocks - # num_blocks = 4096 + Encoder attention is basically structured like decoder self-attention + in that Q/K/V are all derived from the previous layer output & have + the same sequence length (in contrast to encoder/decoder cross- + attention where K/V are drawn from the encoder hidden states and + may have a different length than Q derived from decoder previous + layer output.) + ''' # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init + # instance. Encoder attention does not require KV cache. scale, \ attn_backend, \ attn, \ _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) + head_size, + None, + None, + backend_name) # Self-attention setup # Let encoder_attn_setup() choose default block table - # base address + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. packed_query, \ packed_key, \ packed_value, \ @@ -1301,7 +1302,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, context_lens = [0 for _ in range(batch_size)] - attn_metadata: AttentionMetadata = make_metadata_self_cross( + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, q_seq_lens, @@ -1309,9 +1316,6 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, - cross_seq_lens=None, - cross_block_tables=None, - cross_slot_mapping=None, ) packed_actual_output: torch.Tensor = run_self_attention_test( @@ -1324,10 +1328,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as( - packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1421,7 +1424,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_attn_setup_reuses_query(query, + _ = enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -1437,7 +1440,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_metadata_self_cross( + prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, prefill_q_seq_lens, @@ -1479,7 +1482,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # DECODE: self- and cross-attention tests - decode_attn_metadata: AttentionMetadata = make_metadata_self_cross( + decode_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, q_seq_lens, From 641f43139bc0b7935e759ef9eb89b2f0d3889484 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:46:32 -0400 Subject: [PATCH 093/239] typo --- vllm/attention/backends/abstract.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index cffd2d577777c..ece0da25ee6f2 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -8,8 +8,8 @@ class AttentionType(Enum): - DECODER = auto() # Decoder attention between previously layer Q/K/V - ENCODER = auto() # Encoder attention between previously layer Q/K/V + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V From c7f54907ba3ecaecee761512d9266bc47c17e310 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 14:53:04 -0400 Subject: [PATCH 094/239] changed helper function naming convention --- tests/kernels/test_self_and_cross_attn.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 6e2a3ed19ef76..e45ea629a2a44 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1178,11 +1178,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' Run encoder attention or decoder self-attention test. @@ -1207,11 +1207,11 @@ def run_self_attention_test(attn: Attention, packed_query: torch.Tensor, attn_metadata) -def run_cross_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: +def run_encoder_decoder_cross_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -1318,7 +1318,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, ) - packed_actual_output: torch.Tensor = run_self_attention_test( + packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -1453,7 +1453,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_prefill_slot_mapping, ) - self_prefill_packed_actual_output: torch.Tensor = run_self_attention_test( + self_prefill_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -1468,7 +1468,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_cross_attention_test( + cross_prefill_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -1495,7 +1495,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_decode_slot_mapping, ) - self_decode_packed_actual_output: torch.Tensor = run_self_attention_test( + self_decode_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -1510,7 +1510,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_cross_attention_test( + cross_decode_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? From 9c78f8555dce717521e0bf0c77306c0444f44824 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:05:18 -0400 Subject: [PATCH 095/239] check we are not testing decode-phase/encoder attention --- tests/kernels/test_self_and_cross_attn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e45ea629a2a44..a211b7e2cc210 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1202,6 +1202,7 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] + assert attn_metadata.is_prompt or attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) From bf93a9eba5c70f82d7e66c6b2165c6e673467a49 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:08:55 -0400 Subject: [PATCH 096/239] refactoring --- tests/kernels/test_self_and_cross_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index a211b7e2cc210..9e6b3c16ee86a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1189,6 +1189,10 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to attn_metadata.attention_type is assigned attn_type in order to configure the kernel invocation for either encoder or decoder self-attention. + attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, + attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as + "decode-phase enocder attention".) + Arguments: * attn: Attention wrapper instance @@ -1202,7 +1206,7 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.is_prompt or attn_type != AttentionType.ENCODER + assert attn_metadata.num_decode_tokens==0 or attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) From cd759f2819db3848dad6eda4d78fe0da9932ce3e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:21:59 -0400 Subject: [PATCH 097/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 37 +++++++++++++---------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 9e6b3c16ee86a..28348924ca64b 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1178,11 +1178,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def run_encoder_or_decoder_self_attention_test( + attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' Run encoder attention or decoder self-attention test. @@ -1206,17 +1206,17 @@ def run_encoder_or_decoder_self_attention_test(attn: Attention, packed_query: to & attn_metadata ''' assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.num_decode_tokens==0 or attn_type != AttentionType.ENCODER + assert attn_metadata.num_decode_tokens == 0 or \ + attn_type != AttentionType.ENCODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) -def run_encoder_decoder_cross_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata) -> torch.Tensor: +def run_encoder_decoder_cross_attention_test( + attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, + packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -1323,7 +1323,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, ) - packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -1458,7 +1459,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_prefill_slot_mapping, ) - self_prefill_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -1473,7 +1475,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) - cross_prefill_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -1500,7 +1503,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_slot_mapping=cross_decode_slot_mapping, ) - self_decode_packed_actual_output: torch.Tensor = run_encoder_or_decoder_self_attention_test( + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -1515,7 +1519,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) - cross_decode_packed_actual_output: torch.Tensor = run_encoder_decoder_cross_attention_test( + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? From 1af36258ea3f80d7c89bc62fd1c91c445a789ee5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:23:57 -0400 Subject: [PATCH 098/239] removing unnecessary check --- vllm/attention/backends/xformers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e6fe0717b0e7..90ed07b029b6a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -488,8 +488,7 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] - if attn_type != AttentionType.ENCODER_DECODER \ - and key is not None and value is not None: + if key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] From eb5cf0cbd5f4ae3135f030b0bab4db1333ac61a6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:45:46 -0400 Subject: [PATCH 099/239] unit test for encoder/decoder+chunked prefill non-support; added attention utils file for error strings --- tests/kernels/test_self_and_cross_attn.py | 13 +++++++++++++ vllm/attention/backends/utils.py | 5 +++++ vllm/attention/backends/xformers.py | 17 +++++++++++++---- 3 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 vllm/attention/backends/utils.py diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 28348924ca64b..5128108f52ad6 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,6 +8,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad @@ -1528,3 +1529,15 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_ideal_output, cross_decode_packed_actual_output.view_as( cross_decode_packed_ideal_output)) + + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000000000..f893460cce06e --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,5 @@ +"""Attention utils""" + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ +"Encoder/decoder models " + \ +"currently do not support chunked prefill." \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 90ed07b029b6a..ecd5413fba507 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -477,10 +478,18 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert attn_type == AttentionType.ENCODER_DECODER or ( - key.shape[0] == num_prefill_tokens + num_decode_tokens) - assert attn_type == AttentionType.ENCODER_DECODER or ( - value.shape[0] == num_prefill_tokens + num_decode_tokens) + if attn_type == AttentionType.ENCODER_DECODER: + # Encoder/decoder models are currently incompatible + # with chunked prefill. + if num_prefill_tokens > 0 and num_decode_tokens > 0: + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + else: + # This is a decoder self-attention scenario; + # ensure key/value shape match total number of + # tokens to process + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. From afcb42e125b16a8afa5bb05740783e684b53f6fd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:46:39 -0400 Subject: [PATCH 100/239] explanatory comment --- vllm/attention/backends/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index f893460cce06e..6141f3ead64ad 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,5 +1,8 @@ """Attention utils""" +# Error string(s) for encoder/decoder +# unsupported attention scenarios + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Encoder/decoder models " + \ "currently do not support chunked prefill." \ No newline at end of file From d13e08e72dafce77ba15bd2bb5df610506a274f3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:52:37 -0400 Subject: [PATCH 101/239] refactoring --- tests/kernels/test_self_and_cross_attn.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5128108f52ad6..bdbb083397a93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1530,9 +1530,20 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_actual_output.view_as( cross_decode_packed_ideal_output)) + # The following test conditions could in principle be a + # standalone test, however the test setup is so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, decode_packed_query, From ab92fb0bd237a1e3ef382751254f5b884dbaf549 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:56:18 -0400 Subject: [PATCH 102/239] spelling fix --- tests/kernels/test_self_and_cross_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index bdbb083397a93..187af811874b3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1192,7 +1192,7 @@ def run_encoder_or_decoder_self_attention_test( attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as - "decode-phase enocder attention".) + "decode-phase encoder attention".) Arguments: From 582a0f5cdf0928d8e662e0df6e0326dc606d567f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 15:59:49 -0400 Subject: [PATCH 103/239] rename --- vllm/attention/backends/xformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index ecd5413fba507..f0d9e3576bfc6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -147,7 +147,7 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None @property - def has_valid_cross_attn_metadata(self): + def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever no_md = (self.cross_seq_lens is None) and (self.cross_slot_mapping is @@ -173,7 +173,7 @@ def attention_type(self) -> AttentionType: def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: - assert self.has_valid_cross_attn_metadata, \ + assert self.is_all_cross_attn_metadata_set, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" From a20be6da315d8c3df2b4f2048d5ef0785157b344 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:12:29 -0400 Subject: [PATCH 104/239] skip enc/dec tests if HIP --- tests/kernels/test_self_and_cross_attn.py | 9 ++++++--- vllm/attention/backends/utils.py | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 187af811874b3..e45e14f2342ad 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,10 +8,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL +from vllm.attention.backends.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import make_tensor_with_pad +from vllm.utils import is_hip + # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 @@ -1240,7 +1243,7 @@ def run_encoder_decoder_cross_attention_test( return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) - +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @@ -1338,7 +1341,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, assert torch.allclose(packed_ideal_output, packed_actual_output.view_as(packed_ideal_output)) - +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 6141f3ead64ad..727921641cd55 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -5,4 +5,8 @@ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Encoder/decoder models " + \ -"currently do not support chunked prefill." \ No newline at end of file +"currently do not support chunked prefill." + +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ +"Encoder/decoder models currently" + \ +"do not support ROCm/HIP." \ No newline at end of file From a9a162da1adbbcb5dc89742d2fc3a5c764aec6d7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:13:50 -0400 Subject: [PATCH 105/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e45e14f2342ad..36f9616432e3a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -8,12 +8,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import (STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +from vllm.attention.backends.utils import ( + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import make_tensor_with_pad - -from vllm.utils import is_hip +from vllm.utils import is_hip, make_tensor_with_pad # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # @@ -1243,6 +1241,7 @@ def run_encoder_decoder_cross_attention_test( return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -1341,6 +1340,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, assert torch.allclose(packed_ideal_output, packed_actual_output.view_as(packed_ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) From 6e3cfe141af8894a49d56aff2ec77235a82c3276 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 17:54:03 -0400 Subject: [PATCH 106/239] Refactored checks into utils file --- vllm/attention/backends/utils.py | 42 ++++++++++++++++++++++++++++- vllm/attention/backends/xformers.py | 21 +++++++-------- 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 727921641cd55..5616a28ae9f73 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,5 +1,10 @@ """Attention utils""" +from vllm.utils import is_hip +from vllm.attention import AttentionMetadata +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.xformers import XFormersMetadata + # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -9,4 +14,39 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "Encoder/decoder models currently" + \ -"do not support ROCm/HIP." \ No newline at end of file +"do not support ROCm/HIP." + +STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ +"Encoder/decoder models currently support only the XFormers backend." + +# Check for unsupported encoder/decoder scenarios + +def check_hip_or_chunked_prefill_attention_encdec( + attn_metadata: AttentionMetadata): + ''' + Check for unsupported encoder/decoder scenarios when invoking + attention. + + Arguments: + + * attn_metadata: Attention metadata structure + ''' + if is_hip(): + # AMD ROCm/HIP support currently not implemented for + # encoder/decoder models + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + + if not isinstance(attn_metadata,XFormersMetadata): + # Right now encoder/decoder support is only implemented + # for the XFormers backend. Pretty unlikely to encounter + # this case currently given this function will be invoked inside + # xFormers backend. + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) + + if attn_metadata.attention_type != AttentionType.DECODER: + # Encoder/decoder models are currently incompatible + # with chunked prefill. + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index f0d9e3576bfc6..c23cabd0a07ba 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -447,6 +446,13 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type + if attn_type != AttentionType.DECODER: + # Raise NotImplementedError for unsupported encoder/decoder + # scenarios + from vllm.attention.backends.utils import \ + check_hip_or_chunked_prefill_attention_encdec + check_hip_or_chunked_prefill_attention_encdec(attn_metadata) + if (kv_cache is not None): # Even if there are no new key/value pairs to cache, # we still need to break out key_cache and value_cache @@ -478,16 +484,9 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - if attn_type == AttentionType.ENCODER_DECODER: - # Encoder/decoder models are currently incompatible - # with chunked prefill. - if num_prefill_tokens > 0 and num_decode_tokens > 0: - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) - else: - # This is a decoder self-attention scenario; - # ensure key/value shape match total number of - # tokens to process + if attn_type != AttentionType.ENCODER_DECODER: + # Only enforce this shape-constraint for decoder + # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens From 622ce09f19d8d21080b2a72377c929c0e149eb8f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 18:14:06 -0400 Subject: [PATCH 107/239] format --- tests/kernels/test_self_and_cross_attn.py | 134 ++++++++++++++++++++++ vllm/attention/backends/utils.py | 19 +-- vllm/attention/backends/xformers.py | 4 +- 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 36f9616432e3a..c12b19b929147 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1555,3 +1555,137 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +@pytest.mark.parametrize("num_heads", [256]) +@pytest.mark.parametrize("head_size", [16]) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("max_q_seq_len", [64]) +@pytest.mark.parametrize("max_kv_seq_len", [64]) +def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: + ''' + Encoder/decoder not-implemented-for-ROCm-HIP test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5616a28ae9f73..ad88b4f964a54 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,9 +1,9 @@ """Attention utils""" -from vllm.utils import is_hip from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.xformers import XFormersMetadata +from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -21,6 +21,7 @@ # Check for unsupported encoder/decoder scenarios + def check_hip_or_chunked_prefill_attention_encdec( attn_metadata: AttentionMetadata): ''' @@ -35,18 +36,18 @@ def check_hip_or_chunked_prefill_attention_encdec( # AMD ROCm/HIP support currently not implemented for # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - - if not isinstance(attn_metadata,XFormersMetadata): + + if not isinstance(attn_metadata, XFormersMetadata): # Right now encoder/decoder support is only implemented # for the XFormers backend. Pretty unlikely to encounter # this case currently given this function will be invoked inside # xFormers backend. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - - if attn_metadata.attention_type != AttentionType.DECODER: + + if attn_metadata.attention_type != AttentionType.DECODER \ + and attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible # with chunked prefill. - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) + raise NotImplementedError( \ + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index c23cabd0a07ba..6c5bd8fa3726a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -449,8 +449,8 @@ def forward( if attn_type != AttentionType.DECODER: # Raise NotImplementedError for unsupported encoder/decoder # scenarios - from vllm.attention.backends.utils import \ - check_hip_or_chunked_prefill_attention_encdec + from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec) check_hip_or_chunked_prefill_attention_encdec(attn_metadata) if (kv_cache is not None): From 4d88a898184f09b8ad69fe498bb23629fd99f338 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:09:18 -0400 Subject: [PATCH 108/239] wip trying to combine attention metadata caches --- vllm/attention/backends/xformers.py | 234 ++++++++++++---------------- 1 file changed, 102 insertions(+), 132 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6c5bd8fa3726a..e74b5fa1052db 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -198,149 +198,119 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_prefill_metadata is not None: - return self._self_cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._self_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, - cross_block_tables=None, - cross_slot_mapping=None) + target_attention_type = self.attention_type + + if self._self_cached_prefill_metadata is not None: + self._self_cached_prefill_metadata.attention_type = \ + target_attention_type return self._self_cached_prefill_metadata + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + if self.is_all_cross_attn_metadata_set: + # This attention metadata structure could support + # encoder/decoder cross-attention; make sure to + # set the appropriate fields + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables else: - # Encoder/decoder cross-attention prefill - - if self._cross_cached_prefill_metadata is not None: - return self._cross_cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - self._cross_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_prefill_metadata + # This attention metadata structure supports + # decoder-only self-attention; there are no fields + # to support encoder/decoder cross-attention + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_slot_mapping=None, + cross_block_tables=None + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + _attn_type=self. + attention_type, # Begin cross-attention fields below... + cross_seq_lens=cross_seq_lens, + cross_seq_lens_tensor=cross_seq_lens_tensor, + max_cross_seq_len=max_cross_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + return self._self_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_decode_metadata is not None: - return self._self_cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._self_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, - cross_block_tables=None, - cross_slot_mapping=None) - return self._self_cached_decode_metadata + target_attention_type = self.attention_type + if self._self_cached_decode_metadata is not None: + self._self_cached_decode_metadata.attention_type = \ + target_attention_type + return self._self_cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + if self.is_all_cross_attn_metadata_set: + # This attention metadata structure could support + # encoder/decoder cross-attention; make sure to + # set the appropriate fields + cross_seq_lens=self.cross_seq_lens, + cross_seq_lens_tensor=self.cross_seq_lens_tensor, + max_cross_seq_len=self.max_cross_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables else: - # Encoder/decoder cross-attention decode - - if self._cross_cached_decode_metadata is not None: - return self._cross_cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cross_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_decode_metadata - + # This attention metadata structure supports + # decoder-only self-attention; there are no fields + # to support encoder/decoder cross-attention + cross_seq_lens=None, + cross_seq_lens_tensor=None, + max_cross_seq_len=None, + cross_slot_mapping=None, + cross_block_tables=None + + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + _attn_type=target_attention_type, + # Begin cross-attention fields below... + cross_seq_lens=cross_seq_lens, + cross_seq_lens_tensor=cross_seq_lens_tensor, + max_cross_seq_len=max_cross_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + return self._self_cached_decode_metadata class XFormersImpl(AttentionImpl[XFormersMetadata]): """ From 3dfcb556f43b4b753fea5e85e780a4fda85e99f9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:36:22 -0400 Subject: [PATCH 109/239] wip trying to merge self/cross caches; trying to fix attn_bias issues; just tried having xformers backend clear mask when changing target attention type --- vllm/attention/backends/xformers.py | 32 +++++++++++++++++------------ 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e74b5fa1052db..71a34170abc3e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -162,8 +162,22 @@ def is_all_cross_attn_metadata_set(self): assert ( not invalid_md_if_not_no_md), "Invalid cross-attention metadata" + self._maybe_infer_implicit_cross_attention_metadata() return True + def _maybe_infer_implicit_cross_attention_metadata(self): + # Infer implicit cross-attention fields + # from user-provided fields, if needed + if self.cross_seq_lens_tensor is None: + assert self.seq_lens_tensor is not None + self.cross_seq_lens_tensor = torch.tensor( + self.cross_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_cross_seq_len is None: + assert self.cross_seq_lens is not None + self.max_cross_seq_len = max(self.cross_seq_lens) + @property def attention_type(self) -> AttentionType: return self._attn_type @@ -175,19 +189,7 @@ def attention_type(self, atype: AttentionType) -> None: assert self.is_all_cross_attn_metadata_set, \ "Must have self.cross_seq_lens not None " + \ "in order to enable cross-attention" - - # Infer implicit cross-attention fields - # from user-provided fields, if needed - if self.cross_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None - self.cross_seq_lens_tensor = torch.tensor( - self.cross_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) - if self.max_cross_seq_len is None: - assert self.cross_seq_lens is not None - self.max_cross_seq_len = max(self.cross_seq_lens) - + self._maybe_infer_implicit_cross_attention_metadata() self._attn_type = AttentionType.ENCODER_DECODER else: # AttentionType.{ENCODER,DECODER} @@ -263,6 +265,10 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: target_attention_type = self.attention_type if self._self_cached_decode_metadata is not None: + if self._self_cached_decode_metadata.attention_type != \ + target_attention_type: + self._self_cached_decode_metadata.attn_bias = None + self._self_cached_decode_metadata.attention_type = \ target_attention_type return self._self_cached_decode_metadata From 696072392eca111c7a463ad7ad920528dc0c0e6a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 29 May 2024 19:38:48 -0400 Subject: [PATCH 110/239] wip --- vllm/attention/backends/xformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 71a34170abc3e..d97f6d306e5bd 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -203,6 +203,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: target_attention_type = self.attention_type if self._self_cached_prefill_metadata is not None: + if self._self_cached_prefill_metadata.attention_type != \ + target_attention_type: + self._self_cached_prefill_metadata.attn_bias = None + self._self_cached_prefill_metadata.attention_type = \ target_attention_type return self._self_cached_prefill_metadata From 31275ccba58545b1c090f24e600507b04adeec0c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 30 May 2024 09:32:28 -0400 Subject: [PATCH 111/239] wip merging attention metadata --- vllm/attention/backends/xformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d97f6d306e5bd..6f21f7caf83e1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -111,8 +111,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None _self_cached_decode_metadata: Optional["XFormersMetadata"] = None # Cross-attention prefill/decode metadata cache - _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None # Begin cross-attention fields... From a643436c0aac5db823a009bef65146d1610a7522 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 30 May 2024 10:35:29 -0400 Subject: [PATCH 112/239] simplied is_all_cross_attn_metadata_set() --- vllm/attention/backends/xformers.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6c5bd8fa3726a..4ef052424f176 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,21 +148,9 @@ def __post_init__(self): @property def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever - no_md = (self.cross_seq_lens is - None) and (self.cross_slot_mapping is - None) and (self.cross_block_tables is None) - # If any cross-attention metadata is present, it is invalid - invalid_md_if_not_no_md = (self.cross_seq_lens is None) or ( - self.cross_slot_mapping is None) or (self.cross_block_tables is - None) - - if no_md: - return False - - assert ( - not invalid_md_if_not_no_md), "Invalid cross-attention metadata" - - return True + return (self.cross_seq_lens is not None) and \ + (self.cross_slot_mapping is not None) and \ + (self.cross_block_tables is not None) @property def attention_type(self) -> AttentionType: @@ -173,8 +161,8 @@ def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: assert self.is_all_cross_attn_metadata_set, \ - "Must have self.cross_seq_lens not None " + \ - "in order to enable cross-attention" + "Must enable self.cross_seq_lens, self.cross_slot_mapping, " + \ + "self.cross_block_tables in order to perform cross-attention" # Infer implicit cross-attention fields # from user-provided fields, if needed From 2a1d84ac71fbbdfd50262574c505e34bdc464157 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:06:00 -0400 Subject: [PATCH 113/239] test: envs.VLLM_ATTENTION_BACKEND --- tests/kernels/test_self_and_cross_attn.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c12b19b929147..5d92d73bfef38 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -13,6 +13,9 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad +from vllm.logger import init_logger +logger = init_logger(__name__) + # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] # # TODO: FlashAttention forward only supports head dimension at most 128 @@ -1278,6 +1281,10 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' + import vllm.envs as envs + print("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + logger.info("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. scale, \ From f6e0310f1a955d7cf6e08e27a6b4f7fa12f80c88 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:08:01 -0400 Subject: [PATCH 114/239] formatting' --- tests/kernels/test_self_and_cross_attn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5d92d73bfef38..73625f6f68501 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -11,9 +11,9 @@ from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend +from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -from vllm.logger import init_logger logger = init_logger(__name__) # If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] @@ -1282,8 +1282,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' import vllm.envs as envs - print("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) - logger.info("envs.VLLM_ATTENTION_BACKEND: "+str(envs.VLLM_ATTENTION_BACKEND)) + print("envs.VLLM_ATTENTION_BACKEND: " + str(envs.VLLM_ATTENTION_BACKEND)) + logger.info("envs.VLLM_ATTENTION_BACKEND: ", + str(envs.VLLM_ATTENTION_BACKEND)) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. From 60c01c3b67121e6fb7bdfb074f7fb7f39fef0023 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 2 Jun 2024 23:45:30 -0400 Subject: [PATCH 115/239] attempted to fix issue whereby selector test doesn't cleanup environment variables --- tests/kernels/test_attention_selector.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f439afa9b7d2b..cefd856898ac7 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -34,6 +34,9 @@ def test_env(name: str, device: str): if name_backup is not None: os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) def test_flash_attn(): @@ -73,6 +76,9 @@ def test_flash_attn(): if name_backup is not None: os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) def test_invalid_env(): @@ -81,4 +87,9 @@ def test_invalid_env(): os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + else: + # VLLM_ATTENTION_BACKEND was unset + os.environ.pop('VLLM_ATTENTION_BACKEND', None) From 5c94166b203f22ad20d7909661f5229c6954e09c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 00:48:56 -0400 Subject: [PATCH 116/239] (1) In top-level tests utils.py added env var context manager, (2) in tests/kernels added utils.py w/ vLLM backend context manager, (3) all unit tests for backend selector & enc/dec use backend context manager --- tests/kernels/test_attention_selector.py | 101 ++- tests/kernels/test_self_and_cross_attn.py | 772 +++++++++++----------- tests/kernels/utils.py | 25 + tests/utils.py | 26 + 4 files changed, 484 insertions(+), 440 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index cefd856898ac7..110137d2820b7 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -import os from unittest.mock import patch import pytest import torch from vllm.attention.selector import which_attn_to_use +from tests.kernels.utils import backend_override_fixture @pytest.mark.parametrize( @@ -14,82 +14,63 @@ def test_env(name: str, device: str): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = name - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): + with backend_override_fixture(name): + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == name - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + assert backend.name == name def test_flash_attn(): """Test FlashAttn validation.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + with backend_override_fixture("FLASH_ATTN"): - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" def test_invalid_env(): """Throw an exception if the backend name is invalid.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" - with pytest.raises(ValueError): - which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup - else: - # VLLM_ATTENTION_BACKEND was unset - os.environ.pop('VLLM_ATTENTION_BACKEND', None) + + with backend_override_fixture("INVALID"), pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) \ No newline at end of file diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 73625f6f68501..684a86cfc4586 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -13,6 +13,7 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad +from tests.kernels.utils import backend_override_fixture logger = init_logger(__name__) @@ -27,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = ["xformers"] +BACKEND_NAMES = ["XFORMERS"] CUDA_DEVICE = "cuda:0" MAX_Q_SEQ_LENS = [128] @@ -371,15 +372,22 @@ def make_backend(backend_name: str) -> AttentionBackend: Construct the backend instance determined by the backend_name string argument. - "xformers" -> construct xformers backend + "XFORMERS" -> construct xformers backend + + TODO: other backends + + Note: at time of writing the Attention wrapper automatically selects + its own backend for Attention.forward(); so the backend instance which + you generate with this function is not meant to be used for *running* + inference, but rather for generating compatible metadata structures + using backend.make_metadata() - TODO: flash attention backend Returns: * Backend instance ''' - if backend_name == "xformers": + if backend_name == "XFORMERS": return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") @@ -1281,72 +1289,70 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' - import vllm.envs as envs - print("envs.VLLM_ATTENTION_BACKEND: " + str(envs.VLLM_ATTENTION_BACKEND)) - logger.info("envs.VLLM_ATTENTION_BACKEND: ", - str(envs.VLLM_ATTENTION_BACKEND)) - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - q_seq_lens, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - ) - - packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Attention scale factor, attention backend instance, attention wrapper + # instance. Encoder attention does not require KV cache. + scale, \ + attn_backend, \ + attn, \ + _ = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_seq_len) + + context_lens = [0 for _ in range(batch_size)] + + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + q_seq_lens, + context_lens, + block_tables, + slot_mapping, + is_encoder_only_test=True, + ) + + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1386,314 +1392,320 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( for cross-attention. ''' - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, - ) - - self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? - assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - # The following test conditions could in principle be a - # standalone test, however the test setup is so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -@pytest.mark.parametrize("num_heads", [256]) -@pytest.mark.parametrize("head_size", [16]) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_q_seq_len", [64]) -@pytest.mark.parametrize("max_kv_seq_len", [64]) -def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: - ''' - Encoder/decoder not-implemented-for-ROCm-HIP test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + + # The following test conditions could in principle be a + # standalone test, however the test setup is so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +# @pytest.mark.parametrize("num_heads", [256]) +# @pytest.mark.parametrize("head_size", [16]) +# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) +# @pytest.mark.parametrize("batch_size", [16]) +# @pytest.mark.parametrize("block_size", [16]) +# @pytest.mark.parametrize("max_q_seq_len", [64]) +# @pytest.mark.parametrize("max_kv_seq_len", [64]) +# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, +# backend_name: str, batch_size: int, +# block_size: int, max_q_seq_len: int, +# max_kv_seq_len: int) -> None: +# ''' +# Encoder/decoder not-implemented-for-ROCm-HIP test: + +# * Construct fake test vectors for self- and cross-attention +# * Construct attention metadata structure with self- and cross-attention +# attributes +# * Test self- and cross-attention in the following order - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP +# * Prefill self-attention +# * Prefill cross-attention +# * Decode self-attention +# * Decode cross-attention +# * This order would exacerbate any accidental overlap in the +# self-/cross-attention block tables, which we attempt to avoid +# * Validate output correctness against ideal reference attention +# implementation + +# Block tables are constructed such that cross-attention KV cache is in a +# higher, non-intersecting address-space than self-attention KV cache. + +# Self- and cross-attention share the same query tensor but not the K/V +# tensors. Self-attention K/Vs must have the same seq len as Q while +# cross-attention K/Vs are allowed to differ in seq len, as is often the case +# for cross-attention. +# ''' + +# with backend_override_fixture(backend_name): +# # Force Attention wrapper backend + +# # Num KV cache blocks +# num_blocks = 4096 + +# # Attention scale factor, attention backend instance, attention wrapper +# # instance, KV cache init +# scale, \ +# attn_backend, \ +# attn, \ +# kv_cache = basic_setup(num_heads, +# head_size, +# num_blocks, +# block_size, +# backend_name) + +# # Self-attention setup + +# self_block_base_addr = 0 + +# query, \ +# prefill_packed_query, \ +# self_prefill_packed_key, \ +# self_prefill_packed_value, \ +# self_prefill_packed_ideal_output, \ +# prefill_q_seq_lens, \ +# self_prefill_kv_seq_lens, \ +# decode_packed_query, \ +# self_decode_packed_key, \ +# self_decode_packed_value, \ +# self_decode_packed_ideal_output, \ +# _, \ +# _, \ +# q_seq_lens, \ +# _, \ +# self_decode_block_tables, \ +# self_decode_slot_mapping, \ +# self_prefill_slot_mapping, \ +# self_prefill_block_tables, \ +# cross_block_base_addr = decoder_attn_setup(batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# block_base_addr=self_block_base_addr) + +# # Cross-attention setup + +# cross_prefill_packed_key, \ +# cross_prefill_packed_value, \ +# cross_prefill_packed_ideal_output, \ +# cross_decode_packed_ideal_output, \ +# cross_kv_seq_lens, \ +# cross_decode_block_tables, \ +# cross_decode_slot_mapping, \ +# cross_prefill_slot_mapping, \ +# cross_prefill_block_tables, \ +# _ = enc_dec_cross_attn_setup_reuses_query(query, +# q_seq_lens, +# prefill_q_seq_lens, +# batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# max_kv_seq_len, +# block_base_addr=cross_block_base_addr) + +# # PREFILL: self- and cross-attention tests + +# context_lens = [0 for _ in range(batch_size)] + +# prefill_attn_metadata: AttentionMetadata = make_test_metadata( +# attn_backend, +# True, +# prefill_q_seq_lens, +# context_lens, +# self_prefill_block_tables, +# self_prefill_slot_mapping, +# is_encoder_only_test=False, +# cross_seq_lens=cross_kv_seq_lens, +# cross_block_tables=cross_prefill_block_tables, +# cross_slot_mapping=cross_prefill_slot_mapping, +# ) + +# with pytest.raises(NotImplementedError) as exc_info: +# run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, +# cross_prefill_packed_key, +# cross_prefill_packed_value, +# kv_cache, +# prefill_attn_metadata) + +# # "Encoder decoder models do not currently support ROCm/HIP" +# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..1d54a513f287c --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,25 @@ +"""Kernel test utils""" + +from tests.utils import env_var_fixture +from contextlib import contextmanager +from typing import Iterator + +# Configure + +@contextmanager +def backend_override_fixture(backend_name: str) -> Iterator[None]: + ''' + Text fixture, temporarily configures the vLLM backend by setting + VLLM_ATTENTION_BACKEND, then resets the environment outside of + the fixture. + + Usage: + + with backend_override_fixture("backend_name"): + # code that depends on vLLM backend + + # VLLM_ATTENTION_BACKEND is returned to original value + # or unset + ''' + with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): + yield # Control is yielded to the enclosed block, environment variable is set \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 329842911e159..7c8740236956a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ import warnings from contextlib import contextmanager +from typing import Iterator + import ray import requests @@ -101,3 +103,27 @@ def error_on_warning(): warnings.simplefilter("error") yield + +@contextmanager +def env_var_fixture(var_name: str, value: str) -> Iterator[None]: + ''' + Text fixture, temporarily assigns value var_name environment variable, + then resets environment variable outside of test fixture. + + Usage: + + with env_var_fixture("my_var","my_val"): + # code that depends on my_val == "my_val" + + # my_var is returned to original value or unset + ''' + original_value = os.environ.get(var_name) # Store the original value + os.environ[var_name] = value # Set the new value + try: + yield + finally: + # Restore the original value + if original_value is None: + del os.environ[var_name] + else: + os.environ[var_name] = original_value \ No newline at end of file From eaa627fd3da6f3491a08cf06d8d8fd1eee89f991 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 01:19:39 -0400 Subject: [PATCH 117/239] wip tests --- tests/kernels/test_self_and_cross_attn.py | 268 +++++++++++----------- 1 file changed, 134 insertions(+), 134 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 684a86cfc4586..67772b1286e2a 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1574,138 +1574,138 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL -# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -# @pytest.mark.parametrize("num_heads", [256]) -# @pytest.mark.parametrize("head_size", [16]) -# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -# @pytest.mark.parametrize("batch_size", [16]) -# @pytest.mark.parametrize("block_size", [16]) -# @pytest.mark.parametrize("max_q_seq_len", [64]) -# @pytest.mark.parametrize("max_kv_seq_len", [64]) -# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, -# backend_name: str, batch_size: int, -# block_size: int, max_q_seq_len: int, -# max_kv_seq_len: int) -> None: -# ''' -# Encoder/decoder not-implemented-for-ROCm-HIP test: - -# * Construct fake test vectors for self- and cross-attention -# * Construct attention metadata structure with self- and cross-attention -# attributes -# * Test self- and cross-attention in the following order +@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +@pytest.mark.parametrize("num_heads", [256]) +@pytest.mark.parametrize("head_size", [16]) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", [16]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("max_q_seq_len", [64]) +@pytest.mark.parametrize("max_kv_seq_len", [64]) +def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, + backend_name: str, batch_size: int, + block_size: int, max_q_seq_len: int, + max_kv_seq_len: int) -> None: + ''' + Encoder/decoder not-implemented-for-ROCm-HIP test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order -# * Prefill self-attention -# * Prefill cross-attention -# * Decode self-attention -# * Decode cross-attention -# * This order would exacerbate any accidental overlap in the -# self-/cross-attention block tables, which we attempt to avoid -# * Validate output correctness against ideal reference attention -# implementation - -# Block tables are constructed such that cross-attention KV cache is in a -# higher, non-intersecting address-space than self-attention KV cache. - -# Self- and cross-attention share the same query tensor but not the K/V -# tensors. Self-attention K/Vs must have the same seq len as Q while -# cross-attention K/Vs are allowed to differ in seq len, as is often the case -# for cross-attention. -# ''' - -# with backend_override_fixture(backend_name): -# # Force Attention wrapper backend - -# # Num KV cache blocks -# num_blocks = 4096 - -# # Attention scale factor, attention backend instance, attention wrapper -# # instance, KV cache init -# scale, \ -# attn_backend, \ -# attn, \ -# kv_cache = basic_setup(num_heads, -# head_size, -# num_blocks, -# block_size, -# backend_name) - -# # Self-attention setup - -# self_block_base_addr = 0 - -# query, \ -# prefill_packed_query, \ -# self_prefill_packed_key, \ -# self_prefill_packed_value, \ -# self_prefill_packed_ideal_output, \ -# prefill_q_seq_lens, \ -# self_prefill_kv_seq_lens, \ -# decode_packed_query, \ -# self_decode_packed_key, \ -# self_decode_packed_value, \ -# self_decode_packed_ideal_output, \ -# _, \ -# _, \ -# q_seq_lens, \ -# _, \ -# self_decode_block_tables, \ -# self_decode_slot_mapping, \ -# self_prefill_slot_mapping, \ -# self_prefill_block_tables, \ -# cross_block_base_addr = decoder_attn_setup(batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# block_base_addr=self_block_base_addr) - -# # Cross-attention setup - -# cross_prefill_packed_key, \ -# cross_prefill_packed_value, \ -# cross_prefill_packed_ideal_output, \ -# cross_decode_packed_ideal_output, \ -# cross_kv_seq_lens, \ -# cross_decode_block_tables, \ -# cross_decode_slot_mapping, \ -# cross_prefill_slot_mapping, \ -# cross_prefill_block_tables, \ -# _ = enc_dec_cross_attn_setup_reuses_query(query, -# q_seq_lens, -# prefill_q_seq_lens, -# batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# max_kv_seq_len, -# block_base_addr=cross_block_base_addr) - -# # PREFILL: self- and cross-attention tests - -# context_lens = [0 for _ in range(batch_size)] - -# prefill_attn_metadata: AttentionMetadata = make_test_metadata( -# attn_backend, -# True, -# prefill_q_seq_lens, -# context_lens, -# self_prefill_block_tables, -# self_prefill_slot_mapping, -# is_encoder_only_test=False, -# cross_seq_lens=cross_kv_seq_lens, -# cross_block_tables=cross_prefill_block_tables, -# cross_slot_mapping=cross_prefill_slot_mapping, -# ) - -# with pytest.raises(NotImplementedError) as exc_info: -# run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, -# cross_prefill_packed_key, -# cross_prefill_packed_value, -# kv_cache, -# prefill_attn_metadata) - -# # "Encoder decoder models do not currently support ROCm/HIP" -# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + with backend_override_fixture(backend_name): + # Force Attention wrapper backend + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + cross_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + cross_seq_lens=cross_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From 9c597c4b26cd5ecc0d38798818eaab99752ab03b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 01:29:22 -0400 Subject: [PATCH 118/239] FIX: test_attention_selector.py was leaking VLLM_ATTENTION_BACKEND values; fixed with backend context manager --- tests/kernels/test_attention_selector.py | 91 +++++++++++------------- tests/kernels/utils.py | 23 ++++++ tests/utils.py | 27 +++++++ 3 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f439afa9b7d2b..1726f58cee088 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,10 +1,10 @@ -import os from unittest.mock import patch import pytest import torch from vllm.attention.selector import which_attn_to_use +from tests.kernels.utils import backend_override_fixture @pytest.mark.parametrize( @@ -14,71 +14,66 @@ def test_env(name: str, device: str): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = name - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): + with backend_override_fixture(name): + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == name - - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + assert backend.name == name def test_flash_attn(): """Test FlashAttn validation.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + with backend_override_fixture("FLASH_ATTN"): - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, + 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, + 16) + assert backend.name != "FLASH_ATTN" - if name_backup is not None: - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" def test_invalid_env(): """Throw an exception if the backend name is invalid.""" - name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) - os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" - with pytest.raises(ValueError): + + with backend_override_fixture("INVALID"), pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..955a96bae2a80 --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,23 @@ +"""Kernel test utils""" + +from tests.utils import env_var_fixture +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def backend_override_fixture(backend_name: str) -> Iterator[None]: + ''' + Text fixture, temporarily configures the vLLM backend by setting + VLLM_ATTENTION_BACKEND, then resets the environment outside of + the fixture. + + Usage: + + with backend_override_fixture("backend_name"): + # code that depends on vLLM backend + + # VLLM_ATTENTION_BACKEND is returned to original value + # or unset + ''' + with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): + yield diff --git a/tests/utils.py b/tests/utils.py index 329842911e159..48666ca652dd7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,8 @@ import warnings from contextlib import contextmanager +from typing import Iterator + import ray import requests @@ -101,3 +103,28 @@ def error_on_warning(): warnings.simplefilter("error") yield + + +@contextmanager +def env_var_fixture(var_name: str, value: str) -> Iterator[None]: + ''' + Text fixture, temporarily assigns value var_name environment variable, + then resets environment variable outside of test fixture. + + Usage: + + with env_var_fixture("my_var","my_val"): + # code that depends on my_val == "my_val" + + # my_var is returned to original value or unset + ''' + original_value = os.environ.get(var_name) # Store the original value + os.environ[var_name] = value # Set the new value + try: + yield + finally: + # Restore the original value + if original_value is None: + del os.environ[var_name] + else: + os.environ[var_name] = original_value From 9831ce63077cd4b531979dde442b4185a18b16b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:06:50 -0400 Subject: [PATCH 119/239] formatting --- tests/kernels/test_attention_selector.py | 2 +- tests/kernels/utils.py | 4 +++- tests/utils.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 1726f58cee088..b0b383974904c 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from vllm.attention.selector import which_attn_to_use from tests.kernels.utils import backend_override_fixture +from vllm.attention.selector import which_attn_to_use @pytest.mark.parametrize( diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 955a96bae2a80..8ebc2fc5905aa 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,11 @@ """Kernel test utils""" -from tests.utils import env_var_fixture from contextlib import contextmanager from typing import Iterator +from tests.utils import env_var_fixture + + @contextmanager def backend_override_fixture(backend_name: str) -> Iterator[None]: ''' diff --git a/tests/utils.py b/tests/utils.py index 48666ca652dd7..adbff8e8dc1c6 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager - from typing import Iterator import ray From faf9118554e2677634b3b3bec9d25c5506d1a7d5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:37:21 -0400 Subject: [PATCH 120/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 25 ++++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 67772b1286e2a..9b1d22e19c57d 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -6,6 +6,7 @@ import pytest import torch +from tests.kernels.utils import backend_override_fixture from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -13,7 +14,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -from tests.kernels.utils import backend_override_fixture logger = init_logger(__name__) @@ -1351,8 +1351,9 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + assert torch.allclose( + packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1542,7 +1543,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_actual_output: torch.Tensor = \ run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) + attn, decode_packed_query, None, + None, kv_cache, decode_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( @@ -1551,7 +1553,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_packed_ideal_output)) # The following test conditions could in principle be a - # standalone test, however the test setup is so involved that it is easier + # standalone test, however the test setup is + # so involved that it is easier # to piggyback off of the test vectors & other data structures # created for testing decode-phase encoder/decoder cross- # attention above. @@ -1567,8 +1570,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + None, None, kv_cache, + decode_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -1701,11 +1704,9 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, ) with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) # "Encoder decoder models do not currently support ROCm/HIP" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From 61d63bd8ea84bbce3889aebb5e60d68316af4f22 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 03:57:28 -0400 Subject: [PATCH 121/239] removed comment about supported head_size's, which is not relevant under current encoder/decoder test conditions --- tests/kernels/test_self_and_cross_attn.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 9b1d22e19c57d..7171ba0c2d84c 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -17,11 +17,6 @@ logger = init_logger(__name__) -# If not is_hip(): supported head sizes are [64, 80, 96, 112, 128, 256] -# -# TODO: FlashAttention forward only supports head dimension at most 128 -# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d0 -# 37782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From b9b604821899f11539041bb51ebed5cffbf1c3a0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 04:03:54 -0400 Subject: [PATCH 122/239] small refactors --- tests/kernels/test_self_and_cross_attn.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 7171ba0c2d84c..37bcc582b8bab 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -17,6 +17,7 @@ logger = init_logger(__name__) + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -348,13 +349,6 @@ def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) packed_value, _ = pack_tensor(value, kv_seq_lens) - if packed_query is not None: - packed_query = packed_query.view( - -1, packed_query.shape[-1] * packed_query.shape[-2]) - packed_key = packed_key.view(-1, - packed_key.shape[-1] * packed_key.shape[-2]) - packed_value = packed_value.view( - -1, packed_value.shape[-1] * packed_value.shape[-2]) return packed_query, \ packed_key, \ packed_value, \ From e2e208234f6d690d7cbbec217bac4e9dbd4f5d37 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 04:44:39 -0400 Subject: [PATCH 123/239] refactoring --- tests/kernels/test_self_and_cross_attn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 37bcc582b8bab..5b2b0f718860e 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -12,12 +12,8 @@ from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.attention.backends.xformers import XFormersBackend -from vllm.logger import init_logger from vllm.utils import is_hip, make_tensor_with_pad -logger = init_logger(__name__) - - HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From b2238738b9a83ea21c58b2142daad80be0a8a659 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 05:20:45 -0400 Subject: [PATCH 124/239] make_qkv() tensors are 4D --- tests/kernels/test_self_and_cross_attn.py | 69 +++++++++-------------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 5b2b0f718860e..79d98142c5121 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -195,42 +195,45 @@ def make_qkv(batch_size: int, actual_max_kv_seq_len = max(kv_seq_lens) query = torch.rand( - (batch_size, max_q_seq_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) key = torch.rand( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) value = torch.rand( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads * head_size)).to(device) + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads * head_size)).to(device) + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) decode_query = torch.zeros( - (batch_size, 1, num_heads * head_size)).to(device) - decode_key = torch.zeros((batch_size, 1, num_heads * head_size)).to(device) + (batch_size, 1, num_heads, head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) decode_value = torch.zeros( - (batch_size, 1, num_heads * head_size)).to(device) + (batch_size, 1, num_heads, head_size)).to(device) for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): - query[bdx, q_seq_len:, :] = 0 - key[bdx, kv_seq_len:, :] = 0 - value[bdx, kv_seq_len:, :] = 0 - - prefill_query[bdx, 0:(q_seq_len - 1), :] = query[bdx, - 0:(q_seq_len - 1), :] - prefill_key[bdx, 0:(kv_seq_len - 1), :] = key[bdx, - 0:(kv_seq_len - 1), :] - prefill_value[bdx, - 0:(kv_seq_len - 1), :] = value[bdx, - 0:(kv_seq_len - 1), :] - - decode_query[bdx, :, :] = query[bdx, (q_seq_len - 1):q_seq_len, :] - decode_key[bdx, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :] - decode_value[bdx, :, :] = value[bdx, (kv_seq_len - 1):kv_seq_len, :] + query[bdx, q_seq_len:, :, :] = 0 + key[bdx, kv_seq_len:, :, :] = 0 + value[bdx, kv_seq_len:, :, :] = 0 + + prefill_query[bdx, + 0:(q_seq_len - 1), :, :] = query[bdx, + 0:(q_seq_len - 1), :, :] + prefill_key[bdx, + 0:(kv_seq_len - 1), :, :] = key[bdx, + 0:(kv_seq_len - 1), :, :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + + decode_query[bdx, :, :, :] = query[bdx, + (q_seq_len - 1):q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, + (kv_seq_len - 1):kv_seq_len, :, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -238,24 +241,6 @@ def make_qkv(batch_size: int, decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - query = query.view(batch_size, query.shape[1], num_heads, head_size) - key = key.view(batch_size, key.shape[1], num_heads, head_size) - value = value.view(batch_size, value.shape[1], num_heads, head_size) - - prefill_query = prefill_query.view(batch_size, prefill_query.shape[1], - num_heads, head_size) - prefill_key = prefill_key.view(batch_size, prefill_key.shape[1], num_heads, - head_size) - prefill_value = prefill_value.view(batch_size, prefill_value.shape[1], - num_heads, head_size) - - decode_query = decode_query.view(batch_size, decode_query.shape[1], - num_heads, head_size) - decode_key = decode_key.view(batch_size, decode_key.shape[1], num_heads, - head_size) - decode_value = decode_value.view(batch_size, decode_value.shape[1], - num_heads, head_size) - return query, \ key, \ value, \ From 2ea335cf656962208c8d989c2af07ea0dd66ce20 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 05:25:27 -0400 Subject: [PATCH 125/239] combined seq_start_loc init with cumsum --- tests/kernels/test_self_and_cross_attn.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 79d98142c5121..da4d3c757e4ea 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -393,15 +393,9 @@ def make_metadata_tensors(is_prompt: bool, context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) + + seq_start_loc = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32)]) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=device) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) if is_prompt: # Prefill: query_start_loc matches seq_start_loc From 02875abd2532cd08abbf457c544e9f8081320479 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:22:25 -0400 Subject: [PATCH 126/239] xformers metadata allows unspecified values for most Optional members; xformers forward only enforces non-none query start locs for prefix caching scenario; simplify self/cross attention test metadata building --- tests/kernels/test_self_and_cross_attn.py | 45 ++++++----------------- vllm/attention/backends/xformers.py | 43 ++++++++++++++-------- 2 files changed, 39 insertions(+), 49 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index da4d3c757e4ea..7cf6736b213d1 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -363,8 +363,7 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def make_metadata_tensors(is_prompt: bool, - seq_lens: List[int], +def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], device: Union[torch.device, str] = \ CUDA_DEVICE) -> tuple: @@ -393,27 +392,17 @@ def make_metadata_tensors(is_prompt: bool, context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) - - seq_start_loc = torch.cat([torch.tensor([0], dtype=torch.int32, device=device), torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32)]) - - if is_prompt: - # Prefill: query_start_loc matches seq_start_loc - query_start_loc = copy.deepcopy(seq_start_loc) - max_query_len = max_seq_len - else: - # Decode: one new query input token per batch element, thus - # query_start_loc is the cumsum of [1,1,1,...] - query_start_loc = list(range(len(seq_start_loc))) - max_query_len = 1 + seq_start_loc = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) + ]) return seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ max_context_len, \ max_seq_len, \ - seq_start_loc, \ - query_start_loc + seq_start_loc def make_kv_cache(num_blocks: int, @@ -625,14 +614,11 @@ def make_test_metadata( seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ _, \ _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - seq_lens, - context_lens, - device=device) + seq_start_loc = make_metadata_tensors(seq_lens, + context_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -641,10 +627,8 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, max_prefill_seq_len=max(seq_lens), max_decode_seq_len=0, - query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, @@ -662,14 +646,11 @@ def make_test_metadata( seq_lens_tensor, \ context_lens_tensor, \ - max_query_len, \ _, \ _, \ - seq_start_loc, \ - query_start_loc = make_metadata_tensors(is_prompt, - seq_lens, - context_lens, - device=device) + seq_start_loc = make_metadata_tensors(seq_lens, + context_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -678,10 +659,8 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - query_start_loc=query_start_loc, seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4ef052424f176..bbff0b2dac906 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -80,8 +80,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -89,23 +87,29 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum sequence length among decode batch. 0 if there are prefill # requests only. max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] + seq_start_loc: Optional[torch.Tensor] = None + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None @@ -194,10 +198,12 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.block_tables is not None + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + self._self_cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -208,7 +214,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], + query_start_loc=query_start_loc, seq_start_loc=None, context_lens_tensor=self.context_lens_tensor[:self. num_prefills], @@ -231,10 +237,12 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None assert self.context_lens_tensor is not None assert self.block_tables is not None + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + self._cross_cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, @@ -245,7 +253,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], + query_start_loc=query_start_loc, seq_start_loc=None, context_lens_tensor=self.context_lens_tensor[:self. num_prefills], @@ -503,6 +511,9 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, From 79f307d218f87ccac90776859446d34690b6591b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:42:41 -0400 Subject: [PATCH 127/239] refactored slot mapping logic --- tests/kernels/test_self_and_cross_attn.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 7cf6736b213d1..d28c173f510fa 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -513,17 +513,13 @@ def make_block_tables_slot_mapping(block_size: int, num_blocks = num_blocks_list[sdx] block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) - for idx in range(num_tokens - 1): - prefill_slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * - block_size) - slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) - idx = num_tokens - 1 - decode_slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) - slot_mapping.append((idx % block_size) + - block_table[idx // block_size] * block_size) + for idx in range(num_tokens): + mapping_value = (idx % block_size) + block_table[idx // block_size] * block_size + slot_mapping.append(mapping_value) + if idx < num_tokens - 1: + prefill_slot_mapping.append(mapping_value) + elif idx == num_tokens - 1: + decode_slot_mapping.append(mapping_value) block_base_idx -= num_blocks block_tables.append(block_table) From e790a00fbde0175d81bab5806d378f68d3d08b9e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 06:43:35 -0400 Subject: [PATCH 128/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index d28c173f510fa..c127cfbfa7e93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -514,7 +514,8 @@ def make_block_tables_slot_mapping(block_size: int, block_table = list( range(block_base_idx, block_base_idx - num_blocks, -1)) for idx in range(num_tokens): - mapping_value = (idx % block_size) + block_table[idx // block_size] * block_size + mapping_value = ( + idx % block_size) + block_table[idx // block_size] * block_size slot_mapping.append(mapping_value) if idx < num_tokens - 1: prefill_slot_mapping.append(mapping_value) From aae601b88788eb8db2394793990cdab47c4c2264 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:01:41 -0400 Subject: [PATCH 129/239] selective renaming of cross -> encoder --- tests/kernels/test_self_and_cross_attn.py | 18 ++++---- vllm/attention/backends/xformers.py | 50 +++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index c127cfbfa7e93..e4a9993a143b3 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -564,7 +564,7 @@ def make_test_metadata( slot_mapping: torch.Tensor, is_encoder_only_test: bool, device: Union[torch.device, str] = CUDA_DEVICE, - cross_seq_lens: Optional[List[int]] = None, + encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, ) -> AttentionMetadata: @@ -591,7 +591,7 @@ def make_test_metadata( * is_encoder_only_test: True if testing encoder; False if testing decoder self-attention or encoder/decoder cross-attention. * device: CPU or CUDA device - * cross_seq_lens: list of token counts for each encoder sequence, if any + * encoder_seq_lens: list of token counts for each encoder sequence, if any exist * cross_block_tables: cross-attention block tables, if required * cross_slot_mapping: cross-attention slot mapping, if required @@ -631,7 +631,7 @@ def make_test_metadata( block_tables=block_tables, use_cuda_graph=False, _attn_type=default_attn_type, - cross_seq_lens=cross_seq_lens, + encoder_seq_lens=encoder_seq_lens, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -663,7 +663,7 @@ def make_test_metadata( block_tables=block_tables, use_cuda_graph=False, _attn_type=default_attn_type, - cross_seq_lens=cross_seq_lens, + encoder_seq_lens=encoder_seq_lens, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -1387,7 +1387,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ + encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ @@ -1416,7 +1416,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) @@ -1460,7 +1460,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, ) @@ -1609,7 +1609,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - cross_kv_seq_lens, \ + encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ @@ -1638,7 +1638,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, - cross_seq_lens=cross_kv_seq_lens, + encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index bbff0b2dac906..fca06a1eeebb4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -128,13 +128,13 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value # sequence length (usually encoder sequence length) in the cross-attention # computation. None if this is self-attention - cross_seq_lens: Optional[List[int]] = None - cross_seq_lens_tensor: Optional[torch.Tensor] = None + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None # The maximum cross-sequence-length, if cross_seq_lens is specified. # Note that for cross-attention there is no difference in key/value # sequence length between prefill and decode - max_cross_seq_len: Optional[int] = None + max_encoder_seq_len: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping # and block tables @@ -152,7 +152,7 @@ def __post_init__(self): @property def is_all_cross_attn_metadata_set(self): # No cross-attention metadata is present whatsoever - return (self.cross_seq_lens is not None) and \ + return (self.encoder_seq_lens is not None) and \ (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) @@ -170,15 +170,15 @@ def attention_type(self, atype: AttentionType) -> None: # Infer implicit cross-attention fields # from user-provided fields, if needed - if self.cross_seq_lens_tensor is None: + if self.encoder_seq_lens_tensor is None: assert self.seq_lens_tensor is not None - self.cross_seq_lens_tensor = torch.tensor( - self.cross_seq_lens, + self.encoder_seq_lens_tensor = torch.tensor( + self.encoder_seq_lens, dtype=self.seq_lens_tensor.dtype, device=self.seq_lens_tensor.device) - if self.max_cross_seq_len is None: - assert self.cross_seq_lens is not None - self.max_cross_seq_len = max(self.cross_seq_lens) + if self.max_encoder_seq_len is None: + assert self.encoder_seq_lens is not None + self.max_encoder_seq_len = max(self.encoder_seq_lens) self._attn_type = AttentionType.ENCODER_DECODER else: @@ -222,9 +222,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=self. _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, + encoder_seq_lens=None, + encoder_seq_lens_tensor=None, + max_encoder_seq_len=None, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_prefill_metadata @@ -261,9 +261,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) return self._cross_cached_prefill_metadata @@ -298,9 +298,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin cross-attention fields below... - cross_seq_lens=None, - cross_seq_lens_tensor=None, - max_cross_seq_len=None, + encoder_seq_lens=None, + encoder_seq_lens_tensor=None, + max_encoder_seq_len=None, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_decode_metadata @@ -330,9 +330,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=AttentionType.ENCODER_DECODER, # Begin cross-attention fields below... - cross_seq_lens=self.cross_seq_lens, - cross_seq_lens_tensor=self.cross_seq_lens_tensor, - max_cross_seq_len=self.max_cross_seq_len, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) return self._cross_cached_decode_metadata @@ -540,8 +540,8 @@ def forward( if decode_meta := attn_metadata.decode_metadata: if attn_type == AttentionType.ENCODER_DECODER: # Paged attention against cross-attention KV cache - seq_lens_arg = decode_meta.cross_seq_lens_tensor - max_seq_len_arg = decode_meta.max_cross_seq_len + seq_lens_arg = decode_meta.encoder_seq_lens_tensor + max_seq_len_arg = decode_meta.max_encoder_seq_len block_tables_arg = decode_meta.cross_block_tables else: # Paged attention against self-attention KV cache @@ -610,7 +610,7 @@ def _run_memory_efficient_xformers_forward( AttentionType.ENCODER_DECODER: # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, attn_metadata.cross_seq_lens) + attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal From d871c9fd8eb9774a9c3699240d34ab05cc0df6ef Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:04:18 -0400 Subject: [PATCH 130/239] added encoder, enc/dec cross-attention bias members --- vllm/attention/backends/xformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fca06a1eeebb4..d653c1d501e71 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,6 +148,8 @@ def __post_init__(self): # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None @property def is_all_cross_attn_metadata_set(self): From 90d5c0dfcbadd6dba3005095455338029965f167 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:20:29 -0400 Subject: [PATCH 131/239] xformers metadata now uses a different attn_bias for self, encoder and cross --- vllm/attention/backends/xformers.py | 34 +++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d653c1d501e71..0432ec4e87e85 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -339,6 +339,29 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cross_cached_decode_metadata +def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ + Optional[List[Optional[AttentionBias]]]: + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + +def _set_attn_bias(attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]]) -> None: + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") class XFormersImpl(AttentionImpl[XFormersMetadata]): """ @@ -606,7 +629,8 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - if attn_metadata.attn_bias is None: + attn_bias = _get_attn_bias(attn_metadata) + if attn_bias is None: if self.alibi_slopes is None: if attn_metadata.attention_type == \ AttentionType.ENCODER_DECODER: @@ -625,12 +649,14 @@ def _run_memory_efficient_xformers_forward( if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) - attn_metadata.attn_bias = [attn_bias] + attn_bias = [attn_bias] else: - attn_metadata.attn_bias = _make_alibi_bias( + attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.seq_lens) + _set_attn_bias(attn_metadata,attn_bias) + # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. @@ -643,7 +669,7 @@ def _run_memory_efficient_xformers_forward( query, key, value, - attn_bias=attn_metadata.attn_bias[0], + attn_bias=attn_bias[0], p=0.0, scale=self.scale) return out.view_as(original_query) From a973c2be7240f50adc95a093bbdd684d6c7cfc07 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:24:03 -0400 Subject: [PATCH 132/239] refactoring --- vllm/attention/backends/xformers.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0432ec4e87e85..368170d3cc4ca 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -341,6 +341,20 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + + Returns: + * Appropriate attention bias value + ''' + attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: return attn_metadata.attn_bias @@ -353,6 +367,18 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]]) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + ''' + attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: attn_metadata.attn_bias = attn_bias From d8d284ed372a863ae0ec95f469f3fa28aadefca8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:34:38 -0400 Subject: [PATCH 133/239] wip typing issues --- vllm/attention/backends/xformers.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 368170d3cc4ca..74ffca1ee4b6a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -363,9 +363,11 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias else: - raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + raise AttributeError( + f"Invalid attn_metadata.attention_type {attn_type}") -def _set_attn_bias(attn_metadata: XFormersMetadata, + +def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]]) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -387,7 +389,9 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: attn_metadata.cross_attn_bias = attn_bias else: - raise AttributeError(f"Invalid attn_metadata.attention_type {attn_type}") + raise AttributeError( + f"Invalid attn_metadata.attention_type {attn_type}") + class XFormersImpl(AttentionImpl[XFormersMetadata]): """ @@ -677,11 +681,11 @@ def _run_memory_efficient_xformers_forward( self.sliding_window) attn_bias = [attn_bias] else: - attn_bias = _make_alibi_bias( - self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) - _set_attn_bias(attn_metadata,attn_bias) + _set_attn_bias(attn_metadata, attn_bias) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -712,7 +716,7 @@ def _run_memory_efficient_xformers_forward( query[None, start:end], key[None, start:end], value[None, start:end], - attn_bias=attn_metadata.attn_bias[i], + attn_bias=attn_bias[i], p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. From 27dc095e8c6b10b2c4470b6d4028a1b31053c7b1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 08:53:13 -0400 Subject: [PATCH 134/239] added paged attention args collection, conditional on metadata attention type --- vllm/attention/backends/xformers.py | 55 ++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 74ffca1ee4b6a..eb5a5f0032ed6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -364,7 +364,7 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ return attn_metadata.cross_attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {attn_type}") + f"Invalid attn_metadata.attention_type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, @@ -390,7 +390,44 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_metadata.cross_attn_bias = attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {attn_type}") + f"Invalid attn_metadata.attention_type {str(attn_type)}") + + +def _get_paged_attention_args(attn_metadata: XFormersMetadata) -> tuple: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Depends on attn_metadata having a valid attention_type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + + Returns: + * Appropriate attention bias value + ''' + + attn_type = attn_metadata.attention_type + if attn_type == AttentionType.DECODER: + # Decoder self-attention + return attn_metadata.seq_lens_tensor, \ + attn_metadata.max_decode_seq_len, \ + attn_metadata.block_tables + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return attn_metadata.encoder_seq_lens_tensor, \ + attn_metadata.max_encoder_seq_len, \ + attn_metadata.cross_block_tables + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return attn_metadata.encoder_seq_lens_tensor, \ + attn_metadata.max_encoder_seq_len, \ + None + else: + raise AttributeError( + f"Invalid attn_metadata.attention_type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -593,16 +630,10 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: - if attn_type == AttentionType.ENCODER_DECODER: - # Paged attention against cross-attention KV cache - seq_lens_arg = decode_meta.encoder_seq_lens_tensor - max_seq_len_arg = decode_meta.max_encoder_seq_len - block_tables_arg = decode_meta.cross_block_tables - else: - # Paged attention against self-attention KV cache - seq_lens_arg = decode_meta.seq_lens_tensor - max_seq_len_arg = decode_meta.max_decode_seq_len - block_tables_arg = decode_meta.block_tables + + seq_lens_arg, \ + max_seq_len_arg,\ + block_tables_arg = _get_paged_attention_args(decode_meta) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, From 24459051c1295a73ebfaf26e19c15836cd1c8aca Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 09:41:51 -0400 Subject: [PATCH 135/239] logic to support encoder-specific sequence length usage in xformers --- tests/kernels/test_self_and_cross_attn.py | 1 + vllm/attention/backends/xformers.py | 80 ++++++++++++++++------- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index e4a9993a143b3..239c4a6139c64 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1278,6 +1278,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, + encoder_seq_lens=q_seq_lens ) packed_actual_output: torch.Tensor = \ diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index eb5a5f0032ed6..82a38f3af8123 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -150,11 +150,25 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None + + if self.is_all_encoder_attn_metadata_set: + self._maybe_compute_implicit_encoder_attrs() + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return self.encoder_seq_lens is not None @property def is_all_cross_attn_metadata_set(self): - # No cross-attention metadata is present whatsoever - return (self.encoder_seq_lens is not None) and \ + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return self.is_all_encoder_attn_metadata_set and \ (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) @@ -162,27 +176,40 @@ def is_all_cross_attn_metadata_set(self): def attention_type(self) -> AttentionType: return self._attn_type + def _maybe_compute_implicit_encoder_attrs(self): + ''' + Encoder attention and cross-attention require some encoder-related + metadata attributes which may or may not be been provided by the user. + This method infers the implicit attributes from provided attributes + ''' + if self.encoder_seq_lens_tensor is None: + assert self.seq_lens_tensor is not None + self.encoder_seq_lens_tensor = torch.tensor( + self.encoder_seq_lens, + dtype=self.seq_lens_tensor.dtype, + device=self.seq_lens_tensor.device) + if self.max_encoder_seq_len is None: + assert self.encoder_seq_lens is not None + self.max_encoder_seq_len = max(self.encoder_seq_lens) + @attention_type.setter def attention_type(self, atype: AttentionType) -> None: if atype == AttentionType.ENCODER_DECODER: assert self.is_all_cross_attn_metadata_set, \ - "Must enable self.cross_seq_lens, self.cross_slot_mapping, " + \ + "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - # Infer implicit cross-attention fields - # from user-provided fields, if needed - if self.encoder_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None - self.encoder_seq_lens_tensor = torch.tensor( - self.encoder_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) - if self.max_encoder_seq_len is None: - assert self.encoder_seq_lens is not None - self.max_encoder_seq_len = max(self.encoder_seq_lens) + self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER_DECODER + elif atype == AttentionType.ENCODER: + assert self.is_all_encoder_attn_metadata_set, \ + "Must set self.encoder_seq_lens in order to perform cross-attention" + + self._maybe_compute_implicit_encoder_attrs() + + self._attn_type = AttentionType.ENCODER else: # AttentionType.{ENCODER,DECODER} self._attn_type = atype @@ -224,9 +251,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=False, _attn_type=self. _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=None, - encoder_seq_lens_tensor=None, - max_encoder_seq_len=None, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_prefill_metadata @@ -300,9 +327,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=None, - encoder_seq_lens_tensor=None, - max_encoder_seq_len=None, + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, cross_block_tables=None, cross_slot_mapping=None) return self._self_cached_decode_metadata @@ -393,7 +420,7 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_paged_attention_args(attn_metadata: XFormersMetadata) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata) -> tuple: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -633,7 +660,7 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_paged_attention_args(decode_meta) + block_tables_arg = _get_seq_len_block_table_args(decode_meta) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -672,7 +699,14 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.seq_lens is not None + + # Enforce that the appropriate *_seq_lens attribute of attn_metadata + # (seq_lens or encoder_seq_lens) is set. + seq_lens, \ + _,\ + _ = _get_seq_len_block_table_args(attn_metadata) + assert seq_lens is not None + original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. From 8dabdc2abf59179064667461fb339ef4eb758247 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 09:43:26 -0400 Subject: [PATCH 136/239] formatting --- tests/kernels/test_self_and_cross_attn.py | 3 +-- vllm/attention/backends/xformers.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_self_and_cross_attn.py index 239c4a6139c64..c0d62f7531b93 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_self_and_cross_attn.py @@ -1278,8 +1278,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, block_tables, slot_mapping, is_encoder_only_test=True, - encoder_seq_lens=q_seq_lens - ) + encoder_seq_lens=q_seq_lens) packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 82a38f3af8123..648706b4bf07c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -150,7 +150,7 @@ def __post_init__(self): self.attn_bias: Optional[List[AttentionBias]] = None self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - + if self.is_all_encoder_attn_metadata_set: self._maybe_compute_implicit_encoder_attrs() @@ -172,10 +172,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - @property - def attention_type(self) -> AttentionType: - return self._attn_type - def _maybe_compute_implicit_encoder_attrs(self): ''' Encoder attention and cross-attention require some encoder-related @@ -192,6 +188,10 @@ def _maybe_compute_implicit_encoder_attrs(self): assert self.encoder_seq_lens is not None self.max_encoder_seq_len = max(self.encoder_seq_lens) + @property + def attention_type(self) -> AttentionType: + return self._attn_type + @attention_type.setter def attention_type(self, atype: AttentionType) -> None: From c6200e676bf89c0ef1116b3d18f520448c5e912f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:34:51 -0400 Subject: [PATCH 137/239] test name change; encoder functionality can tolerate being provided with only encoder metadata --- ...and_decoder_self_and_encdec_cross_attn.py} | 36 ++++++++++++------ vllm/attention/backends/xformers.py | 38 ++++++++++--------- 2 files changed, 45 insertions(+), 29 deletions(-) rename tests/kernels/{test_self_and_cross_attn.py => test_encoder_and_decoder_self_and_encdec_cross_attn.py} (98%) diff --git a/tests/kernels/test_self_and_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py similarity index 98% rename from tests/kernels/test_self_and_cross_attn.py rename to tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index c0d62f7531b93..8e1b8849c8a9f 100644 --- a/tests/kernels/test_self_and_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -387,16 +387,18 @@ def make_metadata_tensors(seq_lens: List[int], * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device) + seq_lens_tensor = None if seq_lens is None else \ + torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) - seq_start_loc = torch.cat([ - torch.tensor([0], dtype=torch.int32, device=device), - torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) - ]) + seq_start_loc = None + # seq_start_loc = torch.cat([ + # torch.tensor([0], dtype=torch.int32, device=device), + # torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) + # ]) return seq_lens_tensor, \ context_lens_tensor, \ @@ -563,6 +565,8 @@ def make_test_metadata( block_tables: torch.Tensor, slot_mapping: torch.Tensor, is_encoder_only_test: bool, + num_prefills_or_decodes: int, + num_prefill_or_decode_tokens: int, device: Union[torch.device, str] = CUDA_DEVICE, encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, @@ -605,8 +609,8 @@ def make_test_metadata( else AttentionType.DECODER if is_prompt: - num_prefills = len(seq_lens) - num_prefill_tokens = sum(seq_lens) + num_prefills = num_prefills_or_decodes + num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 seq_lens_tensor, \ @@ -624,9 +628,9 @@ def make_test_metadata( num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=max(seq_lens), + max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, - seq_start_loc=seq_start_loc, + # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -639,7 +643,7 @@ def make_test_metadata( num_prefills = 0 num_prefill_tokens = 0 - num_decode_tokens = len(seq_lens) + num_decode_tokens = num_prefill_or_decode_tokens seq_lens_tensor, \ context_lens_tensor, \ @@ -658,7 +662,7 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - seq_start_loc=seq_start_loc, + # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -1273,11 +1277,13 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - q_seq_lens, + None, context_lens, block_tables, slot_mapping, is_encoder_only_test=True, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=sum(q_seq_lens), encoder_seq_lens=q_seq_lens) packed_actual_output: torch.Tensor = \ @@ -1416,6 +1422,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, @@ -1460,6 +1468,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=len(q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, @@ -1638,6 +1648,8 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 648706b4bf07c..a7bdbb4d03839 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -179,13 +179,11 @@ def _maybe_compute_implicit_encoder_attrs(self): This method infers the implicit attributes from provided attributes ''' if self.encoder_seq_lens_tensor is None: - assert self.seq_lens_tensor is not None self.encoder_seq_lens_tensor = torch.tensor( self.encoder_seq_lens, - dtype=self.seq_lens_tensor.dtype, - device=self.seq_lens_tensor.device) + dtype=torch.int32, + device="cuda:0") if self.max_encoder_seq_len is None: - assert self.encoder_seq_lens is not None self.max_encoder_seq_len = max(self.encoder_seq_lens) @property @@ -225,8 +223,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._self_cached_prefill_metadata is not None: return self._self_cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) assert self.context_lens_tensor is not None assert self.block_tables is not None @@ -238,8 +238,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -264,8 +264,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._cross_cached_prefill_metadata is not None: return self._cross_cached_prefill_metadata - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) assert self.context_lens_tensor is not None assert self.block_tables is not None @@ -277,8 +279,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -308,7 +310,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._self_cached_decode_metadata is not None: return self._self_cached_decode_metadata assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) self._self_cached_decode_metadata = XFormersMetadata( num_prefills=0, @@ -316,7 +319,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -340,7 +343,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._cross_cached_decode_metadata is not None: return self._cross_cached_decode_metadata assert self.block_tables is not None - assert self.seq_lens_tensor is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) self._cross_cached_decode_metadata = XFormersMetadata( num_prefills=0, @@ -348,7 +352,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -736,7 +740,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attention_type == AttentionType.ENCODER: # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens) + attn_metadata.encoder_seq_lens) else: # Default decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( From 0dc197b794de728dc5f40ce6d0e2058d8099397a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:35:35 -0400 Subject: [PATCH 138/239] formatting --- vllm/attention/backends/xformers.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a7bdbb4d03839..d3af144c742ba 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -179,10 +179,9 @@ def _maybe_compute_implicit_encoder_attrs(self): This method infers the implicit attributes from provided attributes ''' if self.encoder_seq_lens_tensor is None: - self.encoder_seq_lens_tensor = torch.tensor( - self.encoder_seq_lens, - dtype=torch.int32, - device="cuda:0") + self.encoder_seq_lens_tensor = torch.tensor(self.encoder_seq_lens, + dtype=torch.int32, + device="cuda:0") if self.max_encoder_seq_len is None: self.max_encoder_seq_len = max(self.encoder_seq_lens) @@ -238,8 +237,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -279,8 +280,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -319,7 +322,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, @@ -352,7 +356,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[self.num_prefills:], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, From 3d3c04ff2aa6be6e93fb0dbb8aca24739d848668 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:48:31 -0400 Subject: [PATCH 139/239] prefill supports shared metadata structure --- vllm/attention/backends/xformers.py | 210 +++++++++++++++++----------- 1 file changed, 125 insertions(+), 85 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d3af144c742ba..dc3d3fdff4af4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -115,10 +115,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None _self_cached_decode_metadata: Optional["XFormersMetadata"] = None # Cross-attention prefill/decode metadata cache - _cross_cached_prefill_metadata: Optional["XFormersMetadata"] = None _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None - # Begin cross-attention fields... + # Begin encoder attn & enc/dec cross-attn fields... # If True, prefill_metadata() and decode_metadata() will return # seqlen & memory-mapping data structures for cross-attention; @@ -216,91 +215,132 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_prefill_metadata is not None: - return self._self_cached_prefill_metadata - - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - - self._self_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_block_tables=None, - cross_slot_mapping=None) + if self._self_cached_prefill_metadata is not None: + self._self_cached_prefill_metadata.attention_type = self.attention_type return self._self_cached_prefill_metadata - else: - # Encoder/decoder cross-attention prefill - - if self._cross_cached_prefill_metadata is not None: - return self._cross_cached_prefill_metadata - - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None - assert self.block_tables is not None - - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - - self._cross_cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_prefill_metadata + assert (self.seq_lens is not None) or \ + (self.encoder_seq_lens is not None) + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + query_start_loc = None if self.query_start_loc is None \ + else self.query_start_loc[:self.num_prefills + 1] + + self._self_cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=None if self.seq_lens is None else + self.seq_lens[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self. + num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + _attn_type=self.attention_type, + # Begin cross-attention fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._self_cached_prefill_metadata + + # if self._attn_type != AttentionType.ENCODER_DECODER: + # # Decoder or encoder self-attention prefill + + # if self._self_cached_prefill_metadata is not None: + # return self._self_cached_prefill_metadata + + # assert (self.seq_lens is not None) or \ + # (self.encoder_seq_lens is not None) + # assert (self.seq_lens_tensor is not None) or \ + # (self.encoder_seq_lens_tensor is not None) + # assert self.context_lens_tensor is not None + # assert self.block_tables is not None + + # query_start_loc = None if self.query_start_loc is None \ + # else self.query_start_loc[:self.num_prefills + 1] + + # self._self_cached_prefill_metadata = XFormersMetadata( + # num_prefills=self.num_prefills, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=0, + # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + # seq_lens=None if self.seq_lens is None else + # self.seq_lens[:self.num_prefills], + # seq_lens_tensor=None if self.seq_lens_tensor is None else + # self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + # max_prefill_seq_len=self.max_prefill_seq_len, + # max_decode_seq_len=0, + # query_start_loc=query_start_loc, + # seq_start_loc=None, + # context_lens_tensor=self.context_lens_tensor[:self. + # num_prefills], + # block_tables=self.block_tables[:self.num_prefills], + # use_cuda_graph=False, + # _attn_type=self. + # _attn_type, # Begin cross-attention fields below... + # encoder_seq_lens=self.encoder_seq_lens, + # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + # max_encoder_seq_len=self.max_encoder_seq_len, + # cross_block_tables=None, + # cross_slot_mapping=None) + # return self._self_cached_prefill_metadata + + # else: + # # Encoder/decoder cross-attention prefill + + # if self._cross_cached_prefill_metadata is not None: + # return self._cross_cached_prefill_metadata + + # assert (self.seq_lens is not None) or \ + # (self.encoder_seq_lens is not None) + # assert (self.seq_lens_tensor is not None) or \ + # (self.encoder_seq_lens_tensor is not None) + # assert self.context_lens_tensor is not None + # assert self.block_tables is not None + + # query_start_loc = None if self.query_start_loc is None \ + # else self.query_start_loc[:self.num_prefills + 1] + + # self._cross_cached_prefill_metadata = XFormersMetadata( + # num_prefills=self.num_prefills, + # num_prefill_tokens=self.num_prefill_tokens, + # num_decode_tokens=0, + # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + # seq_lens=None if self.seq_lens is None else + # self.seq_lens[:self.num_prefills], + # seq_lens_tensor=None if self.seq_lens_tensor is None else + # self.seq_lens_tensor[:self.num_prefills], + # max_query_len=self.max_query_len, + # max_prefill_seq_len=self.max_prefill_seq_len, + # max_decode_seq_len=0, + # query_start_loc=query_start_loc, + # seq_start_loc=None, + # context_lens_tensor=self.context_lens_tensor[:self. + # num_prefills], + # block_tables=self.block_tables[:self.num_prefills], + # use_cuda_graph=False, + # _attn_type=AttentionType.ENCODER_DECODER, + # # Begin cross-attention fields below... + # encoder_seq_lens=self.encoder_seq_lens, + # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + # max_encoder_seq_len=self.max_encoder_seq_len, + # cross_slot_mapping=self.cross_slot_mapping, + # cross_block_tables=self.cross_block_tables) + # return self._cross_cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: From c132caaa3be6bf67eacd628f0d21816de23e9dc7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:49:23 -0400 Subject: [PATCH 140/239] formatting --- vllm/attention/backends/xformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dc3d3fdff4af4..d52b1e8cd000c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -216,7 +216,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return None if self._self_cached_prefill_metadata is not None: - self._self_cached_prefill_metadata.attention_type = self.attention_type + self._self_cached_prefill_metadata.attention_type = \ + self.attention_type return self._self_cached_prefill_metadata assert (self.seq_lens is not None) or \ @@ -234,8 +235,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None else - self.seq_lens[:self.num_prefills], + seq_lens=None + if self.seq_lens is None else self.seq_lens[:self.num_prefills], seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -243,8 +244,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=0, query_start_loc=query_start_loc, seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self. - num_prefills], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, From 39ee51a508b82693e8bf03634667011a88d2388d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 10:54:34 -0400 Subject: [PATCH 141/239] full generalization of prefill & decode metadata structures --- vllm/attention/backends/xformers.py | 181 +++++----------------------- 1 file changed, 30 insertions(+), 151 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d52b1e8cd000c..56a57ac5c4465 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -256,164 +256,43 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._self_cached_prefill_metadata - # if self._attn_type != AttentionType.ENCODER_DECODER: - # # Decoder or encoder self-attention prefill - - # if self._self_cached_prefill_metadata is not None: - # return self._self_cached_prefill_metadata - - # assert (self.seq_lens is not None) or \ - # (self.encoder_seq_lens is not None) - # assert (self.seq_lens_tensor is not None) or \ - # (self.encoder_seq_lens_tensor is not None) - # assert self.context_lens_tensor is not None - # assert self.block_tables is not None - - # query_start_loc = None if self.query_start_loc is None \ - # else self.query_start_loc[:self.num_prefills + 1] - - # self._self_cached_prefill_metadata = XFormersMetadata( - # num_prefills=self.num_prefills, - # num_prefill_tokens=self.num_prefill_tokens, - # num_decode_tokens=0, - # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - # seq_lens=None if self.seq_lens is None else - # self.seq_lens[:self.num_prefills], - # seq_lens_tensor=None if self.seq_lens_tensor is None else - # self.seq_lens_tensor[:self.num_prefills], - # max_query_len=self.max_query_len, - # max_prefill_seq_len=self.max_prefill_seq_len, - # max_decode_seq_len=0, - # query_start_loc=query_start_loc, - # seq_start_loc=None, - # context_lens_tensor=self.context_lens_tensor[:self. - # num_prefills], - # block_tables=self.block_tables[:self.num_prefills], - # use_cuda_graph=False, - # _attn_type=self. - # _attn_type, # Begin cross-attention fields below... - # encoder_seq_lens=self.encoder_seq_lens, - # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - # max_encoder_seq_len=self.max_encoder_seq_len, - # cross_block_tables=None, - # cross_slot_mapping=None) - # return self._self_cached_prefill_metadata - - # else: - # # Encoder/decoder cross-attention prefill - - # if self._cross_cached_prefill_metadata is not None: - # return self._cross_cached_prefill_metadata - - # assert (self.seq_lens is not None) or \ - # (self.encoder_seq_lens is not None) - # assert (self.seq_lens_tensor is not None) or \ - # (self.encoder_seq_lens_tensor is not None) - # assert self.context_lens_tensor is not None - # assert self.block_tables is not None - - # query_start_loc = None if self.query_start_loc is None \ - # else self.query_start_loc[:self.num_prefills + 1] - - # self._cross_cached_prefill_metadata = XFormersMetadata( - # num_prefills=self.num_prefills, - # num_prefill_tokens=self.num_prefill_tokens, - # num_decode_tokens=0, - # slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - # seq_lens=None if self.seq_lens is None else - # self.seq_lens[:self.num_prefills], - # seq_lens_tensor=None if self.seq_lens_tensor is None else - # self.seq_lens_tensor[:self.num_prefills], - # max_query_len=self.max_query_len, - # max_prefill_seq_len=self.max_prefill_seq_len, - # max_decode_seq_len=0, - # query_start_loc=query_start_loc, - # seq_start_loc=None, - # context_lens_tensor=self.context_lens_tensor[:self. - # num_prefills], - # block_tables=self.block_tables[:self.num_prefills], - # use_cuda_graph=False, - # _attn_type=AttentionType.ENCODER_DECODER, - # # Begin cross-attention fields below... - # encoder_seq_lens=self.encoder_seq_lens, - # encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - # max_encoder_seq_len=self.max_encoder_seq_len, - # cross_slot_mapping=self.cross_slot_mapping, - # cross_block_tables=self.cross_block_tables) - # return self._cross_cached_prefill_metadata - @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._attn_type != AttentionType.ENCODER_DECODER: - # Decoder or encoder self-attention prefill - - if self._self_cached_decode_metadata is not None: - return self._self_cached_decode_metadata - assert self.block_tables is not None - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - - self._self_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_block_tables=None, - cross_slot_mapping=None) + if self._self_cached_decode_metadata is not None: + self._self_cached_decode_metadata.attention_type = \ + self.attention_type return self._self_cached_decode_metadata + assert self.block_tables is not None + assert (self.seq_lens_tensor is not None) or \ + (self.encoder_seq_lens_tensor is not None) - else: - # Encoder/decoder cross-attention decode - - if self._cross_cached_decode_metadata is not None: - return self._cross_cached_decode_metadata - assert self.block_tables is not None - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) - - self._cross_cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - _attn_type=AttentionType.ENCODER_DECODER, - # Begin cross-attention fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cross_cached_decode_metadata + self._self_cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + _attn_type=self. + _attn_type, # Begin cross-attention fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._self_cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: From c41917bee6490dcb062e61c559ec3c17d7c951b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 11:02:13 -0400 Subject: [PATCH 142/239] renamed metdata caching structure --- vllm/attention/backends/xformers.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 56a57ac5c4465..b56e566a89fe1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -112,10 +112,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): query_start_loc: Optional[torch.Tensor] = None # Self-attention prefill/decode metadata cache - _self_cached_prefill_metadata: Optional["XFormersMetadata"] = None - _self_cached_decode_metadata: Optional["XFormersMetadata"] = None - # Cross-attention prefill/decode metadata cache - _cross_cached_decode_metadata: Optional["XFormersMetadata"] = None + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None # Begin encoder attn & enc/dec cross-attn fields... @@ -215,10 +213,10 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: return None - if self._self_cached_prefill_metadata is not None: - self._self_cached_prefill_metadata.attention_type = \ + if self._cached_prefill_metadata is not None: + self._cached_prefill_metadata.attention_type = \ self.attention_type - return self._self_cached_prefill_metadata + return self._cached_prefill_metadata assert (self.seq_lens is not None) or \ (self.encoder_seq_lens is not None) @@ -230,7 +228,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] - self._self_cached_prefill_metadata = XFormersMetadata( + self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, @@ -254,22 +252,22 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - return self._self_cached_prefill_metadata + return self._cached_prefill_metadata @property def decode_metadata(self) -> Optional["XFormersMetadata"]: if self.num_decode_tokens == 0: return None - if self._self_cached_decode_metadata is not None: - self._self_cached_decode_metadata.attention_type = \ + if self._cached_decode_metadata is not None: + self._cached_decode_metadata.attention_type = \ self.attention_type - return self._self_cached_decode_metadata + return self._cached_decode_metadata assert self.block_tables is not None assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - self._self_cached_decode_metadata = XFormersMetadata( + self._cached_decode_metadata = XFormersMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, @@ -292,7 +290,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - return self._self_cached_decode_metadata + return self._cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Optional[List[Optional[AttentionBias]]]: From e738fb4ee2c98f6af912e61c9a4a99acfe887dcf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 11:47:40 -0400 Subject: [PATCH 143/239] reverted my custom env var patch impl --- tests/kernels/utils.py | 25 ------------------------- tests/utils.py | 26 -------------------------- 2 files changed, 51 deletions(-) delete mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py deleted file mode 100644 index 8ebc2fc5905aa..0000000000000 --- a/tests/kernels/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Kernel test utils""" - -from contextlib import contextmanager -from typing import Iterator - -from tests.utils import env_var_fixture - - -@contextmanager -def backend_override_fixture(backend_name: str) -> Iterator[None]: - ''' - Text fixture, temporarily configures the vLLM backend by setting - VLLM_ATTENTION_BACKEND, then resets the environment outside of - the fixture. - - Usage: - - with backend_override_fixture("backend_name"): - # code that depends on vLLM backend - - # VLLM_ATTENTION_BACKEND is returned to original value - # or unset - ''' - with env_var_fixture('VLLM_ATTENTION_BACKEND', backend_name): - yield diff --git a/tests/utils.py b/tests/utils.py index adbff8e8dc1c6..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import time import warnings from contextlib import contextmanager -from typing import Iterator import ray import requests @@ -102,28 +101,3 @@ def error_on_warning(): warnings.simplefilter("error") yield - - -@contextmanager -def env_var_fixture(var_name: str, value: str) -> Iterator[None]: - ''' - Text fixture, temporarily assigns value var_name environment variable, - then resets environment variable outside of test fixture. - - Usage: - - with env_var_fixture("my_var","my_val"): - # code that depends on my_val == "my_val" - - # my_var is returned to original value or unset - ''' - original_value = os.environ.get(var_name) # Store the original value - os.environ[var_name] = value # Set the new value - try: - yield - finally: - # Restore the original value - if original_value is None: - del os.environ[var_name] - else: - os.environ[var_name] = original_value From dfe9c10389beccfc43b2bddf687075e07e7283b9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:01:52 -0400 Subject: [PATCH 144/239] monkeypatch works --- tests/kernels/test_attention_selector.py | 93 ++++++++++++------------ 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index b0b383974904c..ebd2d460dc45e 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,79 +1,80 @@ +import os from unittest.mock import patch import pytest import torch -from tests.kernels.utils import backend_override_fixture from vllm.attention.selector import which_attn_to_use +_backend_env_var = "VLLM_ATTENTION_BACKEND" +_flash_attn_val = "FLASH_ATTN" +_invalid_val = "INVALID" + @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) -def test_env(name: str, device: str): +def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ - with backend_override_fixture(name): - - if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "TORCH_SDPA" - elif device == "hip": - with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, - torch.float16, 16) - assert backend.name == "ROCM_FLASH" - else: + monkeypatch.setenv(_backend_env_var,name) + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) - assert backend.name == name + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name -def test_flash_attn(): +def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - with backend_override_fixture("FLASH_ATTN"): - - # Unsupported CUDA arch - with patch("torch.cuda.get_device_capability", return_value=[7, 5]): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + monkeypatch.setenv(_backend_env_var,_flash_attn_val) - # Unsupported data type - backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, - 16) + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" - # Unsupported kv cache data type - backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported block size - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" - # Unsupported sliding window - backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" - # flash-attn is not installed - with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, - 16) - assert backend.name != "FLASH_ATTN" + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" - # Unsupported head size - backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) assert backend.name != "FLASH_ATTN" + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" -def test_invalid_env(): - """Throw an exception if the backend name is invalid.""" - with backend_override_fixture("INVALID"), pytest.raises(ValueError): +def test_invalid_env(monkeypatch): + """Throw an exception if the backend name is invalid.""" + monkeypatch.setenv(_backend_env_var,_invalid_val) + with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From 822175834eb2846c826a63357f731966ed83abce Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:02:31 -0400 Subject: [PATCH 145/239] formatting --- tests/kernels/test_attention_selector.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ebd2d460dc45e..0b4bb7c353cc3 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -1,4 +1,3 @@ -import os from unittest.mock import patch import pytest @@ -19,7 +18,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var,name) + monkeypatch.setenv(_backend_env_var, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -40,7 +39,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var,_flash_attn_val) + monkeypatch.setenv(_backend_env_var, _flash_attn_val) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -75,6 +74,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var,_invalid_val) + monkeypatch.setenv(_backend_env_var, _invalid_val) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) From db2b2d23f8e17683f03ca36e0b9ea41e0ce0c74f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:05:23 -0400 Subject: [PATCH 146/239] wip monkeypatch --- .../test_encoder_and_decoder_self_and_encdec_cross_attn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index 8e1b8849c8a9f..1b830a1dde63a 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -6,7 +6,6 @@ import pytest import torch -from tests.kernels.utils import backend_override_fixture from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -14,6 +13,8 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad +_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] From cbb89b1dd912cb816d3c7d721bb129baa08c83b3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:08:56 -0400 Subject: [PATCH 147/239] refactored constants into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 12 +++++------- tests/kernels/utils.py | 5 +++++ 2 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 tests/kernels/utils.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 0b4bb7c353cc3..7bc0439f3ee82 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,12 +3,10 @@ import pytest import torch +from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, + STR_INVALID_VAL) from vllm.attention.selector import which_attn_to_use -_backend_env_var = "VLLM_ATTENTION_BACKEND" -_flash_attn_val = "FLASH_ATTN" -_invalid_val = "INVALID" - @pytest.mark.parametrize( "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) @@ -18,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(_backend_env_var, name) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -39,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(_backend_env_var, _flash_attn_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -74,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(_backend_env_var, _invalid_val) + monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py new file mode 100644 index 0000000000000..74ad9d8256e3f --- /dev/null +++ b/tests/kernels/utils.py @@ -0,0 +1,5 @@ +"""Kernel test utils""" + +STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL = "FLASH_ATTN" +STR_INVALID_VAL = "INVALID" \ No newline at end of file From c9ce86be2f3b973e9df5057b6ef5ab956224f28e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:14:33 -0400 Subject: [PATCH 148/239] wip enc/dec monkeypatch integration --- ..._and_decoder_self_and_encdec_cross_attn.py | 144 +++++++++--------- 1 file changed, 73 insertions(+), 71 deletions(-) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py index 1b830a1dde63a..d87e548376257 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import copy import itertools import random @@ -6,6 +8,8 @@ import pytest import torch +from tests.kernels.utils import STR_BACKEND_ENV_VAR + from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -13,7 +17,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad -_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" HEAD_SIZES = [64, 256] @@ -1195,8 +1198,7 @@ def run_encoder_decoder_cross_attention_test( ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata) - + attn_metadata) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1207,7 +1209,7 @@ def run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, - max_seq_len: int) -> None: + max_seq_len: int, monkeypatch) -> None: ''' Encoder-only attention test: @@ -1234,73 +1236,73 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, layer output.) ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - None, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens) - - packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + # Force Attention wrapper backend + monkeypatch.setenv(STR_BACKEND_ENV_VAR,backend_name) + + # Attention scale factor, attention backend instance, attention wrapper + # instance. Encoder attention does not require KV cache. + scale, \ + attn_backend, \ + attn, \ + _ = basic_setup(num_heads, + head_size, + None, + None, + backend_name) + + # Self-attention setup + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + packed_query, \ + packed_key, \ + packed_value, \ + packed_ideal_output, \ + block_tables, \ + slot_mapping, \ + q_seq_lens = encoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_seq_len) + + context_lens = [0 for _ in range(batch_size)] + + # Metadata config for encoder attention: + # + # * Use prefill kernel + # * Signal that this is an encoder-only test so that + # metadata attention_type property is correctly + # configured as AttentionType.ENCODER + attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + None, + context_lens, + block_tables, + slot_mapping, + is_encoder_only_test=True, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=sum(q_seq_lens), + encoder_seq_lens=q_seq_lens) + + packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + packed_query, + packed_key, + packed_value, + None, + attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose( + packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) From ca570e7a078540d8788dc6b7cf961039f699a761 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:17:57 -0400 Subject: [PATCH 149/239] a refactoring backend override functionality into tests/kernels/utils.py --- tests/kernels/test_attention_selector.py | 10 +++++----- tests/kernels/utils.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 7bc0439f3ee82..ea3ccb026ea2b 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,8 +3,8 @@ import pytest import torch -from tests.kernels.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_INVALID_VAL) +from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, + override_backend) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - monkeypatch.setenv(STR_BACKEND_ENV_VAR, name) + override_backend(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) + override_backend(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - monkeypatch.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) + override_backend(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 74ad9d8256e3f..3874fad57ae43 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,4 +2,8 @@ STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" \ No newline at end of file +STR_INVALID_VAL = "INVALID" + + +def override_backend(mpatch, backend_name): + mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From b2e131f95143984ddc0a38f5b9b08e9abc421b45 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:26:07 -0400 Subject: [PATCH 150/239] test rename --- ...s_attn.py => test_encoder_decoder_attn.py} | 605 +++++++++--------- 1 file changed, 303 insertions(+), 302 deletions(-) rename tests/kernels/{test_encoder_and_decoder_self_and_encdec_cross_attn.py => test_encoder_decoder_attn.py} (80%) diff --git a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py b/tests/kernels/test_encoder_decoder_attn.py similarity index 80% rename from tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py rename to tests/kernels/test_encoder_decoder_attn.py index d87e548376257..0a835940c3d75 100644 --- a/tests/kernels/test_encoder_and_decoder_self_and_encdec_cross_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,4 +1,10 @@ -from unittest.mock import patch +""" +Test + +* Encoder attention +* Decoder self-attention +* Encoder/decoder cross-attention +""" import copy import itertools @@ -8,8 +14,7 @@ import pytest import torch -from tests.kernels.utils import STR_BACKEND_ENV_VAR - +from tests.kernels.utils import override_backend from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -17,7 +22,6 @@ from vllm.attention.backends.xformers import XFormersBackend from vllm.utils import is_hip, make_tensor_with_pad - HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -399,10 +403,6 @@ def make_metadata_tensors(seq_lens: List[int], max_seq_len = None if seq_lens is None else max(seq_lens) seq_start_loc = None - # seq_start_loc = torch.cat([ - # torch.tensor([0], dtype=torch.int32, device=device), - # torch.cumsum(seq_lens_tensor, dim=0, dtype=torch.int32) - # ]) return seq_lens_tensor, \ context_lens_tensor, \ @@ -634,7 +634,6 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, - # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -666,7 +665,6 @@ def make_test_metadata( seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), - # seq_start_loc=seq_start_loc, context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, @@ -1198,7 +1196,8 @@ def run_encoder_decoder_cross_attention_test( ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER return attn.forward(packed_query, packed_key, packed_value, kv_cache, - attn_metadata) + attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1208,8 +1207,8 @@ def run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_seq_len: int, monkeypatch) -> None: + batch_size: int, block_size: int, max_seq_len: int, + monkeypatch) -> None: ''' Encoder-only attention test: @@ -1237,7 +1236,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' # Force Attention wrapper backend - monkeypatch.setenv(STR_BACKEND_ENV_VAR,backend_name) + override_backend(monkeypatch, backend_name) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. @@ -1300,9 +1299,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) + assert torch.allclose(packed_ideal_output, + packed_actual_output.view_as(packed_ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -1315,7 +1313,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, @pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, max_kv_seq_len: int) -> None: + block_size: int, max_q_seq_len: int, max_kv_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -1342,192 +1341,192 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( for cross-attention. ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, - kv_cache, - prefill_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Prefill self-attention correct? - assert torch.allclose( - self_prefill_packed_ideal_output, - self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) - - cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # - Prefill cross-attention correct? - assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) - - # DECODE: self- and cross-attention tests - - decode_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - False, - q_seq_lens, - context_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=len(q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, - ) - - self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( - attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, - kv_cache, - decode_attn_metadata, - attn_type=AttentionType.DECODER) - - # - Decode self-attention correct? - assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) - - cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, - None, kv_cache, decode_attn_metadata) - - # - Decode cross-attention correct? - assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) - - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + # Force Attention wrapper backend + override_backend(monkeypatch, backend_name) + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + encoder_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + self_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + prefill_packed_query, + self_prefill_packed_key, + self_prefill_packed_value, + kv_cache, + prefill_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Prefill self-attention correct? + assert torch.allclose( + self_prefill_packed_ideal_output, + self_prefill_packed_actual_output.view_as( + self_prefill_packed_ideal_output)) + + cross_prefill_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, prefill_packed_query, cross_prefill_packed_key, + cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + + # - Prefill cross-attention correct? + assert torch.allclose( + cross_prefill_packed_ideal_output, + cross_prefill_packed_actual_output.view_as( + cross_prefill_packed_ideal_output)) + + context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + + # DECODE: self- and cross-attention tests + + decode_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + False, + q_seq_lens, + context_lens, + self_decode_block_tables, + self_decode_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(q_seq_lens), + num_prefill_or_decode_tokens=len(q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_decode_block_tables, + cross_slot_mapping=cross_decode_slot_mapping, + ) + + self_decode_packed_actual_output: torch.Tensor = \ + run_encoder_or_decoder_self_attention_test( + attn, + decode_packed_query, + self_decode_packed_key, + self_decode_packed_value, + kv_cache, + decode_attn_metadata, + attn_type=AttentionType.DECODER) + + # - Decode self-attention correct? + assert torch.allclose( + self_decode_packed_ideal_output, + self_decode_packed_actual_output.view_as( + self_decode_packed_ideal_output)) + + cross_decode_packed_actual_output: torch.Tensor = \ + run_encoder_decoder_cross_attention_test( + attn, decode_packed_query, None, + None, kv_cache, decode_attn_metadata) + + # - Decode cross-attention correct? + assert torch.allclose( + cross_decode_packed_ideal_output, + cross_decode_packed_actual_output.view_as( + cross_decode_packed_ideal_output)) + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + decode_attn_metadata.num_prefill_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + None, None, kv_cache, + decode_attn_metadata) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") @@ -1541,7 +1540,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_q_seq_len: int, - max_kv_seq_len: int) -> None: + max_kv_seq_len: int, monkeypatch) -> None: ''' Encoder/decoder not-implemented-for-ROCm-HIP test: @@ -1568,100 +1567,102 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, for cross-attention. ''' - with backend_override_fixture(backend_name): - # Force Attention wrapper backend - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - context_lens = [0 for _ in range(batch_size)] - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - context_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - ) - - with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + # Force Attention wrapper backend + override_backend(monkeypatch, backend_name) + + # Num KV cache blocks + num_blocks = 4096 + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + scale, \ + attn_backend, \ + attn, \ + kv_cache = basic_setup(num_heads, + head_size, + num_blocks, + block_size, + backend_name) + + # Self-attention setup + + self_block_base_addr = 0 + + query, \ + prefill_packed_query, \ + self_prefill_packed_key, \ + self_prefill_packed_value, \ + self_prefill_packed_ideal_output, \ + prefill_q_seq_lens, \ + self_prefill_kv_seq_lens, \ + decode_packed_query, \ + self_decode_packed_key, \ + self_decode_packed_value, \ + self_decode_packed_ideal_output, \ + _, \ + _, \ + q_seq_lens, \ + _, \ + self_decode_block_tables, \ + self_decode_slot_mapping, \ + self_prefill_slot_mapping, \ + self_prefill_block_tables, \ + cross_block_base_addr = decoder_attn_setup(batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + block_base_addr=self_block_base_addr) + + # Cross-attention setup + + cross_prefill_packed_key, \ + cross_prefill_packed_value, \ + cross_prefill_packed_ideal_output, \ + cross_decode_packed_ideal_output, \ + encoder_kv_seq_lens, \ + cross_decode_block_tables, \ + cross_decode_slot_mapping, \ + cross_prefill_slot_mapping, \ + cross_prefill_block_tables, \ + _ = enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr=cross_block_base_addr) + + # PREFILL: self- and cross-attention tests + + context_lens = [0 for _ in range(batch_size)] + + prefill_attn_metadata: AttentionMetadata = make_test_metadata( + attn_backend, + True, + prefill_q_seq_lens, + context_lens, + self_prefill_block_tables, + self_prefill_slot_mapping, + is_encoder_only_test=False, + num_prefills_or_decodes=len(prefill_q_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), + encoder_seq_lens=encoder_kv_seq_lens, + cross_block_tables=cross_prefill_block_tables, + cross_slot_mapping=cross_prefill_slot_mapping, + ) + + with pytest.raises(NotImplementedError) as exc_info: + run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) + + # "Encoder decoder models do not currently support ROCm/HIP" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP From ed8f8b3aa6eba958e0e527510f50aa3cc94f66c4 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 12:40:06 -0400 Subject: [PATCH 151/239] Comments & type hints --- tests/kernels/utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 3874fad57ae43..fb28924c5f9c4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,9 +1,21 @@ """Kernel test utils""" -STR_BACKEND_ENV_VAR = "VLLM_ATTENTION_BACKEND" -STR_FLASH_ATTN_VAL = "FLASH_ATTN" -STR_INVALID_VAL = "INVALID" +import pytest +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch, backend_name): + +def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: + ''' + Override vLLM attention backend temporarily, + using pytest monkeypatch to ensure that the env vars get + reset once the test context exits. + + Arguments: + + * mpatch: pytest monkeypatch instance + * backend_name: attention backend name to force + ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) From 8abe51c8b6091de871600ea83d6fc1837eb4db79 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 13:15:28 -0400 Subject: [PATCH 152/239] small refactors per @sroy745 suggestions --- tests/kernels/test_attention_selector.py | 8 ++++---- tests/kernels/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index ea3ccb026ea2b..79e03c7478de0 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -4,7 +4,7 @@ import torch from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend) + override_backend_env_variable) from vllm.attention.selector import which_attn_to_use @@ -16,7 +16,7 @@ def test_env(name: str, device: str, monkeypatch): Note that we do not test FlashAttn because it is the default backend. """ - override_backend(monkeypatch, name) + override_backend_env_variable(monkeypatch, name) if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): @@ -37,7 +37,7 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" - override_backend(monkeypatch, STR_FLASH_ATTN_VAL) + override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): @@ -72,6 +72,6 @@ def test_flash_attn(monkeypatch): def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" - override_backend(monkeypatch, STR_INVALID_VAL) + override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fb28924c5f9c4..b401eb87d3ec3 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,9 +7,10 @@ STR_INVALID_VAL: str = "INVALID" -def override_backend(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: +def override_backend_env_variable(mpatch: pytest.MonkeyPatch, + backend_name: str) -> None: ''' - Override vLLM attention backend temporarily, + Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits. From da1b64839dcd0c650410d83e520a26bf326d6913 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 13:58:11 -0400 Subject: [PATCH 153/239] merged backend env config --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0a835940c3d75..2bcb0668f83dc 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -14,7 +14,7 @@ import pytest import torch -from tests.kernels.utils import override_backend +from tests.kernels.utils import override_backend_env_variable from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( @@ -1236,7 +1236,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Attention scale factor, attention backend instance, attention wrapper # instance. Encoder attention does not require KV cache. @@ -1342,7 +1342,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Num KV cache blocks num_blocks = 4096 @@ -1568,7 +1568,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, ''' # Force Attention wrapper backend - override_backend(monkeypatch, backend_name) + override_backend_env_variable(monkeypatch, backend_name) # Num KV cache blocks num_blocks = 4096 From 60a21e3cb7e4d4c96d6222be7d1fb0fa01634d47 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 14:35:23 -0400 Subject: [PATCH 154/239] fixed _get_seq_len_block_table_args() to change behavior based on is_prompt --- tests/kernels/test_encoder_decoder_attn.py | 5 +-- vllm/attention/backends/xformers.py | 42 ++++++++++++++-------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2bcb0668f83dc..fa4db135323fd 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -373,6 +373,7 @@ def make_backend(backend_name: str) -> AttentionBackend: def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], + encoder_seq_lens: List[int], device: Union[torch.device, str] = \ CUDA_DEVICE) -> tuple: ''' @@ -621,7 +622,7 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - seq_start_loc = make_metadata_tensors(seq_lens, + _ = make_metadata_tensors(seq_lens, context_lens, device=device) @@ -652,7 +653,7 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - seq_start_loc = make_metadata_tensors(seq_lens, + _ = make_metadata_tensors(seq_lens, context_lens, device=device) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b56e566a89fe1..01272c1d9628a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,8 +148,8 @@ def __post_init__(self): self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - if self.is_all_encoder_attn_metadata_set: - self._maybe_compute_implicit_encoder_attrs() + # if self.is_all_encoder_attn_metadata_set: + # self._maybe_compute_implicit_encoder_attrs() @property def is_all_encoder_attn_metadata_set(self): @@ -194,14 +194,14 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - self._maybe_compute_implicit_encoder_attrs() + # self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - self._maybe_compute_implicit_encoder_attrs() + # self._maybe_compute_implicit_encoder_attrs() self._attn_type = AttentionType.ENCODER else: @@ -346,26 +346,38 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, is_prompt: bool) -> tuple: ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Depends on attn_metadata having a valid attention_type. - + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + Arguments: - * attn_metadata: Attention metadata structure associated with attention + * attn_metadata: Attention metadata structure associated with attention op Returns: - * Appropriate attention bias value + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) ''' attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len return attn_metadata.seq_lens_tensor, \ - attn_metadata.max_decode_seq_len, \ + max_seq_len, \ attn_metadata.block_tables elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; @@ -586,7 +598,7 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta) + block_tables_arg = _get_seq_len_block_table_args(decode_meta, False) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -630,7 +642,7 @@ def _run_memory_efficient_xformers_forward( # (seq_lens or encoder_seq_lens) is set. seq_lens, \ _,\ - _ = _get_seq_len_block_table_args(attn_metadata) + _ = _get_seq_len_block_table_args(attn_metadata, True) assert seq_lens is not None original_query = query From 306ea5b75865ab05c1666a6dc2bab2b6b8eef4c9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 14:46:08 -0400 Subject: [PATCH 155/239] removed inference of encoder metadata attributes; removed guessing of encoder seq len tensor device --- tests/kernels/test_encoder_decoder_attn.py | 33 ++++++++++++++++------ vllm/attention/backends/xformers.py | 24 ++-------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fa4db135323fd..93675880defc5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -397,19 +397,26 @@ def make_metadata_tensors(seq_lens: List[int], * query_start_loc: start idx of each query ''' seq_lens_tensor = None if seq_lens is None else \ - torch.tensor(seq_lens, dtype=torch.int, device=device) + torch.tensor(seq_lens, dtype=torch.int, device=device) context_lens_tensor = None if context_lens is None else torch.tensor( context_lens, dtype=torch.int, device=device) max_context_len = None if context_lens is None else max(context_lens) max_seq_len = None if seq_lens is None else max(seq_lens) + encoder_seq_lens_tensor = None if encoder_seq_lens is None else \ + torch.tensor(encoder_seq_lens, dtype=torch.int, device=device) + max_encoder_seq_len = None if encoder_seq_lens is None else \ + max(encoder_seq_lens) + seq_start_loc = None return seq_lens_tensor, \ context_lens_tensor, \ max_context_len, \ max_seq_len, \ - seq_start_loc + seq_start_loc, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len def make_kv_cache(num_blocks: int, @@ -622,9 +629,12 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - _ = make_metadata_tensors(seq_lens, - context_lens, - device=device) + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -640,6 +650,8 @@ def make_test_metadata( use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) @@ -653,9 +665,12 @@ def make_test_metadata( context_lens_tensor, \ _, \ _, \ - _ = make_metadata_tensors(seq_lens, - context_lens, - device=device) + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -671,6 +686,8 @@ def make_test_metadata( use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, cross_slot_mapping=cross_slot_mapping, cross_block_tables=cross_block_tables) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 01272c1d9628a..75634b10443a2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -148,15 +148,14 @@ def __post_init__(self): self.encoder_attn_bias: Optional[List[AttentionBias]] = None self.cross_attn_bias: Optional[List[AttentionBias]] = None - # if self.is_all_encoder_attn_metadata_set: - # self._maybe_compute_implicit_encoder_attrs() - @property def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return self.encoder_seq_lens is not None + return (self.encoder_seq_lens is not None) and \ + (self.encoder_seq_lens_tensor is not None) and \ + (self.max_encoder_seq_len is not None) @property def is_all_cross_attn_metadata_set(self): @@ -169,19 +168,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - def _maybe_compute_implicit_encoder_attrs(self): - ''' - Encoder attention and cross-attention require some encoder-related - metadata attributes which may or may not be been provided by the user. - This method infers the implicit attributes from provided attributes - ''' - if self.encoder_seq_lens_tensor is None: - self.encoder_seq_lens_tensor = torch.tensor(self.encoder_seq_lens, - dtype=torch.int32, - device="cuda:0") - if self.max_encoder_seq_len is None: - self.max_encoder_seq_len = max(self.encoder_seq_lens) - @property def attention_type(self) -> AttentionType: return self._attn_type @@ -194,15 +180,11 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - # self._maybe_compute_implicit_encoder_attrs() - self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - # self._maybe_compute_implicit_encoder_attrs() - self._attn_type = AttentionType.ENCODER else: # AttentionType.{ENCODER,DECODER} From eda2273ed659e1fd00b23751a43b42e1bc19dd13 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:10:08 -0400 Subject: [PATCH 156/239] wip refactoring --- tests/kernels/test_encoder_decoder_attn.py | 75 ++++++++++++++++------ vllm/attention/backends/xformers.py | 3 +- 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 93675880defc5..163292e9ad4af 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -8,6 +8,7 @@ import copy import itertools +import numbers import random from typing import List, Optional, Union @@ -35,6 +36,46 @@ MAX_K_SEQ_LENS = [128] +def maybe_list_to_int_tensor(_list: List[int], + device: Union[torch.device, str] \ + = CUDA_DEVICE) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_list_to_long_tensor(_list: List[int], + device: Union[torch.device, str] \ + = CUDA_DEVICE) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[numbers.Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ -> torch.Tensor: ''' @@ -396,15 +437,13 @@ def make_metadata_tensors(seq_lens: List[int], * seq_start_loc: start idx of each sequence * query_start_loc: start idx of each query ''' - seq_lens_tensor = None if seq_lens is None else \ - torch.tensor(seq_lens, dtype=torch.int, device=device) - context_lens_tensor = None if context_lens is None else torch.tensor( - context_lens, dtype=torch.int, device=device) - max_context_len = None if context_lens is None else max(context_lens) - max_seq_len = None if seq_lens is None else max(seq_lens) - - encoder_seq_lens_tensor = None if encoder_seq_lens is None else \ - torch.tensor(encoder_seq_lens, dtype=torch.int, device=device) + seq_lens_tensor = maybe_list_to_int_tensor(seq_lens, device) + context_lens_tensor = maybe_list_to_int_tensor(context_lens, device) + max_context_len = maybe_max(context_lens) + max_seq_len = maybe_max(seq_lens) + + encoder_seq_lens_tensor = maybe_list_to_int_tensor(encoder_seq_lens, + device) max_encoder_seq_len = None if encoder_seq_lens is None else \ max(encoder_seq_lens) @@ -547,18 +586,12 @@ def make_block_tables_slot_mapping(block_size: int, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = torch.tensor(prefill_slot_mapping, - dtype=torch.long, - device=device) - decode_slot_mapping_tensor = torch.tensor(decode_slot_mapping, - dtype=torch.long, - device=device) - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.long, - device=device) - empty_slot_mapping_tensor = torch.tensor([], - dtype=torch.long, - device=device) + prefill_slot_mapping_tensor = maybe_list_to_long_tensor( + prefill_slot_mapping, device) + decode_slot_mapping_tensor = maybe_list_to_long_tensor( + decode_slot_mapping, device) + slot_mapping_tensor = maybe_list_to_long_tensor(slot_mapping, device) + empty_slot_mapping_tensor = maybe_list_to_long_tensor([], device) return decode_block_tables_tensor, \ decode_slot_mapping_tensor, \ diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 75634b10443a2..f165f7922017f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -328,7 +328,8 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, f"Invalid attn_metadata.attention_type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, is_prompt: bool) -> tuple: +def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, + is_prompt: bool) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent From 9425f0cd05069338ca5835196000fdcd878e9233 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:53:33 -0400 Subject: [PATCH 157/239] refactored helper functions into diffferent utils files --- tests/kernels/test_encoder_decoder_attn.py | 780 ++------------------- tests/kernels/utils.py | 638 +++++++++++++++++ vllm/utils.py | 61 ++ 3 files changed, 759 insertions(+), 720 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 163292e9ad4af..3c152c8988536 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -7,21 +7,20 @@ """ import copy -import itertools -import numbers -import random -from typing import List, Optional, Union +from typing import List, Optional import pytest import torch -from tests.kernels.utils import override_backend_env_variable +from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, + make_kv_cache, make_qkv, make_test_metadata, + override_backend_env_variable, pack_qkv, + pack_tensor, ref_masked_attention) from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import is_hip, make_tensor_with_pad +from vllm.utils import is_hip, make_causal_mask HEAD_SIZES = [64, 256] @@ -36,695 +35,6 @@ MAX_K_SEQ_LENS = [128] -def maybe_list_to_int_tensor(_list: List[int], - device: Union[torch.device, str] \ - = CUDA_DEVICE) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D int torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D int torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) - -def maybe_list_to_long_tensor(_list: List[int], - device: Union[torch.device, str] \ - = CUDA_DEVICE) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D long torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D long torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) - - -def maybe_max(_list: List) -> Optional[numbers.Number]: - ''' - Returns: - - * If _list is not None: max(_list) - * None otherwise - ''' - return None if _list is None else max(_list) - -def build_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: - ''' - Create a q_max_seq_len x kv_max_seq_len causal mask - - Arguments: - - * q_max_seq_len: query max seq len - * kv_max_seq_len: key/value max seq len - - Returns: - - * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' - - # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) - # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) - return mask - - -def ref_masked_attention(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[List] = None, - kv_seq_lens: Optional[List] = None) -> torch.Tensor: - ''' - "Golden" masked attention reference. Supports two types of masking: - - * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out - padding elements - * Custom attention mask, which can force an arbitrary mask tensor, i.e. - causal - - Arguments: - - * query: batch_size x q_padded_seq_len x num_heads x head_size - * key: batch_size x kv_padded_seq_len x num_heads x head_size - * value: batch_size x kv_padded_seq_len x num_heads x head_size - * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal - attention mask - * q_seq_lens: list of unpadded query seq_lens for each batch index - * kv_seq_lens: list of unpadded key/value seq_lens for each batch index - - Returns: - - * Attention result, batch_size x q_padded_seq_len x num_heads x head_size - ''' - - batch_size = query.shape[0] - assert (len(q_seq_lens) == batch_size) - assert (len(kv_seq_lens) == batch_size) - - attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() - - # Basic attention mask, derived from seq lens - if (q_seq_lens is not None) or (kv_seq_lens is not None): - attn_mask = torch.zeros_like(attn_weights) - if q_seq_lens is not None: - for bdx, plen in enumerate(q_seq_lens): - attn_mask[bdx, :, plen:, :] = -torch.inf - if kv_seq_lens is not None: - for bdx, plen in enumerate(kv_seq_lens): - attn_mask[bdx, :, :, plen:] = -torch.inf - - attn_weights = attn_weights + attn_mask.float() - - # Custom attention mask - if custom_mask is not None: - attn_weights = attn_weights + custom_mask.float() - - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) - return out - - -def make_qkv(batch_size: int, - max_q_seq_len: int, - max_kv_seq_len: int, - num_heads: int, - head_size: int, - attn_type: AttentionType = AttentionType.ENCODER_DECODER, - force_max_len: bool = False, - device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: - ''' - Construct QKV test tensors for self- and cross-attention. - - Generates three query/key/value triplets: - - * "Baseline" query/key/value (for input to reference attention function) - * "Prefill" query/key/value (last sequence offset zero'd out, for use as - input to prefill kernel) - * "Decode" query/key/value (only the last sequence offset from baseline, - for use as input to decode kernel) - - Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v - seqlens - - Arguments: - - * batch_size - * max_q_seq_len: max query seq len - * max_kv_seq_len: max key/value seq len - * num_heads - * head_size - * is_encoder_decoder_attn: if True, query seqlen may differ from - key/value seqlen (as is often the case for cross-attention); - o/w, query/key/value seqlens match at each batch index - (max_kv_seq_len is unused) - * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query - seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens - and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False - * device: CPU or CUDA device - - Returns: - - * query: "baseline" query; batch_size x max_q_seq_len x num_heads x - head_size - * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x - head_size - * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x - head_size - * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * decode_query: batch_size x 1 x num_heads x head_size - * decode_key: batch_size x 1 x num_heads x head_size - * decode_value: batch_size x 1 x num_heads x head_size - * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list - * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= - max_q_seq_len due to randomness) - * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may - be <= max_kv_seq_len due to randomness) - * prefill_q_seq_lens: "prefill" query seqlen list - * prefill_kv_seq_lens: "prefill" key/value seqlen list - * decode_q_seq_lens: "decode" query seqlen list (all ones) - * decode_kv_seq_lens: "decode" key/value seqlen list - ''' - - if force_max_len: - q_seq_lens = [max_q_seq_len for _ in range(batch_size)] - else: - q_seq_lens = [ - random.randint(2, max_q_seq_len) for _ in range(batch_size) - ] - kv_seq_lens = None - if attn_type != AttentionType.ENCODER_DECODER: - # K,V seq lens match Q for self-attention - kv_seq_lens = q_seq_lens - else: - # K,V seq lens are distinct from Q seq lens & random - if force_max_len: - kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] - else: - kv_seq_lens = [ - random.randint(2, max_kv_seq_len) for _ in range(batch_size) - ] - - actual_max_q_seq_len = max(q_seq_lens) - actual_max_kv_seq_len = max(kv_seq_lens) - - query = torch.rand( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - key = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - value = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - decode_query = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) - decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - decode_value = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) - - for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, - kv_seq_lens)): - query[bdx, q_seq_len:, :, :] = 0 - key[bdx, kv_seq_len:, :, :] = 0 - value[bdx, kv_seq_len:, :, :] = 0 - - prefill_query[bdx, - 0:(q_seq_len - 1), :, :] = query[bdx, - 0:(q_seq_len - 1), :, :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :, :] = key[bdx, - 0:(kv_seq_len - 1), :, :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] - - decode_query[bdx, :, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :, :] - decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] - decode_value[bdx, :, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :, :] - - prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] - prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] - - decode_q_seq_lens = [1 for _ in q_seq_lens] - decode_kv_seq_lens = [1 for _ in kv_seq_lens] - - return query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - actual_max_q_seq_len, \ - actual_max_kv_seq_len, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens - - -def pack_tensor(unpacked_tensor: torch.Tensor, - seq_lens: List[int], - device: Union[torch.device, str] = CUDA_DEVICE) -> tuple: - ''' - Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an - unpadded number_of_tokens x num_heads x head_size tensor, where - number_of_tokens = sum(seq_lens) - - Arguments: - - * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size - * seq_lens: list of token counts for each seq - * device: CPU or CUDA device - - Returns - - * packed_tensor: number_of_tokens x num_heads x head_size - * start_loc_list: start idx of each batch elt in packed_tensor; [0] + - list(itertools.accumulate(seq_lens)) - ''' - - num_tok = sum(seq_lens) - num_heads = unpacked_tensor.shape[-2] - head_size = unpacked_tensor.shape[-1] - start_loc_list = [0] + list(itertools.accumulate(seq_lens)) - packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) - - for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): - - packed_tensor[start_loc:( - start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] - - return packed_tensor, start_loc_list - - -def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_seq_lens: List[int], kv_seq_lens: List[int]) -> tuple: - ''' - Individually pack each of Q, K and V, each with dimensions batch_size x - padded_seq_len x num_heads x head_size, into respective number_of_tokens x - num_heads x head_size tensors. - - For Q, number_of_tokens = sum(q_seq_lens). - - For K and V, number_of_tokens = sum(kv_seq_lens) - - Arguments: - - * query: batch_size x padded_seq_len x num_heads x head_size - * key: batch_size x padded_seq_len x num_heads x head_size - * value: batch_size x padded_seq_len x num_heads x head_size - * q_seq_lens: list of token counts for each query - * kv_seq_lens: list of token counts for each key/value - - Returns - - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * q_start_loc_list: start idx of each query in packed_query - * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} - ''' - - if query is None: - packed_query = None - q_start_loc_list = None - else: - packed_query, q_start_loc_list = pack_tensor(query, q_seq_lens) - packed_key, kv_start_loc_list = pack_tensor(key, kv_seq_lens) - packed_value, _ = pack_tensor(value, kv_seq_lens) - return packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list - - -def make_backend(backend_name: str) -> AttentionBackend: - ''' - Construct the backend instance determined by the backend_name string - argument. - - "XFORMERS" -> construct xformers backend - - TODO: other backends - - Note: at time of writing the Attention wrapper automatically selects - its own backend for Attention.forward(); so the backend instance which - you generate with this function is not meant to be used for *running* - inference, but rather for generating compatible metadata structures - using backend.make_metadata() - - - Returns: - - * Backend instance - ''' - if backend_name == "XFORMERS": - return XFormersBackend() - raise AssertionError( - f"Unrecognized backend_name {backend_name} for unit test") - - -def make_metadata_tensors(seq_lens: List[int], - context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str] = \ - CUDA_DEVICE) -> tuple: - ''' - Build scalar & tensor values required to build attention metadata structure. - - Arguments: - - * is_prompt: True -> Prefill, False -> Decode - * seq_lens: list of token-counts for each seq - * context_lens: list of context length values for each seq - * device: CPU or CUDA device - - Returns: - - * seq_lens_tensor: seq_lens list, as tensor - * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(seq_lens) if is_seq, o/w 1 - * max_context_len: max(context_lens) - * max_seq_len: max(seq_lens) - * seq_start_loc: start idx of each sequence - * query_start_loc: start idx of each query - ''' - seq_lens_tensor = maybe_list_to_int_tensor(seq_lens, device) - context_lens_tensor = maybe_list_to_int_tensor(context_lens, device) - max_context_len = maybe_max(context_lens) - max_seq_len = maybe_max(seq_lens) - - encoder_seq_lens_tensor = maybe_list_to_int_tensor(encoder_seq_lens, - device) - max_encoder_seq_len = None if encoder_seq_lens is None else \ - max(encoder_seq_lens) - - seq_start_loc = None - - return seq_lens_tensor, \ - context_lens_tensor, \ - max_context_len, \ - max_seq_len, \ - seq_start_loc, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len - - -def make_kv_cache(num_blocks: int, - num_heads: int, - head_size: int, - block_size: int, - device: Union[torch.device, str] = \ - CUDA_DEVICE, - default_val: float=0.0) -> torch.Tensor: - ''' - Create a fake KV cache. - - Arguments: - - * num_blocks: number of blocks in the KV cache - * num_heads: number of attention heads - * head_size: head dimension - * block_size: number of offsets within a block - * device: CPU or CUDA device - * default_val: initialization value for KV cache elements - - Returns: - - * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) - ''' - - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) - if default_val is not None: - kv_cache[:, :, :] = default_val - return kv_cache - - -def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: - ''' - Compute the minimum number of blocks required to hold num_tokens tokens, - given block_size - ''' - return (num_tokens + block_size) // block_size - - -def make_block_tables_slot_mapping(block_size: int, - seq_lens: List, - block_base_addr: int=0, - device: Union[torch.device, str] = \ - CUDA_DEVICE) -> tuple: - ''' - Construct fake block tables & slot mappings. - - For a sequence with num_tokens tokens the minimum number - of required KV cache blocks is - - num_blocks = (num_tokens + block_size) // block_size - - Then the minimum KV cache size in blocks is - - total_cache_blocks = sum(num_blocks for all seqs) - - Then, the blocktable mapping counts downward from - - block_base_addr + total_cache_blocks - - to - - block_base_addr - - - Arguments: - - * block_size: number of offsets per block - * seq_lens: list of token-counts for each sequence - * block_base_addr: the block table base address - * device: CPU or CUDA device - - Return: - - * decode_block_tables_tensor: fake the state of the block tables during - decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during - decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during - prefill - * prefill_block_tables_tensor: fake the state of the block tables during - prefill - * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase - cross attention) - * max_block_idx: the highest block address within this block table - ''' - - # Provision minimum number of KV cache blocks - num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) - for num_tokens in seq_lens - ] - max_block_table_len = max(num_blocks_list) - block_table_pad_tokens = 10 - - block_tables = [] - prefill_slot_mapping = [] - decode_slot_mapping = [] - slot_mapping = [] - # Compute uppermost address of block table - total_cache_blocks = sum(num_blocks_list) - block_base_idx = block_base_addr + total_cache_blocks - max_block_idx = block_base_idx - for sdx, num_tokens in enumerate(seq_lens): - num_blocks = num_blocks_list[sdx] - block_table = list( - range(block_base_idx, block_base_idx - num_blocks, -1)) - for idx in range(num_tokens): - mapping_value = ( - idx % block_size) + block_table[idx // block_size] * block_size - slot_mapping.append(mapping_value) - if idx < num_tokens - 1: - prefill_slot_mapping.append(mapping_value) - elif idx == num_tokens - 1: - decode_slot_mapping.append(mapping_value) - - block_base_idx -= num_blocks - block_tables.append(block_table) - - prefill_block_tables_tensor = torch.tensor([], device=CUDA_DEVICE) - decode_block_tables_tensor = make_tensor_with_pad( - block_tables, - max_len=max_block_table_len + block_table_pad_tokens, - pad=0, - dtype=torch.int, - device=device, - ) - prefill_slot_mapping_tensor = maybe_list_to_long_tensor( - prefill_slot_mapping, device) - decode_slot_mapping_tensor = maybe_list_to_long_tensor( - decode_slot_mapping, device) - slot_mapping_tensor = maybe_list_to_long_tensor(slot_mapping, device) - empty_slot_mapping_tensor = maybe_list_to_long_tensor([], device) - - return decode_block_tables_tensor, \ - decode_slot_mapping_tensor, \ - prefill_slot_mapping_tensor, \ - prefill_block_tables_tensor, \ - slot_mapping_tensor, \ - empty_slot_mapping_tensor, \ - max_block_idx - - -def make_test_metadata( - attn_backend: AttentionBackend, - is_prompt: bool, - seq_lens: List[int], - context_lens: List[int], - block_tables: torch.Tensor, - slot_mapping: torch.Tensor, - is_encoder_only_test: bool, - num_prefills_or_decodes: int, - num_prefill_or_decode_tokens: int, - device: Union[torch.device, str] = CUDA_DEVICE, - encoder_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None, -) -> AttentionMetadata: - ''' - Construct fake attention metadata for a combined self-/cross-attention - scenario i.e. an encoder/decoder model. - - is_encoder_only_test=True causes the default attention metadata attention - type to be AttentionType.ENCODER. False causes the default to - be AttentionType.DECODER. - - Assumptions: - - * No chunked prefill -> a batch is 100% prefill or 100% decode, never both - - Arguments: - - * attn_backend: Backend for sourcing attention kernels - * is_prompt: prefill if True, o/w decode - * seq_lens: list of token counts for each sequence - * context_lens: list of context lengths for each sequence - * block_tables: self-attention block tables - * slot_mapping: self-attention slot_mapping - * is_encoder_only_test: True if testing encoder; False if testing - decoder self-attention or encoder/decoder cross-attention. - * device: CPU or CUDA device - * encoder_seq_lens: list of token counts for each encoder sequence, if any - exist - * cross_block_tables: cross-attention block tables, if required - * cross_slot_mapping: cross-attention slot mapping, if required - - Return: - - * AttentionMetadata structure supporting self- and cross-attention - ''' - - default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ - else AttentionType.DECODER - - if is_prompt: - num_prefills = num_prefills_or_decodes - num_prefill_tokens = num_prefill_or_decode_tokens - num_decode_tokens = 0 - - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=None if seq_lens is None else max(seq_lens), - max_decode_seq_len=0, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - _attn_type=default_attn_type, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) - - else: # not is_prompt - - num_prefills = 0 - num_prefill_tokens = 0 - num_decode_tokens = num_prefill_or_decode_tokens - - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) - - return attn_backend.make_metadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=0, - max_decode_seq_len=max(seq_lens), - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - _attn_type=default_attn_type, - encoder_seq_lens=encoder_seq_lens, - encoder_seq_lens_tensor=encoder_seq_lens_tensor, - max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) - - def basic_setup(num_heads: int, head_size: int, num_blocks: int, block_size: int, backend_name: str) -> tuple: ''' @@ -761,7 +71,11 @@ def basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, None # Construct KV cache - kv_cache = make_kv_cache(num_blocks, num_heads, head_size, block_size) + kv_cache = make_kv_cache(num_blocks, + num_heads, + head_size, + block_size, + device=CUDA_DEVICE) return scale, attn_backend, attn, kv_cache @@ -830,7 +144,8 @@ def encoder_attn_setup(batch_size: int, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.ENCODER) + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask ideal_output = ref_masked_attention(query, @@ -840,7 +155,9 @@ def encoder_attn_setup(batch_size: int, q_seq_lens=q_seq_lens, kv_seq_lens=kv_seq_lens) - packed_ideal_output, _ = pack_tensor(ideal_output, q_seq_lens) + packed_ideal_output, _ = pack_tensor(ideal_output, + q_seq_lens, + device=CUDA_DEVICE) block_tables, \ _, \ @@ -849,13 +166,17 @@ def encoder_attn_setup(batch_size: int, slot_mapping, \ _, \ _ = make_block_tables_slot_mapping( - block_size, q_seq_lens, block_base_addr=block_base_addr) + block_size, + q_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) packed_query, \ packed_key, \ packed_value, _, _ = pack_qkv( query, key, value, q_seq_lens, - kv_seq_lens) + kv_seq_lens, + device=CUDA_DEVICE) return packed_query, \ packed_key, \ @@ -973,10 +294,11 @@ def decoder_attn_setup(batch_size: int, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.DECODER) + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) - causal_mask = build_causal_mask(max_q_seq_len, - max_kv_seq_len).to(CUDA_DEVICE) + causal_mask = make_causal_mask(max_q_seq_len, + max_kv_seq_len).to(CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, @@ -995,9 +317,11 @@ def decoder_attn_setup(batch_size: int, prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + prefill_q_seq_lens, + device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) decode_block_tables, \ decode_slot_mapping, \ @@ -1006,13 +330,17 @@ def decoder_attn_setup(batch_size: int, _, \ _, \ max_block_idx = make_block_tables_slot_mapping( - block_size, q_seq_lens, block_base_addr=block_base_addr) + block_size, + q_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) prefill_packed_query, \ prefill_packed_key, \ prefill_packed_value, _, _ = pack_qkv( prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens) + prefill_kv_seq_lens, + device=CUDA_DEVICE) decode_packed_query, \ decode_packed_key, \ @@ -1020,7 +348,8 @@ def decoder_attn_setup(batch_size: int, _, \ _ = pack_qkv( decode_query, decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens) + decode_kv_seq_lens, + device=CUDA_DEVICE) return query, \ prefill_packed_query, \ @@ -1138,7 +467,8 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_kv_seq_len, num_heads, head_size, - attn_type=AttentionType.ENCODER_DECODER) + attn_type=AttentionType.ENCODER_DECODER, + device=CUDA_DEVICE) ideal_output = ref_masked_attention(query, key, @@ -1156,9 +486,11 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens) + prefill_q_seq_lens, + device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, - [1 for _ in range(batch_size)]) + [1 for _ in range(batch_size)], + device=CUDA_DEVICE) # Unlike self-attention: # - Prefill slot-mapping includes all key slots @@ -1170,11 +502,18 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, prefill_slot_mapping, \ decode_slot_mapping, \ max_block_idx = make_block_tables_slot_mapping( - block_size, kv_seq_lens, block_base_addr=block_base_addr) + block_size, + kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv(None, key, value, None, - kv_seq_lens) + _, packed_key, packed_value, _, _ = pack_qkv(None, + key, + value, + None, + kv_seq_lens, + device=CUDA_DEVICE) return packed_key, \ packed_value, \ @@ -1337,7 +676,8 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, is_encoder_only_test=True, num_prefills_or_decodes=len(q_seq_lens), num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens) + encoder_seq_lens=q_seq_lens, + device=CUDA_DEVICE) packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1480,7 +820,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, - ) + device=CUDA_DEVICE) self_prefill_packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1526,7 +866,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, - ) + device=CUDA_DEVICE) self_decode_packed_actual_output: torch.Tensor = \ run_encoder_or_decoder_self_attention_test( @@ -1706,7 +1046,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, encoder_seq_lens=encoder_kv_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, - ) + device=CUDA_DEVICE) with pytest.raises(NotImplementedError) as exc_info: run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b401eb87d3ec3..2b752e4cbcd76 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1,6 +1,17 @@ """Kernel test utils""" +import itertools +import random +from typing import List, Optional, Union + import pytest +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) +from vllm.attention.backends.xformers import XFormersBackend +from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, + maybe_make_long_tensor, maybe_max) STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" @@ -20,3 +31,630 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * backend_name: attention backend name to force ''' mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) + + +def ref_masked_attention(query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[List] = None, + kv_seq_lens: Optional[List] = None) -> torch.Tensor: + ''' + "Golden" masked attention reference. Supports two types of masking: + + * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out + padding elements + * Custom attention mask, which can force an arbitrary mask tensor, i.e. + causal + + Arguments: + + * query: batch_size x q_padded_seq_len x num_heads x head_size + * key: batch_size x kv_padded_seq_len x num_heads x head_size + * value: batch_size x kv_padded_seq_len x num_heads x head_size + * scale: Attention scale factor + * Custom mask: custom attention mask; good place to inject a causal + attention mask + * q_seq_lens: list of unpadded query seq_lens for each batch index + * kv_seq_lens: list of unpadded key/value seq_lens for each batch index + + Returns: + + * Attention result, batch_size x q_padded_seq_len x num_heads x head_size + ''' + + batch_size = query.shape[0] + assert (len(q_seq_lens) == batch_size) + assert (len(kv_seq_lens) == batch_size) + + attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() + + # Basic attention mask, derived from seq lens + if (q_seq_lens is not None) or (kv_seq_lens is not None): + attn_mask = torch.zeros_like(attn_weights) + if q_seq_lens is not None: + for bdx, plen in enumerate(q_seq_lens): + attn_mask[bdx, :, plen:, :] = -torch.inf + if kv_seq_lens is not None: + for bdx, plen in enumerate(kv_seq_lens): + attn_mask[bdx, :, :, plen:] = -torch.inf + + attn_weights = attn_weights + attn_mask.float() + + # Custom attention mask + if custom_mask is not None: + attn_weights = attn_weights + custom_mask.float() + + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) + return out + + +def make_qkv( + batch_size: int, + max_q_seq_len: int, + max_kv_seq_len: int, + num_heads: int, + head_size: int, + device: Union[torch.device, str], + attn_type: AttentionType = AttentionType.ENCODER_DECODER, + force_max_len: bool = False, +) -> tuple: + ''' + Construct QKV test tensors for self- and cross-attention. + + Generates three query/key/value triplets: + + * "Baseline" query/key/value (for input to reference attention function) + * "Prefill" query/key/value (last sequence offset zero'd out, for use as + input to prefill kernel) + * "Decode" query/key/value (only the last sequence offset from baseline, + for use as input to decode kernel) + + Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v + seqlens + + Arguments: + + * batch_size + * max_q_seq_len: max query seq len + * max_kv_seq_len: max key/value seq len + * num_heads + * head_size + * is_encoder_decoder_attn: if True, query seqlen may differ from + key/value seqlen (as is often the case for cross-attention); + o/w, query/key/value seqlens match at each batch index + (max_kv_seq_len is unused) + * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query + seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens + and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False + * device: CPU or CUDA device + + Returns: + + * query: "baseline" query; batch_size x max_q_seq_len x num_heads x + head_size + * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x + head_size + * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x + head_size + * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size + * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size + * decode_query: batch_size x 1 x num_heads x head_size + * decode_key: batch_size x 1 x num_heads x head_size + * decode_value: batch_size x 1 x num_heads x head_size + * q_seq_lens: "baseline" query seqlen list + * kv_seq_lens: "baseline" key/value seqlen list + * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= + max_q_seq_len due to randomness) + * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may + be <= max_kv_seq_len due to randomness) + * prefill_q_seq_lens: "prefill" query seqlen list + * prefill_kv_seq_lens: "prefill" key/value seqlen list + * decode_q_seq_lens: "decode" query seqlen list (all ones) + * decode_kv_seq_lens: "decode" key/value seqlen list + ''' + + if force_max_len: + q_seq_lens = [max_q_seq_len for _ in range(batch_size)] + else: + q_seq_lens = [ + random.randint(2, max_q_seq_len) for _ in range(batch_size) + ] + kv_seq_lens = None + if attn_type != AttentionType.ENCODER_DECODER: + # K,V seq lens match Q for self-attention + kv_seq_lens = q_seq_lens + else: + # K,V seq lens are distinct from Q seq lens & random + if force_max_len: + kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] + else: + kv_seq_lens = [ + random.randint(2, max_kv_seq_len) for _ in range(batch_size) + ] + + actual_max_q_seq_len = max(q_seq_lens) + actual_max_kv_seq_len = max(kv_seq_lens) + + query = torch.rand( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + prefill_query = torch.zeros( + (batch_size, max_q_seq_len, num_heads, head_size)).to(device) + prefill_key = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + prefill_value = torch.zeros( + (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + decode_query = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros( + (batch_size, 1, num_heads, head_size)).to(device) + + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, + kv_seq_lens)): + query[bdx, q_seq_len:, :, :] = 0 + key[bdx, kv_seq_len:, :, :] = 0 + value[bdx, kv_seq_len:, :, :] = 0 + + prefill_query[bdx, + 0:(q_seq_len - 1), :, :] = query[bdx, + 0:(q_seq_len - 1), :, :] + prefill_key[bdx, + 0:(kv_seq_len - 1), :, :] = key[bdx, + 0:(kv_seq_len - 1), :, :] + prefill_value[bdx, 0:(kv_seq_len - + 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] + + decode_query[bdx, :, :, :] = query[bdx, + (q_seq_len - 1):q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, + (kv_seq_len - 1):kv_seq_len, :, :] + + prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] + prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] + + decode_q_seq_lens = [1 for _ in q_seq_lens] + decode_kv_seq_lens = [1 for _ in kv_seq_lens] + + return query, \ + key, \ + value, \ + prefill_query, \ + prefill_key, \ + prefill_value, \ + decode_query, \ + decode_key, \ + decode_value, \ + q_seq_lens, \ + kv_seq_lens, \ + actual_max_q_seq_len, \ + actual_max_kv_seq_len, \ + prefill_q_seq_lens, \ + prefill_kv_seq_lens, \ + decode_q_seq_lens, \ + decode_kv_seq_lens + + +def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an + unpadded number_of_tokens x num_heads x head_size tensor, where + number_of_tokens = sum(seq_lens) + + Arguments: + + * unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size + * seq_lens: list of token counts for each seq + * device: CPU or CUDA device + + Returns + + * packed_tensor: number_of_tokens x num_heads x head_size + * start_loc_list: start idx of each batch elt in packed_tensor; [0] + + list(itertools.accumulate(seq_lens)) + ''' + + num_tok = sum(seq_lens) + num_heads = unpacked_tensor.shape[-2] + head_size = unpacked_tensor.shape[-1] + start_loc_list = [0] + list(itertools.accumulate(seq_lens)) + packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) + + for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): + + packed_tensor[start_loc:( + start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + + return packed_tensor, start_loc_list + + +def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + q_seq_lens: List[int], kv_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Individually pack each of Q, K and V, each with dimensions batch_size x + padded_seq_len x num_heads x head_size, into respective number_of_tokens x + num_heads x head_size tensors. + + For Q, number_of_tokens = sum(q_seq_lens). + + For K and V, number_of_tokens = sum(kv_seq_lens) + + Arguments: + + * query: batch_size x padded_seq_len x num_heads x head_size + * key: batch_size x padded_seq_len x num_heads x head_size + * value: batch_size x padded_seq_len x num_heads x head_size + * q_seq_lens: list of token counts for each query + * kv_seq_lens: list of token counts for each key/value + + Returns + + * packed_query: number_of_tokens x num_heads x head_size + * packed_key: number_of_tokens x num_heads x head_size + * packed_value: number_of_tokens x num_heads x head_size + * q_start_loc_list: start idx of each query in packed_query + * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + ''' + + if query is None: + packed_query = None + q_start_loc_list = None + else: + packed_query, q_start_loc_list = pack_tensor(query, + q_seq_lens, + device=device) + packed_key, kv_start_loc_list = pack_tensor(key, + kv_seq_lens, + device=device) + packed_value, _ = pack_tensor(value, kv_seq_lens, device=device) + return packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list + + +def make_backend(backend_name: str) -> AttentionBackend: + ''' + Construct the backend instance determined by the backend_name string + argument. + + "XFORMERS" -> construct xformers backend + + TODO: other backends + + Note: at time of writing the Attention wrapper automatically selects + its own backend for Attention.forward(); so the backend instance which + you generate with this function is not meant to be used for *running* + inference, but rather for generating compatible metadata structures + using backend.make_metadata() + + + Returns: + + * Backend instance + ''' + if backend_name == "XFORMERS": + return XFormersBackend() + raise AssertionError( + f"Unrecognized backend_name {backend_name} for unit test") + + +def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], + encoder_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: + ''' + Build scalar & tensor values required to build attention metadata structure. + + Arguments: + + * is_prompt: True -> Prefill, False -> Decode + * seq_lens: list of token-counts for each seq + * context_lens: list of context length values for each seq + * device: CPU or CUDA device + + Returns: + + * seq_lens_tensor: seq_lens list, as tensor + * context_lens_tensor: context_lens list, as tensor + * max_query_len: max(seq_lens) if is_seq, o/w 1 + * max_context_len: max(context_lens) + * max_seq_len: max(seq_lens) + * seq_start_loc: start idx of each sequence + * query_start_loc: start idx of each query + ''' + seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) + context_lens_tensor = maybe_make_int_tensor(context_lens, device) + max_context_len = maybe_max(context_lens) + max_seq_len = maybe_max(seq_lens) + + encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) + max_encoder_seq_len = None if encoder_seq_lens is None else \ + max(encoder_seq_lens) + + seq_start_loc = None + + return seq_lens_tensor, \ + context_lens_tensor, \ + max_context_len, \ + max_seq_len, \ + seq_start_loc, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len + + +def make_kv_cache(num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + default_val: float = 0.0) -> torch.Tensor: + ''' + Create a fake KV cache. + + Arguments: + + * num_blocks: number of blocks in the KV cache + * num_heads: number of attention heads + * head_size: head dimension + * block_size: number of offsets within a block + * device: CPU or CUDA device + * default_val: initialization value for KV cache elements + + Returns: + + * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) + ''' + + kv_cache = torch.rand( + (2, num_blocks, block_size * num_heads * head_size)).to(device) + if default_val is not None: + kv_cache[:, :, :] = default_val + return kv_cache + + +def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: + ''' + Compute the minimum number of blocks required to hold num_tokens tokens, + given block_size + ''' + return (num_tokens + block_size) // block_size + + +def make_block_tables_slot_mapping(block_size: int, + seq_lens: List, + device: Union[torch.device, str], + block_base_addr: int = 0) -> tuple: + ''' + Construct fake block tables & slot mappings. + + For a sequence with num_tokens tokens the minimum number + of required KV cache blocks is + + num_blocks = (num_tokens + block_size) // block_size + + Then the minimum KV cache size in blocks is + + total_cache_blocks = sum(num_blocks for all seqs) + + Then, the blocktable mapping counts downward from + + block_base_addr + total_cache_blocks + + to + + block_base_addr + + + Arguments: + + * block_size: number of offsets per block + * seq_lens: list of token-counts for each sequence + * block_base_addr: the block table base address + * device: CPU or CUDA device + + Return: + + * decode_block_tables_tensor: fake the state of the block tables during + decode + * decode_slot_mapping_tensor: fake the state of the slot mapping during + decode + * prefill_slot_mapping_tensor: fake the state of the slot mapping during + prefill + * prefill_block_tables_tensor: fake the state of the block tables during + prefill + * slot_mapping_tensor: union of prefill and decode slot mappings + * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase + cross attention) + * max_block_idx: the highest block address within this block table + ''' + + # Provision minimum number of KV cache blocks + num_blocks_list = [ + num_tokens_to_min_blocks(num_tokens, block_size) + for num_tokens in seq_lens + ] + max_block_table_len = max(num_blocks_list) + block_table_pad_tokens = 10 + + block_tables = [] + prefill_slot_mapping = [] + decode_slot_mapping = [] + slot_mapping = [] + # Compute uppermost address of block table + total_cache_blocks = sum(num_blocks_list) + block_base_idx = block_base_addr + total_cache_blocks + max_block_idx = block_base_idx + for sdx, num_tokens in enumerate(seq_lens): + num_blocks = num_blocks_list[sdx] + block_table = list( + range(block_base_idx, block_base_idx - num_blocks, -1)) + for idx in range(num_tokens): + mapping_value = ( + idx % block_size) + block_table[idx // block_size] * block_size + slot_mapping.append(mapping_value) + if idx < num_tokens - 1: + prefill_slot_mapping.append(mapping_value) + elif idx == num_tokens - 1: + decode_slot_mapping.append(mapping_value) + + block_base_idx -= num_blocks + block_tables.append(block_table) + + prefill_block_tables_tensor = torch.tensor([], device=device) + decode_block_tables_tensor = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len + block_table_pad_tokens, + pad=0, + dtype=torch.int, + device=device, + ) + prefill_slot_mapping_tensor = maybe_make_long_tensor( + prefill_slot_mapping, device) + decode_slot_mapping_tensor = maybe_make_long_tensor( + decode_slot_mapping, device) + slot_mapping_tensor = maybe_make_long_tensor(slot_mapping, device) + empty_slot_mapping_tensor = maybe_make_long_tensor([], device) + + return decode_block_tables_tensor, \ + decode_slot_mapping_tensor, \ + prefill_slot_mapping_tensor, \ + prefill_block_tables_tensor, \ + slot_mapping_tensor, \ + empty_slot_mapping_tensor, \ + max_block_idx + + +def make_test_metadata( + attn_backend: AttentionBackend, + is_prompt: bool, + seq_lens: List[int], + context_lens: List[int], + block_tables: torch.Tensor, + slot_mapping: torch.Tensor, + is_encoder_only_test: bool, + num_prefills_or_decodes: int, + num_prefill_or_decode_tokens: int, + device: Union[torch.device, str], + encoder_seq_lens: Optional[List[int]] = None, + cross_block_tables: Optional[torch.Tensor] = None, + cross_slot_mapping: Optional[List[int]] = None, +) -> AttentionMetadata: + ''' + Construct fake attention metadata for a combined self-/cross-attention + scenario i.e. an encoder/decoder model. + + is_encoder_only_test=True causes the default attention metadata attention + type to be AttentionType.ENCODER. False causes the default to + be AttentionType.DECODER. + + Assumptions: + + * No chunked prefill -> a batch is 100% prefill or 100% decode, never both + + Arguments: + + * attn_backend: Backend for sourcing attention kernels + * is_prompt: prefill if True, o/w decode + * seq_lens: list of token counts for each sequence + * context_lens: list of context lengths for each sequence + * block_tables: self-attention block tables + * slot_mapping: self-attention slot_mapping + * is_encoder_only_test: True if testing encoder; False if testing + decoder self-attention or encoder/decoder cross-attention. + * device: CPU or CUDA device + * encoder_seq_lens: list of token counts for each encoder sequence, if any + exist + * cross_block_tables: cross-attention block tables, if required + * cross_slot_mapping: cross-attention slot mapping, if required + + Return: + + * AttentionMetadata structure supporting self- and cross-attention + ''' + + default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ + else AttentionType.DECODER + + if is_prompt: + num_prefills = num_prefills_or_decodes + num_prefill_tokens = num_prefill_or_decode_tokens + num_decode_tokens = 0 + + seq_lens_tensor, \ + context_lens_tensor, \ + _, \ + _, \ + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=None if seq_lens is None else max(seq_lens), + max_decode_seq_len=0, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + _attn_type=default_attn_type, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) + + else: # not is_prompt + + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = num_prefill_or_decode_tokens + + seq_lens_tensor, \ + context_lens_tensor, \ + _, \ + _, \ + _, \ + encoder_seq_lens_tensor, \ + max_encoder_seq_len = make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) + + return attn_backend.make_metadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=max(seq_lens), + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + _attn_type=default_attn_type, + encoder_seq_lens=encoder_seq_lens, + encoder_seq_lens_tensor=encoder_seq_lens_tensor, + max_encoder_seq_len=max_encoder_seq_len, + cross_slot_mapping=cross_slot_mapping, + cross_block_tables=cross_block_tables) diff --git a/vllm/utils.py b/vllm/utils.py index 2781eceb7ba98..1986ba2b3d8c6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -12,6 +12,7 @@ import warnings from collections import defaultdict from functools import lru_cache, partial, wraps +from numbers import Number from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -674,3 +675,63 @@ def inner(*args, **kwargs): return inner # type: ignore return wrapper + +def maybe_make_int_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_make_long_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + +def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: + ''' + Create a q_max_seq_len x kv_max_seq_len causal mask + + Arguments: + + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len + + Returns: + + * 2D tensor, q_max_seq_len x kv_max_seq_len + ''' + + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) + return mask \ No newline at end of file From 62fb8d1a63cc1f501d5ee22948dcc4853f26df50 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:55:25 -0400 Subject: [PATCH 158/239] _ for private functions in test_encoder_decoder_attn --- tests/kernels/test_encoder_decoder_attn.py | 42 +++++++++++----------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3c152c8988536..33149cf38e866 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -35,7 +35,7 @@ MAX_K_SEQ_LENS = [128] -def basic_setup(num_heads: int, head_size: int, num_blocks: int, +def _basic_setup(num_heads: int, head_size: int, num_blocks: int, block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -79,7 +79,7 @@ def basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, kv_cache -def encoder_attn_setup(batch_size: int, +def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, block_size: int, @@ -187,7 +187,7 @@ def encoder_attn_setup(batch_size: int, q_seq_lens -def decoder_attn_setup(batch_size: int, +def _decoder_attn_setup(batch_size: int, num_heads: int, head_size: int, block_size: int, @@ -373,7 +373,7 @@ def decoder_attn_setup(batch_size: int, max_block_idx -def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, +def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, q_seq_lens: List, prefill_q_seq_lens: List, batch_size: int, @@ -527,7 +527,7 @@ def enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, max_block_idx -def run_encoder_or_decoder_self_attention_test( +def _run_encoder_or_decoder_self_attention_test( attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, @@ -562,7 +562,7 @@ def run_encoder_or_decoder_self_attention_test( attn_metadata) -def run_encoder_decoder_cross_attention_test( +def _run_encoder_decoder_cross_attention_test( attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, packed_value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -633,7 +633,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, scale, \ attn_backend, \ attn, \ - _ = basic_setup(num_heads, + _ = _basic_setup(num_heads, head_size, None, None, @@ -651,7 +651,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, packed_ideal_output, \ block_tables, \ slot_mapping, \ - q_seq_lens = encoder_attn_setup(batch_size, + q_seq_lens = _encoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -680,7 +680,7 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, device=CUDA_DEVICE) packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, packed_query, packed_key, @@ -743,7 +743,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + kv_cache = _basic_setup(num_heads, head_size, num_blocks, block_size, @@ -772,7 +772,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, + cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -791,7 +791,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, + _ = _enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -823,7 +823,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( device=CUDA_DEVICE) self_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, @@ -839,7 +839,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_ideal_output)) cross_prefill_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( + _run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, prefill_attn_metadata) @@ -869,7 +869,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( device=CUDA_DEVICE) self_decode_packed_actual_output: torch.Tensor = \ - run_encoder_or_decoder_self_attention_test( + _run_encoder_or_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -885,7 +885,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_ideal_output)) cross_decode_packed_actual_output: torch.Tensor = \ - run_encoder_decoder_cross_attention_test( + _run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) @@ -912,7 +912,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, decode_packed_query, + _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, None, None, kv_cache, decode_attn_metadata) @@ -969,7 +969,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, scale, \ attn_backend, \ attn, \ - kv_cache = basic_setup(num_heads, + kv_cache = _basic_setup(num_heads, head_size, num_blocks, block_size, @@ -998,7 +998,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_decode_slot_mapping, \ self_prefill_slot_mapping, \ self_prefill_block_tables, \ - cross_block_base_addr = decoder_attn_setup(batch_size, + cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, block_size, @@ -1017,7 +1017,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = enc_dec_cross_attn_setup_reuses_query(query, + _ = _enc_dec_cross_attn_setup_reuses_query(query, q_seq_lens, prefill_q_seq_lens, batch_size, @@ -1049,7 +1049,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, device=CUDA_DEVICE) with pytest.raises(NotImplementedError) as exc_info: - run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, + _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, cross_prefill_packed_key, cross_prefill_packed_value, kv_cache, From 2730daa37de2ea8ef9531f906661c5592c9b1307 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:58:46 -0400 Subject: [PATCH 159/239] _ refactor --- tests/kernels/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2b752e4cbcd76..54ffa4ff0e700 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -352,7 +352,7 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def make_metadata_tensors(seq_lens: List[int], context_lens: List[int], +def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], device: Union[torch.device, str]) -> tuple: ''' @@ -425,7 +425,7 @@ def make_kv_cache(num_blocks: int, return kv_cache -def num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: +def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' Compute the minimum number of blocks required to hold num_tokens tokens, given block_size @@ -483,7 +483,7 @@ def make_block_tables_slot_mapping(block_size: int, # Provision minimum number of KV cache blocks num_blocks_list = [ - num_tokens_to_min_blocks(num_tokens, block_size) + _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) @@ -599,7 +599,7 @@ def make_test_metadata( _, \ _, \ encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, + max_encoder_seq_len = _make_metadata_tensors(seq_lens, context_lens, encoder_seq_lens, device=device) @@ -635,7 +635,7 @@ def make_test_metadata( _, \ _, \ encoder_seq_lens_tensor, \ - max_encoder_seq_len = make_metadata_tensors(seq_lens, + max_encoder_seq_len = _make_metadata_tensors(seq_lens, context_lens, encoder_seq_lens, device=device) From 5face2ab0e6f77fc3ee98f307369e848fdea1a4f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 15:59:26 -0400 Subject: [PATCH 160/239] formatting --- tests/kernels/test_encoder_decoder_attn.py | 38 +++++++++++----------- tests/kernels/utils.py | 4 +-- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 33149cf38e866..c1ff1327af423 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -36,7 +36,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, - block_size: int, backend_name: str) -> tuple: + block_size: int, backend_name: str) -> tuple: ''' Compute & build entities required for the self-/cross-attention test. @@ -80,12 +80,12 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for encoder attention test. @@ -188,12 +188,12 @@ def _encoder_attn_setup(batch_size: int, def _decoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0) -> tuple: ''' Set up test vectors & data structures for self-attention test. @@ -913,8 +913,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decode_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + None, None, kv_cache, + decode_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -1050,10 +1050,10 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) + cross_prefill_packed_key, + cross_prefill_packed_value, + kv_cache, + prefill_attn_metadata) # "Encoder decoder models do not currently support ROCm/HIP" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 54ffa4ff0e700..b7951d4b5da28 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -353,8 +353,8 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: + encoder_seq_lens: List[int], + device: Union[torch.device, str]) -> tuple: ''' Build scalar & tensor values required to build attention metadata structure. From f39155a3b981fab2d53a18394a20211ce9d51dab Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 3 Jun 2024 18:17:01 -0400 Subject: [PATCH 161/239] constructing attn md with minimum number of arguments --- tests/kernels/test_encoder_decoder_attn.py | 7 +----- tests/kernels/utils.py | 2 +- vllm/attention/backends/xformers.py | 28 ++++++++++------------ 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c1ff1327af423..85a29f61ab159 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -669,7 +669,6 @@ def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - None, context_lens, block_tables, slot_mapping, @@ -805,13 +804,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # PREFILL: self- and cross-attention tests - context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, prefill_q_seq_lens, - context_lens, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, @@ -857,10 +853,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, False, q_seq_lens, - context_lens, self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, + context_lens=context_lens, num_prefills_or_decodes=len(q_seq_lens), num_prefill_or_decode_tokens=len(q_seq_lens), encoder_seq_lens=encoder_kv_seq_lens, @@ -1037,7 +1033,6 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, attn_backend, True, prefill_q_seq_lens, - context_lens, self_prefill_block_tables, self_prefill_slot_mapping, is_encoder_only_test=False, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b7951d4b5da28..74b33aaec5d2f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -541,13 +541,13 @@ def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], - context_lens: List[int], block_tables: torch.Tensor, slot_mapping: torch.Tensor, is_encoder_only_test: bool, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, device: Union[torch.device, str], + context_lens: Optional[List[int]] = None, encoder_seq_lens: Optional[List[int]] = None, cross_block_tables: Optional[torch.Tensor] = None, cross_slot_mapping: Optional[List[int]] = None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index f165f7922017f..81e31d4c38c2b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -67,9 +67,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] @@ -93,6 +91,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -204,7 +206,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - assert self.context_lens_tensor is not None + #assert self.context_lens_tensor is not None assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ @@ -216,19 +218,20 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=0, slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=None - if self.seq_lens is None else self.seq_lens[:self.num_prefills], + if self.seq_lens is None \ + else self.seq_lens[:self.num_prefills], seq_lens_tensor=None if self.seq_lens_tensor is None else self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, - seq_start_loc=None, - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + context_lens_tensor=None if self.context_lens_tensor is None else \ + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, - # Begin cross-attention fields below... + # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, @@ -254,19 +257,14 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - seq_lens=None, seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin cross-attention fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From c7edbc6d962f702fc2d5c996e773a6fd8e121e59 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 09:24:21 -0400 Subject: [PATCH 162/239] Formatting --- tests/kernels/test_encoder_decoder_attn.py | 2 -- vllm/attention/backends/xformers.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 85a29f61ab159..ec8945f1c7257 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1027,8 +1027,6 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, # PREFILL: self- and cross-attention tests - context_lens = [0 for _ in range(batch_size)] - prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 81e31d4c38c2b..6d4aff2bba37d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -258,7 +258,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, block_tables=self.block_tables[self.num_prefills:], From b023557e87b3a826f92aa8ff39896e23974990b8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 12:14:28 -0400 Subject: [PATCH 163/239] typing and formatting --- vllm/attention/backends/xformers.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6d4aff2bba37d..6948bfea33ce3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -273,7 +273,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return self._cached_decode_metadata def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ - Optional[List[Optional[AttentionBias]]]: + Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -621,10 +621,10 @@ def _run_memory_efficient_xformers_forward( # Enforce that the appropriate *_seq_lens attribute of attn_metadata # (seq_lens or encoder_seq_lens) is set. - seq_lens, \ - _,\ - _ = _get_seq_len_block_table_args(attn_metadata, True) - assert seq_lens is not None + # seq_lens, \ + # _,\ + # _ = _get_seq_len_block_table_args(attn_metadata, True) + # assert seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: @@ -648,15 +648,22 @@ def _run_memory_efficient_xformers_forward( if self.alibi_slopes is None: if attn_metadata.attention_type == \ AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: if attn_metadata.attention_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + # Default encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.encoder_seq_lens) else: + assert attn_metadata.seq_lens is not None + # Default decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) @@ -665,6 +672,7 @@ def _run_memory_efficient_xformers_forward( self.sliding_window) attn_bias = [attn_bias] else: + assert attn_metadata.seq_lens is not None attn_bias = _make_alibi_bias(self.alibi_slopes, self.num_kv_heads, query.dtype, attn_metadata.seq_lens) @@ -692,6 +700,7 @@ def _run_memory_efficient_xformers_forward( # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None output = torch.empty_like(original_query) start = 0 for i, seq_len in enumerate(attn_metadata.seq_lens): From af0c0b94d7e374159ed343b9c125497a94571d6c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 18:27:02 -0400 Subject: [PATCH 164/239] refactored block table/slot mapping construction process for decoder into two steps --- tests/kernels/test_encoder_decoder_attn.py | 42 +++++++-- tests/kernels/utils.py | 105 ++++++++++++++------- 2 files changed, 102 insertions(+), 45 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ec8945f1c7257..42198a4e162db 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -15,7 +15,10 @@ from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention) + pack_tensor, ref_masked_attention, + make_empty_slot_mapping_tensor, + make_empty_block_tables_tensor, + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -323,17 +326,36 @@ def _decoder_attn_setup(batch_size: int, [1 for _ in range(batch_size)], device=CUDA_DEVICE) + # Build prefill- & decode-phase data structures + # for decoder self-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Prefill: + # + # * Empty block-tables tensor + # * Slot-mapping with entries for prompt tokens + # + # Decode: + # * Block-tables tensor with minimum number of blocks + # required by total num. tokens in the entirety of all sequences + # (including both prefill & decode) + # * Slot-mapping with entries for tokens that will be decoded in the + # current decode iteration + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + decode_block_tables, \ - decode_slot_mapping, \ + slot_mapping_list, \ + max_block_idx = make_block_tables_slot_mapping(block_size, + q_seq_lens, + device=CUDA_DEVICE, + block_base_addr = block_base_addr) + prefill_slot_mapping, \ - prefill_block_tables, \ - _, \ - _, \ - max_block_idx = make_block_tables_slot_mapping( - block_size, - q_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) + decode_slot_mapping = split_slot_mapping(slot_mapping_list, + q_seq_lens, + device=CUDA_DEVICE) prefill_packed_query, \ prefill_packed_key, \ diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 74b33aaec5d2f..0ac052c42a75a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -432,9 +432,70 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' return (num_tokens + block_size) // block_size +def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): + return maybe_make_long_tensor([], device) + +def make_empty_block_tables_tensor(device: Union[torch.device, str]): + return torch.tensor([], device=device) + +def split_slot_mapping(slot_mapping_list: torch.Tensor, + seq_lens: List[int], + device: Union[torch.device, str]): + ''' + Split a slot mapping into valid prefill- and decode-phase slot mappings. + + Context: + * Your goal is to test (1) prefill of N prompts, with prompt-lengths + {K_i \forall i \in [0,N)}, followed by (2) decoding of a single token + for all N prompts (N tokens total); the resultant sequence lengths + after decode would be {K_i + 1 for i \in [0,N)} + * The test you want to do requires (1) having the prefill slot mapping + for all tokens present during prefill, the number of which is + M = \sum_i{K_i}, and (2) having the decode slot mapping for all N + decoded tokens + + This function consumes a single 1D slot mapping, which is the + concatenation of N slot mappings each of length K_i + 1 (corresponding + to the sequence lengths after decode), with a total length of + P = \sum_i{K_i + 1} = M + N + + The prefill-phase slot mapping results from excising the (K_i + 1)-th entry + from each of the N subsequences in the slot mapping (i.e. omitting the + decoded token's mapping.) + + The N excised entries are appended to obtain the decode-phase slot mapping + + Arguments: + + * slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N + post-decode sequences + * seq_lens: List of N post-decode sequence lengths (K_i + 1 in the + description above) + * device: cuda, cpu, etc. + + Returns: + + * prefill_slot_mapping: Length-M 1D slot mapping (as Tensor) + reflecting all N prefill prompts + * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting + all N decoded tokens + ''' + + prefill_slot_mapping = [] + decode_slot_mapping = [] + + base_idx=0 + for seq_len in seq_lens: + prefill_slot_mapping.extend( + slot_mapping_list[range(base_idx,base_idx+seq_len-1)]) + decode_slot_mapping.append(slot_mapping_list[base_idx+seq_len-1]) + base_idx += seq_len + + return maybe_make_long_tensor(prefill_slot_mapping, device), \ + maybe_make_long_tensor(decode_slot_mapping, device) def make_block_tables_slot_mapping(block_size: int, - seq_lens: List, + seq_lens: List[int], device: Union[torch.device, str], block_base_addr: int = 0) -> tuple: ''' @@ -467,17 +528,8 @@ def make_block_tables_slot_mapping(block_size: int, Return: - * decode_block_tables_tensor: fake the state of the block tables during - decode - * decode_slot_mapping_tensor: fake the state of the slot mapping during - decode - * prefill_slot_mapping_tensor: fake the state of the slot mapping during - prefill - * prefill_block_tables_tensor: fake the state of the block tables during - prefill - * slot_mapping_tensor: union of prefill and decode slot mappings - * empty_slot_mapping_tensor: empty slot mapping (useful for decode phase - cross attention) + * block_tables_tensor: block table for sequence + * slot_mapping_list: slot mapping for sequence * max_block_idx: the highest block address within this block table ''' @@ -490,9 +542,7 @@ def make_block_tables_slot_mapping(block_size: int, block_table_pad_tokens = 10 block_tables = [] - prefill_slot_mapping = [] - decode_slot_mapping = [] - slot_mapping = [] + slot_mapping_list = [] # Compute uppermost address of block table total_cache_blocks = sum(num_blocks_list) block_base_idx = block_base_addr + total_cache_blocks @@ -504,36 +554,21 @@ def make_block_tables_slot_mapping(block_size: int, for idx in range(num_tokens): mapping_value = ( idx % block_size) + block_table[idx // block_size] * block_size - slot_mapping.append(mapping_value) - if idx < num_tokens - 1: - prefill_slot_mapping.append(mapping_value) - elif idx == num_tokens - 1: - decode_slot_mapping.append(mapping_value) + slot_mapping_list.append(mapping_value) block_base_idx -= num_blocks block_tables.append(block_table) - prefill_block_tables_tensor = torch.tensor([], device=device) - decode_block_tables_tensor = make_tensor_with_pad( + block_tables_tensor = make_tensor_with_pad( block_tables, max_len=max_block_table_len + block_table_pad_tokens, pad=0, dtype=torch.int, device=device, ) - prefill_slot_mapping_tensor = maybe_make_long_tensor( - prefill_slot_mapping, device) - decode_slot_mapping_tensor = maybe_make_long_tensor( - decode_slot_mapping, device) - slot_mapping_tensor = maybe_make_long_tensor(slot_mapping, device) - empty_slot_mapping_tensor = maybe_make_long_tensor([], device) - - return decode_block_tables_tensor, \ - decode_slot_mapping_tensor, \ - prefill_slot_mapping_tensor, \ - prefill_block_tables_tensor, \ - slot_mapping_tensor, \ - empty_slot_mapping_tensor, \ + + return block_tables_tensor, \ + slot_mapping_list, \ max_block_idx From 50bca0886cf988dfc43712d3ad65ab9adce9c228 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 19:01:18 -0400 Subject: [PATCH 165/239] finished breaking block table/slot mapping construction into steps; formatting --- tests/kernels/test_encoder_decoder_attn.py | 62 ++++++++++++++-------- tests/kernels/utils.py | 15 +++--- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 42198a4e162db..fdf48f12e33aa 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -13,17 +13,17 @@ import torch from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, - make_kv_cache, make_qkv, make_test_metadata, + make_empty_block_tables_tensor, + make_empty_slot_mapping_tensor, make_kv_cache, + make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - make_empty_slot_mapping_tensor, - make_empty_block_tables_tensor, split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -from vllm.utils import is_hip, make_causal_mask +from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -163,11 +163,7 @@ def _encoder_attn_setup(batch_size: int, device=CUDA_DEVICE) block_tables, \ - _, \ - _, \ - _, \ slot_mapping, \ - _, \ _ = make_block_tables_slot_mapping( block_size, q_seq_lens, @@ -331,12 +327,12 @@ def _decoder_attn_setup(batch_size: int, # slot mapping must be in a format compatible # with KV caching & attention kernels # - # Prefill: - # + # Prefill-phase: + # # * Empty block-tables tensor # * Slot-mapping with entries for prompt tokens # - # Decode: + # Decode-phase: # * Block-tables tensor with minimum number of blocks # required by total num. tokens in the entirety of all sequences # (including both prefill & decode) @@ -353,8 +349,8 @@ def _decoder_attn_setup(batch_size: int, block_base_addr = block_base_addr) prefill_slot_mapping, \ - decode_slot_mapping = split_slot_mapping(slot_mapping_list, - q_seq_lens, + decode_slot_mapping = split_slot_mapping(slot_mapping_list, + q_seq_lens, device=CUDA_DEVICE) prefill_packed_query, \ @@ -514,21 +510,45 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, [1 for _ in range(batch_size)], device=CUDA_DEVICE) - # Unlike self-attention: - # - Prefill slot-mapping includes all key slots - # - Decode slot-mapping is empty + # Build prefill- & decode-phase data structures + # for encoder/decoder cross-attention. Block tables and + # slot mapping must be in a format compatible + # with KV caching & attention kernels + # + # Whereas decoder self-attention extracts relationships between + # equal-length Q/K/V sequences, which mutually grow in length + # with each decoded token, cross-attention relates the Q sequence + # - which grows with each new decoded token - to fixed-length + # K and V sequences derived from the encoder hidden states. + # + # Prefill-phase: + # + # * Empty block-tables tensor + # * Slot-mapping with as many entries as there are tokens in the encoder + # prompt. + # + # Decode-phase: + # * Block-tables tensor with minimum number of blocks to + # accommodate K & V tensors which are equal in lnegth + # to the encoder prompt length + # * Empty slot-mapping tensor (since K & V are fixed in size, + # new decoded tokens are not KV-cached and require no slot- + # mapping) + + prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) + decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) + decode_block_tables, \ - _, \ - _, \ - prefill_block_tables, \ - prefill_slot_mapping, \ - decode_slot_mapping, \ + prefill_slot_mapping_list, \ max_block_idx = make_block_tables_slot_mapping( block_size, kv_seq_lens, block_base_addr=block_base_addr, device=CUDA_DEVICE) + prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, + device=CUDA_DEVICE) + # Packed key/value (query is already provided) _, packed_key, packed_value, _, _ = pack_qkv(None, key, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 0ac052c42a75a..575e97f24c0c4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -432,14 +432,16 @@ def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: ''' return (num_tokens + block_size) // block_size + def make_empty_slot_mapping_tensor(device: Union[torch.device, str]): return maybe_make_long_tensor([], device) + def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, - seq_lens: List[int], + +def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], device: Union[torch.device, str]): ''' Split a slot mapping into valid prefill- and decode-phase slot mappings. @@ -484,16 +486,17 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, prefill_slot_mapping = [] decode_slot_mapping = [] - base_idx=0 + base_idx = 0 for seq_len in seq_lens: - prefill_slot_mapping.extend( - slot_mapping_list[range(base_idx,base_idx+seq_len-1)]) - decode_slot_mapping.append(slot_mapping_list[base_idx+seq_len-1]) + prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + + seq_len - 1)]) + decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len return maybe_make_long_tensor(prefill_slot_mapping, device), \ maybe_make_long_tensor(decode_slot_mapping, device) + def make_block_tables_slot_mapping(block_size: int, seq_lens: List[int], device: Union[torch.device, str], From 90610daa1b4163872d5a243b95d7127309e1d91e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 19:19:32 -0400 Subject: [PATCH 166/239] slight refactor --- tests/kernels/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 575e97f24c0c4..ddeafbc654073 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -522,6 +522,10 @@ def make_block_tables_slot_mapping(block_size: int, block_base_addr + The constructed block-tables and slot-mapping are sized to the + lengths of the sequences in their entirety (as reflected by seq_lens), + i.e. the total of prefill prompt tokens + decoded tokens. + Arguments: * block_size: number of offsets per block From a006cc892b9438d77f8d3f992075bc73dfcb9bb5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 22:18:00 -0400 Subject: [PATCH 167/239] refactored encoder test into the cross-attention test --- tests/kernels/test_encoder_decoder_attn.py | 353 +++++++++------------ tests/kernels/utils.py | 12 +- vllm/attention/backends/xformers.py | 14 +- 3 files changed, 169 insertions(+), 210 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fdf48f12e33aa..7d9d6d32c0c40 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -34,8 +34,8 @@ BACKEND_NAMES = ["XFORMERS"] CUDA_DEVICE = "cuda:0" -MAX_Q_SEQ_LENS = [128] -MAX_K_SEQ_LENS = [128] +MAX_DEC_SEQ_LENS = [128] +MAX_ENC_SEQ_LENS = [128] def _basic_setup(num_heads: int, head_size: int, num_blocks: int, @@ -82,13 +82,8 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, return scale, attn_backend, attn, kv_cache -def _encoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: +def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, + scale: float, max_q_seq_len: int) -> tuple: ''' Set up test vectors & data structures for encoder attention test. @@ -162,14 +157,6 @@ def _encoder_attn_setup(batch_size: int, q_seq_lens, device=CUDA_DEVICE) - block_tables, \ - slot_mapping, \ - _ = make_block_tables_slot_mapping( - block_size, - q_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) - packed_query, \ packed_key, \ packed_value, _, _ = pack_qkv( @@ -181,8 +168,6 @@ def _encoder_attn_setup(batch_size: int, packed_key, \ packed_value, \ packed_ideal_output, \ - block_tables, \ - slot_mapping, \ q_seq_lens @@ -380,10 +365,7 @@ def _decoder_attn_setup(batch_size: int, decode_packed_key, \ decode_packed_value, \ decode_packed_ideal_output, \ - decode_q_seq_lens, \ - decode_kv_seq_lens, \ q_seq_lens, \ - kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ @@ -392,15 +374,16 @@ def _decoder_attn_setup(batch_size: int, def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, - q_seq_lens: List, + decoder_seq_lens: List[int], + encoder_seq_lens: Optional[List[int]], prefill_q_seq_lens: List, batch_size: int, num_heads: int, head_size: int, block_size: int, scale: float, - max_q_seq_len: int, - max_kv_seq_len: int, + max_decoder_seq_len: int, + max_encoder_seq_len: int, block_base_addr: Optional[int]=0) \ -> tuple: ''' @@ -481,10 +464,11 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, _, \ _, \ _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, + max_decoder_seq_len, + max_encoder_seq_len, num_heads, head_size, + force_kv_seq_lens=encoder_seq_lens, attn_type=AttentionType.ENCODER_DECODER, device=CUDA_DEVICE) @@ -492,7 +476,7 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, key, value, scale=scale, - q_seq_lens=q_seq_lens, + q_seq_lens=decoder_seq_lens, kv_seq_lens=kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) @@ -561,28 +545,57 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, packed_value, \ prefill_packed_ideal_output, \ decode_packed_ideal_output, \ - kv_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ - prefill_block_tables, \ - max_block_idx + prefill_block_tables -def _run_encoder_or_decoder_self_attention_test( - attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: +def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: ''' - Run encoder attention or decoder self-attention test. + Run encoder attention. attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for either encoder or decoder self-attention. + the kernel invocation for either encoder attention - attn_type must be AttentionType.ENCODER or DECODER; if ENCODER, - attn_metadata.num_decode_tokens must be 0 (i.e. there is no such thing as - "decode-phase encoder attention".) + attn_type must be AttentionType.ENCODER + + Arguments: + + * attn: Attention wrapper instance + * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * attn_metadata: attention metadata for encoder/decoder-self attention + * attn_type: AttentionType.DECODER or AttentionType.ENCODER + + Returns: + * Attention.forward() applied to packed_{query,key,value}, kv_cache + & attn_metadata + ''' + assert attn_type == AttentionType.ENCODER + assert attn_metadata.num_decode_tokens == 0 + attn_metadata.attention_type = attn_type + return attn.forward(packed_query, packed_key, packed_value, None, + attn_metadata) + + +def _run_decoder_self_attention_test(attn: Attention, + packed_query: torch.Tensor, + packed_key: torch.Tensor, + packed_value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + attn_type: AttentionType) -> torch.Tensor: + ''' + Run decoder self-attention test. + + attn_metadata.attention_type is assigned attn_type in order to configure + the kernel invocation for decoder self-attention. + + attn_type must be AttentionType.DECODER Arguments: @@ -596,9 +609,7 @@ def _run_encoder_or_decoder_self_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - assert attn_type in [AttentionType.DECODER, AttentionType.ENCODER] - assert attn_metadata.num_decode_tokens == 0 or \ - attn_type != AttentionType.ENCODER + assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type return attn.forward(packed_query, packed_key, packed_value, kv_cache, attn_metadata) @@ -637,115 +648,11 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_seq_len", MAX_Q_SEQ_LENS) -def test_encoder_attention(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, max_seq_len: int, - monkeypatch) -> None: - ''' - Encoder-only attention test: - - * Construct fake test vectors for encoder attention - * Construct attention metadata structure with encoder-attention- - specific attributes - * Run encoder attention with metadata structure & test vectors - * Validate output correctness against ideal reference attention - implementation - - Encoder attention (by default) does not restrict which sequence offsets - may attend to each other. Thus the reference ideal reference - implementation does not employ a causal attention mask. - - Encoder attention does not utilize KV cache however the XFormer backend - requires block_tables & slot_mapping to be non-None and have a valid - structure, thus this test constructs dummy memory-mapping structures. - - Encoder attention is basically structured like decoder self-attention - in that Q/K/V are all derived from the previous layer output & have - the same sequence length (in contrast to encoder/decoder cross- - attention where K/V are drawn from the encoder hidden states and - may have a different length than Q derived from decoder previous - layer output.) - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - # Attention scale factor, attention backend instance, attention wrapper - # instance. Encoder attention does not require KV cache. - scale, \ - attn_backend, \ - attn, \ - _ = _basic_setup(num_heads, - head_size, - None, - None, - backend_name) - - # Self-attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - block_tables, \ - slot_mapping, \ - q_seq_lens = _encoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_seq_len) - - context_lens = [0 for _ in range(batch_size)] - - # Metadata config for encoder attention: - # - # * Use prefill kernel - # * Signal that this is an encoder-only test so that - # metadata attention_type property is correctly - # configured as AttentionType.ENCODER - attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - context_lens, - block_tables, - slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=sum(q_seq_lens), - encoder_seq_lens=q_seq_lens, - device=CUDA_DEVICE) - - packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( - attn, - packed_query, - packed_key, - packed_value, - None, - attn_metadata, - attn_type=AttentionType.ENCODER) - - # - Is encoder attention result correct? - assert torch.allclose(packed_ideal_output, - packed_actual_output.view_as(packed_ideal_output)) - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_q_seq_len", MAX_Q_SEQ_LENS) -@pytest.mark.parametrize("max_kv_seq_len", MAX_K_SEQ_LENS) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, max_kv_seq_len: int, + block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -790,25 +697,37 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( block_size, backend_name) - # Self-attention setup + # Encoder attention setup - self_block_base_addr = 0 + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + enc_packed_query, \ + enc_packed_key, \ + enc_packed_value, \ + enc_packed_ideal_output, \ + encoder_seq_lens = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) + + # Decoder self-attention setup query, \ prefill_packed_query, \ self_prefill_packed_key, \ self_prefill_packed_value, \ self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ + prefill_decoder_seq_lens, \ + self_prefill_encoder_seq_lens, \ decode_packed_query, \ self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - _, \ - _, \ - q_seq_lens, \ - _, \ + decoder_seq_lens, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -818,8 +737,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( head_size, block_size, scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) + max_dec_seq_len) # Cross-attention setup @@ -827,47 +745,68 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_value, \ cross_prefill_packed_ideal_output, \ cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) - - # PREFILL: self- and cross-attention tests + = _enc_dec_cross_attn_setup_reuses_query(query, + decoder_seq_lens, + encoder_seq_lens, + prefill_decoder_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_dec_seq_len, + max_enc_seq_len, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + # - prefill_attn_metadata: AttentionMetadata = make_test_metadata( + enc_and_dec_prefill_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prefill_q_seq_lens, + prefill_decoder_seq_lens, self_prefill_block_tables, self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, + is_encoder_only_test=True, + num_prefills_or_decodes=len(prefill_decoder_seq_lens), + num_prefill_or_decode_tokens=sum(prefill_decoder_seq_lens), + encoder_seq_lens=encoder_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, device=CUDA_DEVICE) + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + attn, + enc_packed_query, + enc_packed_key, + enc_packed_value, + enc_and_dec_prefill_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + assert torch.allclose( + enc_packed_ideal_output, + enc_packed_actual_output.view_as(enc_packed_ideal_output)) + + # PREFILL: self-attention test + self_prefill_packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( + _run_decoder_self_attention_test( attn, prefill_packed_query, self_prefill_packed_key, self_prefill_packed_value, kv_cache, - prefill_attn_metadata, + enc_and_dec_prefill_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? @@ -876,10 +815,12 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output.view_as( self_prefill_packed_ideal_output)) + # PREFILL: cross-attention test + cross_prefill_packed_actual_output: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, prefill_attn_metadata) + cross_prefill_packed_value, kv_cache, enc_and_dec_prefill_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( @@ -887,27 +828,29 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_prefill_packed_actual_output.view_as( cross_prefill_packed_ideal_output)) - context_lens = copy.deepcopy(self_prefill_kv_seq_lens) + context_lens = copy.deepcopy(self_prefill_encoder_seq_lens) - # DECODE: self- and cross-attention tests + # DECODE: build decode-phase attention metadata decode_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, - q_seq_lens, + decoder_seq_lens, self_decode_block_tables, self_decode_slot_mapping, is_encoder_only_test=False, context_lens=context_lens, - num_prefills_or_decodes=len(q_seq_lens), - num_prefill_or_decode_tokens=len(q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, + num_prefills_or_decodes=len(decoder_seq_lens), + num_prefill_or_decode_tokens=len(decoder_seq_lens), + encoder_seq_lens=encoder_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, device=CUDA_DEVICE) + # DECODE: self-attention test + self_decode_packed_actual_output: torch.Tensor = \ - _run_encoder_or_decoder_self_attention_test( + _run_decoder_self_attention_test( attn, decode_packed_query, self_decode_packed_key, @@ -922,6 +865,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_decode_packed_actual_output.view_as( self_decode_packed_ideal_output)) + # DECODE: cross-attention test + cross_decode_packed_actual_output: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, decode_packed_query, None, @@ -1028,10 +973,7 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, self_decode_packed_key, \ self_decode_packed_value, \ self_decode_packed_ideal_output, \ - _, \ - _, \ q_seq_lens, \ - _, \ self_decode_block_tables, \ self_decode_slot_mapping, \ self_prefill_slot_mapping, \ @@ -1055,17 +997,18 @@ def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - _ = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr=cross_block_base_addr) + = _enc_dec_cross_attn_setup_reuses_query(query, + q_seq_lens, + prefill_q_seq_lens, + batch_size, + num_heads, + head_size, + block_size, + scale, + max_q_seq_len, + max_kv_seq_len, + block_base_addr = \ + cross_block_base_addr) # PREFILL: self- and cross-attention tests diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ddeafbc654073..14e9549287604 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -94,10 +94,11 @@ def ref_masked_attention(query: torch.Tensor, def make_qkv( batch_size: int, max_q_seq_len: int, - max_kv_seq_len: int, + max_kv_seq_len: Optional[int], num_heads: int, head_size: int, device: Union[torch.device, str], + force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple: @@ -126,6 +127,8 @@ def make_qkv( key/value seqlen (as is often the case for cross-attention); o/w, query/key/value seqlens match at each batch index (max_kv_seq_len is unused) + * force_kv_seq_lens: if not None, overrides kv sequence lengths + * attn_type: encoder, decoder self, or enc/dec cross attention * force_max_len: if True, all query seqlens are max_q_seq_len; o/w query seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False @@ -146,7 +149,8 @@ def make_qkv( * decode_key: batch_size x 1 x num_heads x head_size * decode_value: batch_size x 1 x num_heads x head_size * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list + * kv_seq_lens: "baseline" key/value seqlen list; overridden by non-None + force_encoder_kv_seq_lens * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= max_q_seq_len due to randomness) * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may @@ -164,7 +168,9 @@ def make_qkv( random.randint(2, max_q_seq_len) for _ in range(batch_size) ] kv_seq_lens = None - if attn_type != AttentionType.ENCODER_DECODER: + if force_kv_seq_lens is not None: + kv_seq_lens = force_kv_seq_lens + elif attn_type != AttentionType.ENCODER_DECODER: # K,V seq lens match Q for self-attention kv_seq_lens = q_seq_lens else: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6948bfea33ce3..cc240242f7eae 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -515,8 +515,18 @@ def forward( self.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens + if attn_metadata.attention_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + num_prefill_tokens = query.shape[0] + num_decode_tokens = 0 if attn_type != AttentionType.ENCODER_DECODER: # Only enforce this shape-constraint for decoder From 20b95b00b4e85516ee077c08ba5fbe2c5dbd2811 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 4 Jun 2024 22:24:49 -0400 Subject: [PATCH 168/239] slight refactoring --- tests/kernels/test_encoder_decoder_attn.py | 8 ++------ tests/kernels/utils.py | 2 -- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 7d9d6d32c0c40..6f91f055a80bc 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,6 +12,8 @@ import pytest import torch +import collections + from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, make_kv_cache, @@ -135,8 +137,6 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, _, \ _, \ _, \ - _, \ - _, \ _ = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, @@ -268,8 +268,6 @@ def _decoder_attn_setup(batch_size: int, decode_value, \ q_seq_lens, \ kv_seq_lens, \ - _, \ - _, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ decode_q_seq_lens, \ @@ -461,8 +459,6 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, _, \ _, \ _, \ - _, \ - _, \ _ = make_qkv(batch_size, max_decoder_seq_len, max_encoder_seq_len, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 14e9549287604..c36a969149e73 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -243,8 +243,6 @@ def make_qkv( decode_value, \ q_seq_lens, \ kv_seq_lens, \ - actual_max_q_seq_len, \ - actual_max_kv_seq_len, \ prefill_q_seq_lens, \ prefill_kv_seq_lens, \ decode_q_seq_lens, \ From 6d52d606e7eaf650cbff541e691e9e831cab63cc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:33:14 -0400 Subject: [PATCH 169/239] QKVInputs and PackedQKVInputs named tuple integration to simplify test logic --- tests/kernels/test_encoder_decoder_attn.py | 695 ++++++++++----------- tests/kernels/utils.py | 114 ++-- 2 files changed, 378 insertions(+), 431 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6f91f055a80bc..431fe42cd6efe 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -20,7 +20,8 @@ make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - split_slot_mapping) + split_slot_mapping, QKVInputs, + PackedQKVInputs) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -85,7 +86,8 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, - scale: float, max_q_seq_len: int) -> tuple: + scale: float, max_q_seq_len: int) \ + -> tuple[PackedQKVInputs,torch.Tensor]: ''' Set up test vectors & data structures for encoder attention test. @@ -122,53 +124,33 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ''' max_kv_seq_len = max_q_seq_len - - query, \ - key, \ - value, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - q_seq_lens, \ - kv_seq_lens, \ - _, \ - _, \ - _, \ - _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) + + qkv_in, _, _ = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(qkv_in.query, + qkv_in.key, + qkv_in.value, scale=scale, - q_seq_lens=q_seq_lens, - kv_seq_lens=kv_seq_lens) + q_seq_lens=qkv_in.q_seq_lens, + kv_seq_lens=qkv_in.kv_seq_lens) packed_ideal_output, _ = pack_tensor(ideal_output, - q_seq_lens, + qkv_in.q_seq_lens, device=CUDA_DEVICE) - packed_query, \ - packed_key, \ - packed_value, _, _ = pack_qkv( - query, key, value, q_seq_lens, - kv_seq_lens, + packed_qkv = pack_qkv( + qkv_in, device=CUDA_DEVICE) - return packed_query, \ - packed_key, \ - packed_value, \ - packed_ideal_output, \ - q_seq_lens + return packed_qkv, \ + packed_ideal_output def _decoder_attn_setup(batch_size: int, @@ -257,49 +239,37 @@ def _decoder_attn_setup(batch_size: int, max_kv_seq_len = max_q_seq_len - query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) + qkv, \ + prefill_qkv, \ + decode_qkv = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(qkv.query, + qkv.key, + qkv.value, scale=scale, custom_mask=causal_mask, - q_seq_lens=q_seq_lens, - kv_seq_lens=kv_seq_lens) + q_seq_lens=qkv.q_seq_lens, + kv_seq_lens=qkv.kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) - for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens): + for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[ bdx, :prefill_q_seq_len] decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:( prefill_q_seq_len + 1)] prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output, - prefill_q_seq_lens, + prefill_qkv.q_seq_lens, device=CUDA_DEVICE) decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output, [1 for _ in range(batch_size)], @@ -327,62 +297,43 @@ def _decoder_attn_setup(batch_size: int, decode_block_tables, \ slot_mapping_list, \ max_block_idx = make_block_tables_slot_mapping(block_size, - q_seq_lens, + qkv.q_seq_lens, device=CUDA_DEVICE, block_base_addr = block_base_addr) prefill_slot_mapping, \ decode_slot_mapping = split_slot_mapping(slot_mapping_list, - q_seq_lens, + qkv.q_seq_lens, device=CUDA_DEVICE) - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, _, _ = pack_qkv( - prefill_query, prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens, - device=CUDA_DEVICE) + prefill_pckd_qkv = pack_qkv(prefill_qkv, + device=CUDA_DEVICE) - decode_packed_query, \ - decode_packed_key, \ - decode_packed_value, \ - _, \ - _ = pack_qkv( - decode_query, decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens, - device=CUDA_DEVICE) + decode_pckd_qkv = pack_qkv(decode_qkv, + device=CUDA_DEVICE) - return query, \ - prefill_packed_query, \ - prefill_packed_key, \ - prefill_packed_value, \ + return qkv, \ + prefill_pckd_qkv, \ prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_packed_query, \ - decode_packed_key, \ - decode_packed_value, \ + decode_pckd_qkv, \ decode_packed_ideal_output, \ - q_seq_lens, \ decode_block_tables, \ decode_slot_mapping, \ prefill_slot_mapping, \ prefill_block_tables, \ max_block_idx - -def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, - decoder_seq_lens: List[int], - encoder_seq_lens: Optional[List[int]], - prefill_q_seq_lens: List, - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_decoder_seq_len: int, - max_encoder_seq_len: int, - block_base_addr: Optional[int]=0) \ +def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, + encoder_packed_qkv: PackedQKVInputs, + prefill_phase_decoder_packed_qkv: PackedQKVInputs, + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_decoder_seq_len: int, + max_encoder_seq_len: int, + block_base_addr: Optional[int]=0) \ -> tuple: ''' Set up test vectors & data structures for cross-attention test. @@ -445,19 +396,13 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, * max_block_idx: highest block address in the cross-attention block-table ''' - _, \ - key, \ - value, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - _, \ - kv_seq_lens, \ - _, \ - _, \ + decoder_query = decoder_qkv.query + decoder_seq_lens = decoder_qkv.q_seq_lens + encoder_seq_lens = encoder_packed_qkv.q_seq_lens + prefill_q_seq_lens = prefill_phase_decoder_packed_qkv.q_seq_lens + + + cross_kv, \ _, \ _ = make_qkv(batch_size, max_decoder_seq_len, @@ -468,12 +413,12 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, attn_type=AttentionType.ENCODER_DECODER, device=CUDA_DEVICE) - ideal_output = ref_masked_attention(query, - key, - value, + ideal_output = ref_masked_attention(decoder_query, + cross_kv.key, + cross_kv.value, scale=scale, q_seq_lens=decoder_seq_lens, - kv_seq_lens=kv_seq_lens) + kv_seq_lens=cross_kv.kv_seq_lens) prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) @@ -520,9 +465,9 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, decode_block_tables, \ prefill_slot_mapping_list, \ - max_block_idx = make_block_tables_slot_mapping( + _ = make_block_tables_slot_mapping( block_size, - kv_seq_lens, + cross_kv.kv_seq_lens, block_base_addr=block_base_addr, device=CUDA_DEVICE) @@ -530,26 +475,20 @@ def _enc_dec_cross_attn_setup_reuses_query(query: torch.Tensor, device=CUDA_DEVICE) # Packed key/value (query is already provided) - _, packed_key, packed_value, _, _ = pack_qkv(None, - key, - value, - None, - kv_seq_lens, - device=CUDA_DEVICE) + packed_cross_kv = pack_qkv(cross_kv, + device=CUDA_DEVICE) - return packed_key, \ - packed_value, \ - prefill_packed_ideal_output, \ - decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables + return packed_cross_kv, \ + prefill_packed_ideal_output, \ + decode_packed_ideal_output, \ + decode_block_tables, \ + decode_slot_mapping, \ + prefill_slot_mapping, \ + prefill_block_tables -def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, +def _run_encoder_attention_test(attn: Attention, + pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -563,7 +502,7 @@ def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, Arguments: * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * pckd_qkv: Packed query/key/value inputs * attn_metadata: attention metadata for encoder/decoder-self attention * attn_type: AttentionType.DECODER or AttentionType.ENCODER @@ -574,14 +513,15 @@ def _run_encoder_attention_test(attn: Attention, packed_query: torch.Tensor, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(packed_query, packed_key, packed_value, None, + return attn.forward(pckd_qkv.query, + pckd_qkv.key, + pckd_qkv.value, + None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, - packed_query: torch.Tensor, - packed_key: torch.Tensor, - packed_value: torch.Tensor, + pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: @@ -596,7 +536,7 @@ def _run_decoder_self_attention_test(attn: Attention, Arguments: * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) + * pckd_qkv: Packed query/key/value inputs * kv_cache * attn_metadata: attention metadata for encoder/decoder-self attention * attn_type: AttentionType.DECODER or AttentionType.ENCODER @@ -607,13 +547,18 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(packed_query, packed_key, packed_value, kv_cache, + return attn.forward(pckd_qkv.query, + pckd_qkv.key, + pckd_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, packed_query: torch.Tensor, packed_key: torch.Tensor, - packed_value: torch.Tensor, kv_cache: torch.Tensor, + attn: Attention, + dec_pckd_qkv: PackedQKVInputs, + cross_pckd_qkv: PackedQKVInputs, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -634,7 +579,14 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER - return attn.forward(packed_query, packed_key, packed_value, kv_cache, + key = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.key + value = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.value + return attn.forward(dec_pckd_qkv.query, + key, + value, + kv_cache, attn_metadata) @@ -700,34 +652,33 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # tensors are not actually utilized by encoder attention # anyway but are required to be present & valid by the # backend. - enc_packed_query, \ - enc_packed_key, \ - enc_packed_value, \ - enc_packed_ideal_output, \ - encoder_seq_lens = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + + # encoder_packed_query, \ + # enc_packed_key, \ + # enc_packed_value, \ + # encoder_packed_ideal_output, \ + # encoder_seq_lens = + + + enc_pckd_qkv, \ + enc_pckd_idl_out = \ + _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_decoder_seq_lens, \ - self_prefill_encoder_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - decoder_seq_lens, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ + dec_qkv, \ + prephase_dec_pckd_qkv, \ + prephase_dec_pckd_idl_out, \ + decphase_dec_pckd_qkv, \ + decphase_dec_pckd_idl_out, \ + decphase_dec_blk_tbls, \ + decphase_dec_slt_map, \ + prephase_dec_slt_map, \ + prephase_dec_blk_tbls, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -737,18 +688,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Cross-attention setup - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ + prephase_cross_pckd_qkv, \ + prephase_cross_pckd_idl_out, \ + decphase_cross_pckd_idl_out, \ cross_decode_block_tables, \ cross_decode_slot_mapping, \ cross_prefill_slot_mapping, \ cross_prefill_block_tables, \ - = _enc_dec_cross_attn_setup_reuses_query(query, - decoder_seq_lens, - encoder_seq_lens, - prefill_decoder_seq_lens, + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_pckd_qkv, + prephase_dec_pckd_qkv, batch_size, num_heads, head_size, @@ -762,16 +711,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Shared prefill metadata structure # - enc_and_dec_prefill_attn_metadata: AttentionMetadata = make_test_metadata( + prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prefill_decoder_seq_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=True, - num_prefills_or_decodes=len(prefill_decoder_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_decoder_seq_lens), - encoder_seq_lens=encoder_seq_lens, + prephase_dec_pckd_qkv.q_seq_lens, + prephase_dec_blk_tbls, + prephase_dec_slt_map, + default_attn_type=AttentionType.ENCODER, + num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), + num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), + encoder_seq_lens=enc_pckd_qkv.q_seq_lens, cross_block_tables=cross_prefill_block_tables, cross_slot_mapping=cross_prefill_slot_mapping, device=CUDA_DEVICE) @@ -782,97 +731,99 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_packed_query, - enc_packed_key, - enc_packed_value, - enc_and_dec_prefill_attn_metadata, + enc_pckd_qkv, + prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? assert torch.allclose( - enc_packed_ideal_output, - enc_packed_actual_output.view_as(enc_packed_ideal_output)) + enc_pckd_idl_out, + enc_packed_actual_output.view_as(enc_pckd_idl_out)) # PREFILL: self-attention test self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( attn, - prefill_packed_query, - self_prefill_packed_key, - self_prefill_packed_value, + prephase_dec_pckd_qkv, kv_cache, - enc_and_dec_prefill_attn_metadata, + prephase_attn_metadata, attn_type=AttentionType.DECODER) # - Prefill self-attention correct? assert torch.allclose( - self_prefill_packed_ideal_output, + prephase_dec_pckd_idl_out, self_prefill_packed_actual_output.view_as( - self_prefill_packed_ideal_output)) + prephase_dec_pckd_idl_out)) # PREFILL: cross-attention test - cross_prefill_packed_actual_output: torch.Tensor = \ + prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, prefill_packed_query, cross_prefill_packed_key, - cross_prefill_packed_value, kv_cache, enc_and_dec_prefill_attn_metadata) + attn, + prephase_dec_pckd_qkv, + prephase_cross_pckd_qkv, + kv_cache, + prephase_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( - cross_prefill_packed_ideal_output, - cross_prefill_packed_actual_output.view_as( - cross_prefill_packed_ideal_output)) - - context_lens = copy.deepcopy(self_prefill_encoder_seq_lens) + prephase_cross_pckd_idl_out, + prephase_cross_pckd_act_out.view_as( + prephase_cross_pckd_idl_out)) # DECODE: build decode-phase attention metadata - decode_attn_metadata: AttentionMetadata = make_test_metadata( + # - Cross-attention KV context is equal in length to + # encoder input + context_lens = copy.deepcopy(enc_pckd_qkv.q_seq_lens) + + decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, - decoder_seq_lens, - self_decode_block_tables, - self_decode_slot_mapping, - is_encoder_only_test=False, + dec_qkv.q_seq_lens, + decphase_dec_blk_tbls, + decphase_dec_slt_map, + default_attn_type=AttentionType.DECODER, context_lens=context_lens, - num_prefills_or_decodes=len(decoder_seq_lens), - num_prefill_or_decode_tokens=len(decoder_seq_lens), - encoder_seq_lens=encoder_seq_lens, + num_prefills_or_decodes=len(dec_qkv.q_seq_lens), + num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), + encoder_seq_lens=enc_pckd_qkv.q_seq_lens, cross_block_tables=cross_decode_block_tables, cross_slot_mapping=cross_decode_slot_mapping, device=CUDA_DEVICE) # DECODE: self-attention test - self_decode_packed_actual_output: torch.Tensor = \ + decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( attn, - decode_packed_query, - self_decode_packed_key, - self_decode_packed_value, + decphase_dec_pckd_qkv, kv_cache, - decode_attn_metadata, + decphase_attn_metadata, attn_type=AttentionType.DECODER) # - Decode self-attention correct? assert torch.allclose( - self_decode_packed_ideal_output, - self_decode_packed_actual_output.view_as( - self_decode_packed_ideal_output)) + decphase_dec_pckd_idl_out, + decphase_dec_pckd_act_out.view_as( + decphase_dec_pckd_idl_out)) # DECODE: cross-attention test - cross_decode_packed_actual_output: torch.Tensor = \ + decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, decode_packed_query, None, - None, kv_cache, decode_attn_metadata) + attn, + decphase_dec_pckd_qkv, + None, + kv_cache, + decphase_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( - cross_decode_packed_ideal_output, - cross_decode_packed_actual_output.view_as( - cross_decode_packed_ideal_output)) + decphase_cross_pckd_idl_out, + decphase_cross_pckd_act_out.view_as( + decphase_cross_pckd_idl_out)) # The following test conditions could in principle be a # standalone test, however the test setup is @@ -889,145 +840,147 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. - decode_attn_metadata.num_prefill_tokens = 1 + decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decode_packed_query, - None, None, kv_cache, - decode_attn_metadata) + _run_encoder_decoder_cross_attention_test(attn, + decphase_dec_pckd_qkv, + None, + kv_cache, + decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL -@pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -@pytest.mark.parametrize("num_heads", [256]) -@pytest.mark.parametrize("head_size", [16]) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", [16]) -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("max_q_seq_len", [64]) -@pytest.mark.parametrize("max_kv_seq_len", [64]) -def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, - backend_name: str, batch_size: int, - block_size: int, max_q_seq_len: int, - max_kv_seq_len: int, monkeypatch) -> None: - ''' - Encoder/decoder not-implemented-for-ROCm-HIP test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order +# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") +# @pytest.mark.parametrize("num_heads", [256]) +# @pytest.mark.parametrize("head_size", [16]) +# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) +# @pytest.mark.parametrize("batch_size", [16]) +# @pytest.mark.parametrize("block_size", [16]) +# @pytest.mark.parametrize("max_q_seq_len", [64]) +# @pytest.mark.parametrize("max_kv_seq_len", [64]) +# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, +# backend_name: str, batch_size: int, +# block_size: int, max_q_seq_len: int, +# max_kv_seq_len: int, monkeypatch) -> None: +# ''' +# Encoder/decoder not-implemented-for-ROCm-HIP test: + +# * Construct fake test vectors for self- and cross-attention +# * Construct attention metadata structure with self- and cross-attention +# attributes +# * Test self- and cross-attention in the following order - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation - - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. - - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - # Num KV cache blocks - num_blocks = 4096 - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = _basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) - - # Self-attention setup - - self_block_base_addr = 0 - - query, \ - prefill_packed_query, \ - self_prefill_packed_key, \ - self_prefill_packed_value, \ - self_prefill_packed_ideal_output, \ - prefill_q_seq_lens, \ - self_prefill_kv_seq_lens, \ - decode_packed_query, \ - self_decode_packed_key, \ - self_decode_packed_value, \ - self_decode_packed_ideal_output, \ - q_seq_lens, \ - self_decode_block_tables, \ - self_decode_slot_mapping, \ - self_prefill_slot_mapping, \ - self_prefill_block_tables, \ - cross_block_base_addr = _decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - block_base_addr=self_block_base_addr) - - # Cross-attention setup - - cross_prefill_packed_key, \ - cross_prefill_packed_value, \ - cross_prefill_packed_ideal_output, \ - cross_decode_packed_ideal_output, \ - encoder_kv_seq_lens, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ - = _enc_dec_cross_attn_setup_reuses_query(query, - q_seq_lens, - prefill_q_seq_lens, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_q_seq_len, - max_kv_seq_len, - block_base_addr = \ - cross_block_base_addr) - - # PREFILL: self- and cross-attention tests - - prefill_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, - True, - prefill_q_seq_lens, - self_prefill_block_tables, - self_prefill_slot_mapping, - is_encoder_only_test=False, - num_prefills_or_decodes=len(prefill_q_seq_lens), - num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), - encoder_seq_lens=encoder_kv_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, - device=CUDA_DEVICE) - - with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, - cross_prefill_packed_key, - cross_prefill_packed_value, - kv_cache, - prefill_attn_metadata) - - # "Encoder decoder models do not currently support ROCm/HIP" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP +# * Prefill self-attention +# * Prefill cross-attention +# * Decode self-attention +# * Decode cross-attention +# * This order would exacerbate any accidental overlap in the +# self-/cross-attention block tables, which we attempt to avoid +# * Validate output correctness against ideal reference attention +# implementation + +# Block tables are constructed such that cross-attention KV cache is in a +# higher, non-intersecting address-space than self-attention KV cache. + +# Self- and cross-attention share the same query tensor but not the K/V +# tensors. Self-attention K/Vs must have the same seq len as Q while +# cross-attention K/Vs are allowed to differ in seq len, as is often the case +# for cross-attention. +# ''' + +# # Force Attention wrapper backend +# override_backend_env_variable(monkeypatch, backend_name) + +# # Num KV cache blocks +# num_blocks = 4096 + +# # Attention scale factor, attention backend instance, attention wrapper +# # instance, KV cache init +# scale, \ +# attn_backend, \ +# attn, \ +# kv_cache = _basic_setup(num_heads, +# head_size, +# num_blocks, +# block_size, +# backend_name) + +# # Self-attention setup + +# self_block_base_addr = 0 + +# query, \ +# prefill_packed_query, \ +# self_prefill_packed_key, \ +# self_prefill_packed_value, \ +# self_prefill_packed_ideal_output, \ +# prefill_q_seq_lens, \ +# self_prefill_kv_seq_lens, \ +# decode_packed_query, \ +# self_decode_packed_key, \ +# self_decode_packed_value, \ +# self_decode_packed_ideal_output, \ +# q_seq_lens, \ +# self_decode_block_tables, \ +# self_decode_slot_mapping, \ +# self_prefill_slot_mapping, \ +# self_prefill_block_tables, \ +# cross_block_base_addr = _decoder_attn_setup(batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# block_base_addr=self_block_base_addr) + +# # Cross-attention setup + +# cross_prefill_packed_key, \ +# cross_prefill_packed_value, \ +# cross_prefill_packed_ideal_output, \ +# cross_decode_packed_ideal_output, \ +# encoder_kv_seq_lens, \ +# cross_decode_block_tables, \ +# cross_decode_slot_mapping, \ +# cross_prefill_slot_mapping, \ +# cross_prefill_block_tables, \ +# = _enc_dec_cross_attn_setup_reuses_query(query, +# q_seq_lens, +# prefill_q_seq_lens, +# batch_size, +# num_heads, +# head_size, +# block_size, +# scale, +# max_q_seq_len, +# max_kv_seq_len, +# block_base_addr = \ +# cross_block_base_addr) + +# # PREFILL: self- and cross-attention tests + +# prefill_attn_metadata: AttentionMetadata = make_test_metadata( +# attn_backend, +# True, +# prefill_q_seq_lens, +# self_prefill_block_tables, +# self_prefill_slot_mapping, +# is_encoder_only_test=False, +# num_prefills_or_decodes=len(prefill_q_seq_lens), +# num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), +# encoder_seq_lens=encoder_kv_seq_lens, +# cross_block_tables=cross_prefill_block_tables, +# cross_slot_mapping=cross_prefill_slot_mapping, +# device=CUDA_DEVICE) + +# with pytest.raises(NotImplementedError) as exc_info: +# _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, +# cross_prefill_packed_key, +# cross_prefill_packed_value, +# kv_cache, +# prefill_attn_metadata) + +# # "Encoder decoder models do not currently support ROCm/HIP" +# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c36a969149e73..142bc64d34b99 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -7,6 +7,8 @@ import pytest import torch +from collections import namedtuple + from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend @@ -90,6 +92,23 @@ def ref_masked_attention(query: torch.Tensor, out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out +# batch_size x max_q_seq_len x num_heads x head_size +QKVInputs = namedtuple("QKVInputs", + ["query", + "key", + "value", + "q_seq_lens", + "kv_seq_lens"]) + +# total_num_tokens x (num_heads*head_size) +PackedQKVInputs = namedtuple("PackedQKVInputs", + ["query", + "key", + "value", + "q_start_loc_list", + "kv_start_loc_list", + "q_seq_lens", + "kv_seq_lens"]) def make_qkv( batch_size: int, @@ -101,7 +120,7 @@ def make_qkv( force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple: +) -> tuple[QKVInputs,QKVInputs,QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -136,29 +155,7 @@ def make_qkv( Returns: - * query: "baseline" query; batch_size x max_q_seq_len x num_heads x - head_size - * key: "baseline" key; batch_size x max_kv_seq_len x num_heads x - head_size - * value: "baseline" value; batch_size x max_kv_seq_len x num_heads x - head_size - * prefill_query: batch_size x (max_q_seq_len-1) x num_heads x head_size - * prefill_key: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * prefill_value: batch_size x (max_kv_seq_len-1) x num_heads x head_size - * decode_query: batch_size x 1 x num_heads x head_size - * decode_key: batch_size x 1 x num_heads x head_size - * decode_value: batch_size x 1 x num_heads x head_size - * q_seq_lens: "baseline" query seqlen list - * kv_seq_lens: "baseline" key/value seqlen list; overridden by non-None - force_encoder_kv_seq_lens - * actual_max_q_seq_len: actual "baseline" query max seq len (may be <= - max_q_seq_len due to randomness) - * actual_max_kv_seq_len: actual "baseline" key/value max seq len (may - be <= max_kv_seq_len due to randomness) - * prefill_q_seq_lens: "prefill" query seqlen list - * prefill_kv_seq_lens: "prefill" key/value seqlen list - * decode_q_seq_lens: "decode" query seqlen list (all ones) - * decode_kv_seq_lens: "decode" key/value seqlen list + * QKVInputs structure ''' if force_max_len: @@ -182,9 +179,6 @@ def make_qkv( random.randint(2, max_kv_seq_len) for _ in range(batch_size) ] - actual_max_q_seq_len = max(q_seq_lens) - actual_max_kv_seq_len = max(kv_seq_lens) - query = torch.rand( (batch_size, max_q_seq_len, num_heads, head_size)).to(device) key = torch.rand( @@ -232,21 +226,22 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return query, \ - key, \ - value, \ - prefill_query, \ - prefill_key, \ - prefill_value, \ - decode_query, \ - decode_key, \ - decode_value, \ - q_seq_lens, \ - kv_seq_lens, \ - prefill_q_seq_lens, \ - prefill_kv_seq_lens, \ - decode_q_seq_lens, \ - decode_kv_seq_lens + return QKVInputs(query, + key, + value, + q_seq_lens, + kv_seq_lens), \ + QKVInputs(prefill_query, + prefill_key, + prefill_value, + prefill_q_seq_lens, + prefill_kv_seq_lens), \ + QKVInputs( + decode_query, + decode_key, + decode_value, + decode_q_seq_lens, + decode_kv_seq_lens) def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], @@ -283,9 +278,8 @@ def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], return packed_tensor, start_loc_list -def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - q_seq_lens: List[int], kv_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def pack_qkv(qkv: QKVInputs, + device: Union[torch.device, str]) -> PackedQKVInputs: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x @@ -312,22 +306,25 @@ def pack_qkv(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} ''' - if query is None: + if qkv.query is None: packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(query, - q_seq_lens, + packed_query, q_start_loc_list = pack_tensor(qkv.query, + qkv.q_seq_lens, device=device) - packed_key, kv_start_loc_list = pack_tensor(key, - kv_seq_lens, + packed_key, kv_start_loc_list = pack_tensor(qkv.key, + qkv.kv_seq_lens, device=device) - packed_value, _ = pack_tensor(value, kv_seq_lens, device=device) - return packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list + packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) + return PackedQKVInputs(packed_query, \ + packed_key, \ + packed_value, \ + q_start_loc_list, \ + kv_start_loc_list, \ + None if q_start_loc_list is None else \ + qkv.q_seq_lens, \ + qkv.kv_seq_lens) def make_backend(backend_name: str) -> AttentionBackend: @@ -589,7 +586,7 @@ def make_test_metadata( seq_lens: List[int], block_tables: torch.Tensor, slot_mapping: torch.Tensor, - is_encoder_only_test: bool, + default_attn_type: AttentionType, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, device: Union[torch.device, str], @@ -631,9 +628,6 @@ def make_test_metadata( * AttentionMetadata structure supporting self- and cross-attention ''' - default_attn_type = AttentionType.ENCODER if is_encoder_only_test \ - else AttentionType.DECODER - if is_prompt: num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens From a81712b14126ff7563ed0e6ce145a8401c0efaec Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:36:49 -0400 Subject: [PATCH 170/239] refactoring --- tests/kernels/test_encoder_decoder_attn.py | 246 ++++----------------- tests/kernels/utils.py | 31 +-- 2 files changed, 56 insertions(+), 221 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 431fe42cd6efe..955fa1629440f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -7,21 +7,19 @@ """ import copy -from typing import List, Optional +from typing import Optional import pytest import torch -import collections - -from tests.kernels.utils import (make_backend, make_block_tables_slot_mapping, +from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, + make_block_tables_slot_mapping, make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, pack_qkv, pack_tensor, ref_masked_attention, - split_slot_mapping, QKVInputs, - PackedQKVInputs) + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( @@ -124,14 +122,14 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ''' max_kv_seq_len = max_q_seq_len - + qkv_in, _, _ = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) # No causal attention mask ideal_output = ref_masked_attention(qkv_in.query, @@ -145,9 +143,7 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, qkv_in.q_seq_lens, device=CUDA_DEVICE) - packed_qkv = pack_qkv( - qkv_in, - device=CUDA_DEVICE) + packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return packed_qkv, \ packed_ideal_output @@ -306,11 +302,9 @@ def _decoder_attn_setup(batch_size: int, qkv.q_seq_lens, device=CUDA_DEVICE) - prefill_pckd_qkv = pack_qkv(prefill_qkv, - device=CUDA_DEVICE) + prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) - decode_pckd_qkv = pack_qkv(decode_qkv, - device=CUDA_DEVICE) + decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) return qkv, \ prefill_pckd_qkv, \ @@ -324,8 +318,10 @@ def _decoder_attn_setup(batch_size: int, max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_packed_qkv: PackedQKVInputs, - prefill_phase_decoder_packed_qkv: PackedQKVInputs, + encoder_packed_qkv: + PackedQKVInputs, + prefill_phase_decoder_packed_qkv: + PackedQKVInputs, batch_size: int, num_heads: int, head_size: int, @@ -475,8 +471,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, device=CUDA_DEVICE) # Packed key/value (query is already provided) - packed_cross_kv = pack_qkv(cross_kv, - device=CUDA_DEVICE) + packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) return packed_cross_kv, \ prefill_packed_ideal_output, \ @@ -487,8 +482,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_block_tables -def _run_encoder_attention_test(attn: Attention, - pckd_qkv: PackedQKVInputs, +def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -513,10 +507,7 @@ def _run_encoder_attention_test(attn: Attention, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, - pckd_qkv.key, - pckd_qkv.value, - None, + return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, None, attn_metadata) @@ -547,18 +538,13 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, - pckd_qkv.key, - pckd_qkv.value, - kv_cache, + return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, - dec_pckd_qkv: PackedQKVInputs, - cross_pckd_qkv: PackedQKVInputs, - kv_cache: torch.Tensor, + attn: Attention, dec_pckd_qkv: PackedQKVInputs, + cross_pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -583,10 +569,7 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(dec_pckd_qkv.query, - key, - value, - kv_cache, + return attn.forward(dec_pckd_qkv.query, key, value, kv_cache, attn_metadata) @@ -657,9 +640,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # enc_packed_key, \ # enc_packed_value, \ # encoder_packed_ideal_output, \ - # encoder_seq_lens = - - + # encoder_seq_lens = + + enc_pckd_qkv, \ enc_pckd_idl_out = \ _encoder_attn_setup(batch_size, @@ -736,9 +719,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose( - enc_pckd_idl_out, - enc_packed_actual_output.view_as(enc_pckd_idl_out)) + assert torch.allclose(enc_pckd_idl_out, + enc_packed_actual_output.view_as(enc_pckd_idl_out)) # PREFILL: self-attention test @@ -753,24 +735,22 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Prefill self-attention correct? assert torch.allclose( prephase_dec_pckd_idl_out, - self_prefill_packed_actual_output.view_as( - prephase_dec_pckd_idl_out)) + self_prefill_packed_actual_output.view_as(prephase_dec_pckd_idl_out)) # PREFILL: cross-attention test prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, - prephase_dec_pckd_qkv, - prephase_cross_pckd_qkv, - kv_cache, + attn, + prephase_dec_pckd_qkv, + prephase_cross_pckd_qkv, + kv_cache, prephase_attn_metadata) # - Prefill cross-attention correct? assert torch.allclose( prephase_cross_pckd_idl_out, - prephase_cross_pckd_act_out.view_as( - prephase_cross_pckd_idl_out)) + prephase_cross_pckd_act_out.view_as(prephase_cross_pckd_idl_out)) # DECODE: build decode-phase attention metadata @@ -806,24 +786,22 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Decode self-attention correct? assert torch.allclose( decphase_dec_pckd_idl_out, - decphase_dec_pckd_act_out.view_as( - decphase_dec_pckd_idl_out)) + decphase_dec_pckd_act_out.view_as(decphase_dec_pckd_idl_out)) # DECODE: cross-attention test decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + attn, decphase_dec_pckd_qkv, - None, - kv_cache, + None, + kv_cache, decphase_attn_metadata) # - Decode cross-attention correct? assert torch.allclose( decphase_cross_pckd_idl_out, - decphase_cross_pckd_act_out.view_as( - decphase_cross_pckd_idl_out)) + decphase_cross_pckd_act_out.view_as(decphase_cross_pckd_idl_out)) # The following test conditions could in principle be a # standalone test, however the test setup is @@ -842,145 +820,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, - decphase_dec_pckd_qkv, - None, - kv_cache, + _run_encoder_decoder_cross_attention_test(attn, decphase_dec_pckd_qkv, + None, kv_cache, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -# @pytest.mark.skipif(not is_hip(), reason="This test requires ROCm/HIP") -# @pytest.mark.parametrize("num_heads", [256]) -# @pytest.mark.parametrize("head_size", [16]) -# @pytest.mark.parametrize("backend_name", BACKEND_NAMES) -# @pytest.mark.parametrize("batch_size", [16]) -# @pytest.mark.parametrize("block_size", [16]) -# @pytest.mark.parametrize("max_q_seq_len", [64]) -# @pytest.mark.parametrize("max_kv_seq_len", [64]) -# def test_enc_dec_no_rocm_hip_support(num_heads: int, head_size: int, -# backend_name: str, batch_size: int, -# block_size: int, max_q_seq_len: int, -# max_kv_seq_len: int, monkeypatch) -> None: -# ''' -# Encoder/decoder not-implemented-for-ROCm-HIP test: - -# * Construct fake test vectors for self- and cross-attention -# * Construct attention metadata structure with self- and cross-attention -# attributes -# * Test self- and cross-attention in the following order - -# * Prefill self-attention -# * Prefill cross-attention -# * Decode self-attention -# * Decode cross-attention -# * This order would exacerbate any accidental overlap in the -# self-/cross-attention block tables, which we attempt to avoid -# * Validate output correctness against ideal reference attention -# implementation - -# Block tables are constructed such that cross-attention KV cache is in a -# higher, non-intersecting address-space than self-attention KV cache. - -# Self- and cross-attention share the same query tensor but not the K/V -# tensors. Self-attention K/Vs must have the same seq len as Q while -# cross-attention K/Vs are allowed to differ in seq len, as is often the case -# for cross-attention. -# ''' - -# # Force Attention wrapper backend -# override_backend_env_variable(monkeypatch, backend_name) - -# # Num KV cache blocks -# num_blocks = 4096 - -# # Attention scale factor, attention backend instance, attention wrapper -# # instance, KV cache init -# scale, \ -# attn_backend, \ -# attn, \ -# kv_cache = _basic_setup(num_heads, -# head_size, -# num_blocks, -# block_size, -# backend_name) - -# # Self-attention setup - -# self_block_base_addr = 0 - -# query, \ -# prefill_packed_query, \ -# self_prefill_packed_key, \ -# self_prefill_packed_value, \ -# self_prefill_packed_ideal_output, \ -# prefill_q_seq_lens, \ -# self_prefill_kv_seq_lens, \ -# decode_packed_query, \ -# self_decode_packed_key, \ -# self_decode_packed_value, \ -# self_decode_packed_ideal_output, \ -# q_seq_lens, \ -# self_decode_block_tables, \ -# self_decode_slot_mapping, \ -# self_prefill_slot_mapping, \ -# self_prefill_block_tables, \ -# cross_block_base_addr = _decoder_attn_setup(batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# block_base_addr=self_block_base_addr) - -# # Cross-attention setup - -# cross_prefill_packed_key, \ -# cross_prefill_packed_value, \ -# cross_prefill_packed_ideal_output, \ -# cross_decode_packed_ideal_output, \ -# encoder_kv_seq_lens, \ -# cross_decode_block_tables, \ -# cross_decode_slot_mapping, \ -# cross_prefill_slot_mapping, \ -# cross_prefill_block_tables, \ -# = _enc_dec_cross_attn_setup_reuses_query(query, -# q_seq_lens, -# prefill_q_seq_lens, -# batch_size, -# num_heads, -# head_size, -# block_size, -# scale, -# max_q_seq_len, -# max_kv_seq_len, -# block_base_addr = \ -# cross_block_base_addr) - -# # PREFILL: self- and cross-attention tests - -# prefill_attn_metadata: AttentionMetadata = make_test_metadata( -# attn_backend, -# True, -# prefill_q_seq_lens, -# self_prefill_block_tables, -# self_prefill_slot_mapping, -# is_encoder_only_test=False, -# num_prefills_or_decodes=len(prefill_q_seq_lens), -# num_prefill_or_decode_tokens=sum(prefill_q_seq_lens), -# encoder_seq_lens=encoder_kv_seq_lens, -# cross_block_tables=cross_prefill_block_tables, -# cross_slot_mapping=cross_prefill_slot_mapping, -# device=CUDA_DEVICE) - -# with pytest.raises(NotImplementedError) as exc_info: -# _run_encoder_decoder_cross_attention_test(attn, prefill_packed_query, -# cross_prefill_packed_key, -# cross_prefill_packed_value, -# kv_cache, -# prefill_attn_metadata) - -# # "Encoder decoder models do not currently support ROCm/HIP" -# assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_ROCM_HIP + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 142bc64d34b99..e57d6c412537c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,13 +2,12 @@ import itertools import random +from collections import namedtuple from typing import List, Optional, Union import pytest import torch -from collections import namedtuple - from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend @@ -92,23 +91,17 @@ def ref_masked_attention(query: torch.Tensor, out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value) return out + # batch_size x max_q_seq_len x num_heads x head_size -QKVInputs = namedtuple("QKVInputs", - ["query", - "key", - "value", - "q_seq_lens", - "kv_seq_lens"]) +QKVInputs = namedtuple("QKVInputs", + ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) # total_num_tokens x (num_heads*head_size) -PackedQKVInputs = namedtuple("PackedQKVInputs", - ["query", - "key", - "value", - "q_start_loc_list", - "kv_start_loc_list", - "q_seq_lens", - "kv_seq_lens"]) +PackedQKVInputs = namedtuple("PackedQKVInputs", [ + "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", + "q_seq_lens", "kv_seq_lens" +]) + def make_qkv( batch_size: int, @@ -120,7 +113,7 @@ def make_qkv( force_kv_seq_lens: List[int] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple[QKVInputs,QKVInputs,QKVInputs]: +) -> tuple[QKVInputs, QKVInputs, QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -278,8 +271,8 @@ def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, - device: Union[torch.device, str]) -> PackedQKVInputs: +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, + str]) -> PackedQKVInputs: ''' Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x From d35ea41a5fd8d55c902dcbfeaac505e8362ef227 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:40:26 -0400 Subject: [PATCH 171/239] format --- tests/kernels/test_encoder_decoder_attn.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 955fa1629440f..db0dd3ec9ade8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -636,13 +636,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - # encoder_packed_query, \ - # enc_packed_key, \ - # enc_packed_value, \ - # encoder_packed_ideal_output, \ - # encoder_seq_lens = - - enc_pckd_qkv, \ enc_pckd_idl_out = \ _encoder_attn_setup(batch_size, @@ -692,7 +685,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( cross_block_base_addr) # Shared prefill metadata structure - # prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, @@ -825,4 +817,4 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL \ No newline at end of file + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL From 27782df3ac02df531fe71730000e1ad40e34c975 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:42:56 -0400 Subject: [PATCH 172/239] yapf fix --- tests/kernels/test_encoder_decoder_attn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index db0dd3ec9ade8..bc475599163c2 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,14 +12,11 @@ import pytest import torch -from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, - make_block_tables_slot_mapping, - make_empty_block_tables_tensor, - make_empty_slot_mapping_tensor, make_kv_cache, - make_qkv, make_test_metadata, - override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention, - split_slot_mapping) +from tests.kernels.utils import ( + PackedQKVInputs, QKVInputs, make_backend, make_block_tables_slot_mapping, + make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, + make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, + pack_qkv, pack_tensor, ref_masked_attention, split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From c3a2e7afb82448299ffd3caf351d5b3d3b11cfdd Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:44:26 -0400 Subject: [PATCH 173/239] import reorg --- tests/kernels/test_encoder_decoder_attn.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index bc475599163c2..db0dd3ec9ade8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,11 +12,14 @@ import pytest import torch -from tests.kernels.utils import ( - PackedQKVInputs, QKVInputs, make_backend, make_block_tables_slot_mapping, - make_empty_block_tables_tensor, make_empty_slot_mapping_tensor, - make_kv_cache, make_qkv, make_test_metadata, override_backend_env_variable, - pack_qkv, pack_tensor, ref_masked_attention, split_slot_mapping) +from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, + make_block_tables_slot_mapping, + make_empty_block_tables_tensor, + make_empty_slot_mapping_tensor, make_kv_cache, + make_qkv, make_test_metadata, + override_backend_env_variable, pack_qkv, + pack_tensor, ref_masked_attention, + split_slot_mapping) from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From 8babfda338e5f941054b630cf6973bfb6d56f747 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 00:47:25 -0400 Subject: [PATCH 174/239] switched to star import to avoid unsatisfiable formatting constraints --- tests/kernels/test_encoder_decoder_attn.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index db0dd3ec9ade8..0b1546905926b 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -12,14 +12,7 @@ import pytest import torch -from tests.kernels.utils import (PackedQKVInputs, QKVInputs, make_backend, - make_block_tables_slot_mapping, - make_empty_block_tables_tensor, - make_empty_slot_mapping_tensor, make_kv_cache, - make_qkv, make_test_metadata, - override_backend_env_variable, pack_qkv, - pack_tensor, ref_masked_attention, - split_slot_mapping) +from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( From ce2422be980c7bac12355a6bf8f4fbac8c471136 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:09:37 -0400 Subject: [PATCH 175/239] progress on memory map structure integration --- tests/kernels/test_encoder_decoder_attn.py | 22 ++++++++++------------ tests/kernels/utils.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0b1546905926b..3690666371625 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -302,12 +302,14 @@ def _decoder_attn_setup(batch_size: int, return qkv, \ prefill_pckd_qkv, \ prefill_packed_ideal_output, \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping), \ decode_pckd_qkv, \ decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables, \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping), \ max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, @@ -642,12 +644,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( dec_qkv, \ prephase_dec_pckd_qkv, \ prephase_dec_pckd_idl_out, \ + prephase_dec_kv_mmap, \ decphase_dec_pckd_qkv, \ decphase_dec_pckd_idl_out, \ - decphase_dec_blk_tbls, \ - decphase_dec_slt_map, \ - prephase_dec_slt_map, \ - prephase_dec_blk_tbls, \ + decphase_dec_kv_mmap, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -683,8 +683,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, True, prephase_dec_pckd_qkv.q_seq_lens, - prephase_dec_blk_tbls, - prephase_dec_slt_map, + prephase_dec_kv_mmap, default_attn_type=AttentionType.ENCODER, num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), @@ -747,8 +746,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_backend, False, dec_qkv.q_seq_lens, - decphase_dec_blk_tbls, - decphase_dec_slt_map, + decphase_dec_kv_mmap, default_attn_type=AttentionType.DECODER, context_lens=context_lens, num_prefills_or_decodes=len(dec_qkv.q_seq_lens), diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e57d6c412537c..15baa68c8fb41 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -102,6 +102,9 @@ def ref_masked_attention(query: torch.Tensor, "q_seq_lens", "kv_seq_lens" ]) +KVMemoryMap = namedtuple("KVMemoryMap", [ + "block_tables", "slot_mapping" +]) def make_qkv( batch_size: int, @@ -577,8 +580,7 @@ def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, seq_lens: List[int], - block_tables: torch.Tensor, - slot_mapping: torch.Tensor, + kv_mmap: KVMemoryMap, default_attn_type: AttentionType, num_prefills_or_decodes: int, num_prefill_or_decode_tokens: int, @@ -639,7 +641,7 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -647,7 +649,7 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=block_tables, + block_tables=kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, @@ -675,7 +677,7 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=slot_mapping, + slot_mapping=kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -683,7 +685,7 @@ def make_test_metadata( max_prefill_seq_len=0, max_decode_seq_len=max(seq_lens), context_lens_tensor=context_lens_tensor, - block_tables=block_tables, + block_tables=kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, From 00450517ee415b8193c612faae041291ce9daccc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:19:50 -0400 Subject: [PATCH 176/239] completed integration of KVMemoryMap into tests --- tests/kernels/test_encoder_decoder_attn.py | 23 ++++++++++------------ tests/kernels/utils.py | 15 ++++++++------ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3690666371625..c53eb9ed85970 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -470,12 +470,13 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, return packed_cross_kv, \ prefill_packed_ideal_output, \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping), \ decode_packed_ideal_output, \ - decode_block_tables, \ - decode_slot_mapping, \ - prefill_slot_mapping, \ - prefill_block_tables - + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping), \ def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, @@ -659,11 +660,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_qkv, \ prephase_cross_pckd_idl_out, \ + prephase_cross_kv_mmap, \ decphase_cross_pckd_idl_out, \ - cross_decode_block_tables, \ - cross_decode_slot_mapping, \ - cross_prefill_slot_mapping, \ - cross_prefill_block_tables, \ + decphase_cross_kv_mmap \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_pckd_qkv, prephase_dec_pckd_qkv, @@ -688,8 +687,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), encoder_seq_lens=enc_pckd_qkv.q_seq_lens, - cross_block_tables=cross_prefill_block_tables, - cross_slot_mapping=cross_prefill_slot_mapping, + cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -752,8 +750,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_prefills_or_decodes=len(dec_qkv.q_seq_lens), num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), encoder_seq_lens=enc_pckd_qkv.q_seq_lens, - cross_block_tables=cross_decode_block_tables, - cross_slot_mapping=cross_decode_slot_mapping, + cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) # DECODE: self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 15baa68c8fb41..d0f5f20e12b1d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -587,8 +587,7 @@ def make_test_metadata( device: Union[torch.device, str], context_lens: Optional[List[int]] = None, encoder_seq_lens: Optional[List[int]] = None, - cross_block_tables: Optional[torch.Tensor] = None, - cross_slot_mapping: Optional[List[int]] = None, + cross_kv_mmap: Optional[KVMemoryMap] = None, ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -655,8 +654,10 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) + cross_slot_mapping=None if cross_kv_mmap is None else \ + cross_kv_mmap.slot_mapping, + cross_block_tables=None if cross_kv_mmap is None else \ + cross_kv_mmap.block_tables) else: # not is_prompt @@ -691,5 +692,7 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=cross_slot_mapping, - cross_block_tables=cross_block_tables) + cross_slot_mapping=None if cross_kv_mmap is None else \ + cross_kv_mmap.slot_mapping, + cross_block_tables=None if cross_kv_mmap is None else \ + cross_kv_mmap.block_tables) From 91eb0671a960a6f52ca8de22d2c6a604c4ae477e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 01:25:57 -0400 Subject: [PATCH 177/239] first step toward QKVO integration into tests --- tests/kernels/test_encoder_decoder_attn.py | 34 +++++++++++----------- tests/kernels/utils.py | 12 ++++++++ 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c53eb9ed85970..1cdd02ad5aa87 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -78,7 +78,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, scale: float, max_q_seq_len: int) \ - -> tuple[PackedQKVInputs,torch.Tensor]: + -> PackedQKVO: ''' Set up test vectors & data structures for encoder attention test. @@ -138,8 +138,9 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - return packed_qkv, \ - packed_ideal_output + return PackedQKVO( + packed_qkv, \ + packed_ideal_output) def _decoder_attn_setup(batch_size: int, @@ -632,13 +633,11 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_pckd_qkv, \ - enc_pckd_idl_out = \ - _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_pckd_qkvo = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup @@ -664,7 +663,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_idl_out, \ decphase_cross_kv_mmap \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_pckd_qkv, + enc_pckd_qkvo.packed_qkv, prephase_dec_pckd_qkv, batch_size, num_heads, @@ -686,7 +685,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( default_attn_type=AttentionType.ENCODER, num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkv.q_seq_lens, + encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) @@ -696,13 +695,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_pckd_qkv, + enc_pckd_qkvo.packed_qkv, prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(enc_pckd_idl_out, - enc_packed_actual_output.view_as(enc_pckd_idl_out)) + assert torch.allclose(enc_pckd_qkvo.ideal_output, + enc_packed_actual_output + .view_as(enc_pckd_qkvo.ideal_output)) # PREFILL: self-attention test @@ -738,7 +738,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Cross-attention KV context is equal in length to # encoder input - context_lens = copy.deepcopy(enc_pckd_qkv.q_seq_lens) + context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, @@ -749,7 +749,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( context_lens=context_lens, num_prefills_or_decodes=len(dec_qkv.q_seq_lens), num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkv.q_seq_lens, + encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index d0f5f20e12b1d..14ae12a061e55 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -96,12 +96,24 @@ def ref_masked_attention(query: torch.Tensor, QKVInputs = namedtuple("QKVInputs", ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) +QKVO = namedtuple("QKVO", + [ + "qkv", + "ideal_output" + ]) + # total_num_tokens x (num_heads*head_size) PackedQKVInputs = namedtuple("PackedQKVInputs", [ "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", "q_seq_lens", "kv_seq_lens" ]) +PackedQKVO = namedtuple("PackedQKVO", + [ + "packed_qkv", + "ideal_output" + ]) + KVMemoryMap = namedtuple("KVMemoryMap", [ "block_tables", "slot_mapping" ]) From a6aee8002125115ae41dc22a0cbf2a9a210bd5ce Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 15:28:44 -0400 Subject: [PATCH 178/239] wip test params structure integration --- tests/kernels/test_encoder_decoder_attn.py | 107 +++++++++++---------- tests/kernels/utils.py | 5 + 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 1cdd02ad5aa87..0ac2e47660db5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -78,7 +78,7 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, scale: float, max_q_seq_len: int) \ - -> PackedQKVO: + -> PhaseTestParameters: ''' Set up test vectors & data structures for encoder attention test. @@ -138,9 +138,13 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) - return PackedQKVO( - packed_qkv, \ - packed_ideal_output) + return PhaseTestParameters( + PackedQKVO( + packed_qkv, \ + packed_ideal_output), + + None # No KV cache + ) def _decoder_attn_setup(batch_size: int, @@ -149,7 +153,10 @@ def _decoder_attn_setup(batch_size: int, block_size: int, scale: float, max_q_seq_len: int, - block_base_addr: int = 0) -> tuple: + block_base_addr: int = 0) -> tuple[QKVInputs, + PhaseTestParameters, + PhaseTestParameters, + int]: ''' Set up test vectors & data structures for self-attention test. @@ -301,23 +308,27 @@ def _decoder_attn_setup(batch_size: int, decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) return qkv, \ - prefill_pckd_qkv, \ - prefill_packed_ideal_output, \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping), \ - decode_pckd_qkv, \ - decode_packed_ideal_output, \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping), \ - max_block_idx + PhaseTestParameters( # Prefill test params + PackedQKVO( + prefill_pckd_qkv, \ + prefill_packed_ideal_output), \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping)), \ + PhaseTestParameters( # Decode test params + PackedQKVO( + decode_pckd_qkv, \ + decode_packed_ideal_output), \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping)), \ + max_block_idx def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_packed_qkv: - PackedQKVInputs, - prefill_phase_decoder_packed_qkv: - PackedQKVInputs, + encoder_test_params: + PhaseTestParameters, + prefill_phase_test_params: + PhaseTestParameters, batch_size: int, num_heads: int, head_size: int, @@ -390,8 +401,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens - encoder_seq_lens = encoder_packed_qkv.q_seq_lens - prefill_q_seq_lens = prefill_phase_decoder_packed_qkv.q_seq_lens + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -469,15 +480,20 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - return packed_cross_kv, \ - prefill_packed_ideal_output, \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping), \ - decode_packed_ideal_output, \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping), \ + return PhaseTestParameters( # Prefill-phase test params + PackedQKVO( + packed_cross_kv, \ + prefill_packed_ideal_output), \ + KVMemoryMap( + prefill_block_tables, \ + prefill_slot_mapping)), \ + PhaseTestParameters( # Decode-phase test params + PackedQKVO( + None, + decode_packed_ideal_output), \ + KVMemoryMap( + decode_block_tables, \ + decode_slot_mapping)) def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, attn_metadata: AttentionMetadata, @@ -633,21 +649,17 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_pckd_qkvo = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_test_params = _encoder_attn_setup(batch_size, + num_heads, + head_size, + scale, + max_enc_seq_len) # Decoder self-attention setup dec_qkv, \ - prephase_dec_pckd_qkv, \ - prephase_dec_pckd_idl_out, \ - prephase_dec_kv_mmap, \ - decphase_dec_pckd_qkv, \ - decphase_dec_pckd_idl_out, \ - decphase_dec_kv_mmap, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ cross_block_base_addr = _decoder_attn_setup(batch_size, num_heads, head_size, @@ -657,14 +669,11 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Cross-attention setup - prephase_cross_pckd_qkv, \ - prephase_cross_pckd_idl_out, \ - prephase_cross_kv_mmap, \ - decphase_cross_pckd_idl_out, \ - decphase_cross_kv_mmap \ + prephase_cross_test_params, \ + decphase_cross_test_params, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_pckd_qkvo.packed_qkv, - prephase_dec_pckd_qkv, + enc_test_params, + prephase_dec_test_params, batch_size, num_heads, head_size, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 14ae12a061e55..e2d09522db601 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -118,6 +118,11 @@ def ref_masked_attention(query: torch.Tensor, "block_tables", "slot_mapping" ]) +PhaseTestParameters = namedtuple("PhaseTestParameters", [ + "packed_qkvo", + "kv_mmap" +]) + def make_qkv( batch_size: int, max_q_seq_len: int, From ee512605696e5170fabd3b17f6e7db86143230a7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:01:45 -0400 Subject: [PATCH 179/239] prephase md struct using test params --- tests/kernels/test_encoder_decoder_attn.py | 9 ++--- tests/kernels/utils.py | 38 ++++++++++++++++++---- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0ac2e47660db5..e9136909f8af6 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -689,13 +689,10 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, - prephase_dec_pckd_qkv.q_seq_lens, - prephase_dec_kv_mmap, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, default_attn_type=AttentionType.ENCODER, - num_prefills_or_decodes=len(prephase_dec_pckd_qkv.q_seq_lens), - num_prefill_or_decode_tokens=sum(prephase_dec_pckd_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, - cross_kv_mmap=prephase_cross_kv_mmap, device=CUDA_DEVICE) # PREFILL: encoder attention diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e2d09522db601..2d1c274a49c07 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -596,15 +596,11 @@ def make_block_tables_slot_mapping(block_size: int, def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, - seq_lens: List[int], - kv_mmap: KVMemoryMap, + decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, - num_prefills_or_decodes: int, - num_prefill_or_decode_tokens: int, device: Union[torch.device, str], - context_lens: Optional[List[int]] = None, - encoder_seq_lens: Optional[List[int]] = None, - cross_kv_mmap: Optional[KVMemoryMap] = None, + encoder_test_params: Optional[PhaseTestParameters]=None, + cross_test_params: Optional[PhaseTestParameters]=None ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -639,6 +635,34 @@ def make_test_metadata( * AttentionMetadata structure supporting self- and cross-attention ''' + # Extract + # * Decoder input sequence lengths (seq_lens) + # * Decoder self-attention slot mapping & block tables (kv_mmap) + seq_lens = decoder_test_params.packed_qkvo.packed_qkv.seq_lens + kv_mmap = decoder_test_params.kv_mmap + + # is_prompt determines whether input tokens are treated + # as 100% prefill or 100% decode. In either case, + # the number of {prefills, decodes} and the number of + # {prefill, decode} tokens can be inferred from seq_lens + num_prefills_or_decodes = len(seq_lens) + num_prefill_or_decode_tokens = sum(seq_lens) + + if encoder_test_params is None: + encoder_seq_lens = None + else: + # Encoder/decoder models only: + # * Extract encoder input sequence lengths + encoder_seq_lens = encoder_test_params.q_seq_lens + + if cross_test_params is None: + cross_kv_mmap = None + else: + # Encoder/decoder models only: + # * Extract *cross-attention* slot_mapping and block table + # (kv_mmap) + cross_kv_mmap = cross_test_params.kv_mmap + if is_prompt: num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens From 50a45cce0a5114ae12dedeca0b39de349e960ebc Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:22:26 -0400 Subject: [PATCH 180/239] correctness check helper function --- tests/kernels/test_encoder_decoder_attn.py | 59 +++++++++++++--------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e9136909f8af6..19e44273c37f4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -495,7 +495,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decode_block_tables, \ decode_slot_mapping)) -def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, +def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -520,12 +520,13 @@ def _run_encoder_attention_test(attn: Attention, pckd_qkv: PackedQKVInputs, assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, None, + packed_qkv = encoder_test_params.packed_qkvo.packed_qkv + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, - pckd_qkv: PackedQKVInputs, + decoder_test_params: PhaseTestParameters, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: @@ -551,13 +552,14 @@ def _run_decoder_self_attention_test(attn: Attention, ''' assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type - return attn.forward(pckd_qkv.query, pckd_qkv.key, pckd_qkv.value, kv_cache, + packed_qkv = decoder_test_params.packed_qkvo.packed_qkv + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( - attn: Attention, dec_pckd_qkv: PackedQKVInputs, - cross_pckd_qkv: PackedQKVInputs, kv_cache: torch.Tensor, + attn: Attention, decoder_test_params: PhaseTestParameters, + cross_test_params: PhaseTestParameters, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -578,13 +580,29 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER + cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = None if cross_pckd_qkv is None else \ cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(dec_pckd_qkv.query, key, value, kv_cache, + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) +def _assert_actual_match_ideal(test_params: PhaseTestParameters, + output_under_test: torch.Tensor) -> None: + ''' + Assert that observed output matches the ideal output + contained in the test parameters data structure. + + Arguments: + + * test_params: Test parameters including packed ideal output + * output_under_test: actually observed output value + ''' + ideal_output = test_params.packed_qkvo.ideal_output + assert torch.allclose(ideal_output, + output_under_test + .view_as(ideal_output)) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -701,14 +719,13 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( attn, - enc_pckd_qkvo.packed_qkv, + enc_test_params, prephase_attn_metadata, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - assert torch.allclose(enc_pckd_qkvo.ideal_output, - enc_packed_actual_output - .view_as(enc_pckd_qkvo.ideal_output)) + _assert_actual_match_ideal(enc_test_params, + enc_packed_actual_output) # PREFILL: self-attention test @@ -721,9 +738,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Prefill self-attention correct? - assert torch.allclose( - prephase_dec_pckd_idl_out, - self_prefill_packed_actual_output.view_as(prephase_dec_pckd_idl_out)) + _assert_actual_match_ideal(prephase_dec_test_params, + self_prefill_packed_actual_output) # PREFILL: cross-attention test @@ -736,9 +752,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata) # - Prefill cross-attention correct? - assert torch.allclose( - prephase_cross_pckd_idl_out, - prephase_cross_pckd_act_out.view_as(prephase_cross_pckd_idl_out)) + _assert_actual_match_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata @@ -770,9 +785,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Decode self-attention correct? - assert torch.allclose( - decphase_dec_pckd_idl_out, - decphase_dec_pckd_act_out.view_as(decphase_dec_pckd_idl_out)) + _assert_actual_match_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) # DECODE: cross-attention test @@ -785,9 +799,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # - Decode cross-attention correct? - assert torch.allclose( - decphase_cross_pckd_idl_out, - decphase_cross_pckd_act_out.view_as(decphase_cross_pckd_idl_out)) + _assert_actual_match_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) # The following test conditions could in principle be a # standalone test, however the test setup is From cd0a1aaee6ec9af5f543e7f7e8e57acad1577038 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 19:30:59 -0400 Subject: [PATCH 181/239] wip --- tests/kernels/test_encoder_decoder_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 19e44273c37f4..cd72645adcb12 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -732,7 +732,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( attn, - prephase_dec_pckd_qkv, + prephase_dec_test_params, kv_cache, prephase_attn_metadata, attn_type=AttentionType.DECODER) @@ -746,8 +746,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, - prephase_dec_pckd_qkv, - prephase_cross_pckd_qkv, + prephase_dec_test_params, + prephase_cross_test_params, kv_cache, prephase_attn_metadata) From ec5977d76fdce71e8a7ef7955446a52846c6806c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:00:31 -0400 Subject: [PATCH 182/239] debugging test params integration --- tests/kernels/test_encoder_decoder_attn.py | 18 ++++++++---------- tests/kernels/utils.py | 9 +++++++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index cd72645adcb12..23400819de398 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -707,6 +707,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, @@ -759,19 +760,16 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # - Cross-attention KV context is equal in length to # encoder input - context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) + # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( attn_backend, False, dec_qkv.q_seq_lens, - decphase_dec_kv_mmap, + decoder_test_params=decphase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=decphase_cross_test_params, default_attn_type=AttentionType.DECODER, - context_lens=context_lens, - num_prefills_or_decodes=len(dec_qkv.q_seq_lens), - num_prefill_or_decode_tokens=len(dec_qkv.q_seq_lens), - encoder_seq_lens=enc_pckd_qkvo.packed_qkv.q_seq_lens, - cross_kv_mmap=decphase_cross_kv_mmap, device=CUDA_DEVICE) # DECODE: self-attention test @@ -779,7 +777,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( attn, - decphase_dec_pckd_qkv, + decphase_dec_test_params, kv_cache, decphase_attn_metadata, attn_type=AttentionType.DECODER) @@ -793,7 +791,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( attn, - decphase_dec_pckd_qkv, + decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) @@ -819,7 +817,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decphase_dec_pckd_qkv, + _run_encoder_decoder_cross_attention_test(attn, decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2d1c274a49c07..38d363e7adcce 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -596,6 +596,7 @@ def make_block_tables_slot_mapping(block_size: int, def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, + seq_lens: List[int], decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, device: Union[torch.device, str], @@ -638,7 +639,7 @@ def make_test_metadata( # Extract # * Decoder input sequence lengths (seq_lens) # * Decoder self-attention slot mapping & block tables (kv_mmap) - seq_lens = decoder_test_params.packed_qkvo.packed_qkv.seq_lens + #seq_lens = decoder_test_params.packed_qkvo.packed_qkv.q_seq_lens kv_mmap = decoder_test_params.kv_mmap # is_prompt determines whether input tokens are treated @@ -648,12 +649,16 @@ def make_test_metadata( num_prefills_or_decodes = len(seq_lens) num_prefill_or_decode_tokens = sum(seq_lens) + # Seems for non-prefix-caching scenarios context_lens + # is never needed + context_lens = None + if encoder_test_params is None: encoder_seq_lens = None else: # Encoder/decoder models only: # * Extract encoder input sequence lengths - encoder_seq_lens = encoder_test_params.q_seq_lens + encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: cross_kv_mmap = None From 1f7b2ebe2d36c56693b75e1bdffe1a24e2e62dd7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:21:10 -0400 Subject: [PATCH 183/239] passing tests with test params integration --- tests/kernels/test_encoder_decoder_attn.py | 14 +++++++++----- tests/kernels/utils.py | 8 +++++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 23400819de398..6b287ce4bc72f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -580,11 +580,15 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER - cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.key - value = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.value + if cross_test_params is None: + key = None + value = None + else: + cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv + key = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.key + value = None if cross_pckd_qkv is None else \ + cross_pckd_qkv.value return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 38d363e7adcce..e18de66264204 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -647,7 +647,13 @@ def make_test_metadata( # the number of {prefills, decodes} and the number of # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) - num_prefill_or_decode_tokens = sum(seq_lens) + if is_prompt: + # Prefill: operate on total num. of prompt + # tokens + num_prefill_or_decode_tokens = sum(seq_lens) + else: + # Decode: operate on one token per seq + num_prefill_or_decode_tokens = len(seq_lens) # Seems for non-prefix-caching scenarios context_lens # is never needed From 76b0b9ea8bb1247ee6324a03b3a1b0fa11578b02 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 20:32:06 -0400 Subject: [PATCH 184/239] format --- tests/kernels/test_encoder_decoder_attn.py | 57 +++++++++++----------- tests/kernels/utils.py | 40 ++++++--------- 2 files changed, 42 insertions(+), 55 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 6b287ce4bc72f..8462b1580794b 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -6,7 +6,6 @@ * Encoder/decoder cross-attention """ -import copy from typing import Optional import pytest @@ -147,16 +146,15 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, ) -def _decoder_attn_setup(batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0) -> tuple[QKVInputs, - PhaseTestParameters, - PhaseTestParameters, - int]: +def _decoder_attn_setup( + batch_size: int, + num_heads: int, + head_size: int, + block_size: int, + scale: float, + max_q_seq_len: int, + block_base_addr: int = 0 +) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. @@ -402,7 +400,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = \ + prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -495,7 +494,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, decode_block_tables, \ decode_slot_mapping)) -def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, + +def _run_encoder_attention_test(attn: Attention, + encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: ''' @@ -521,8 +522,8 @@ def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestP assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = attn_type packed_qkv = encoder_test_params.packed_qkvo.packed_qkv - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, - attn_metadata) + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, + None, attn_metadata) def _run_decoder_self_attention_test(attn: Attention, @@ -553,8 +554,8 @@ def _run_decoder_self_attention_test(attn: Attention, assert attn_type == AttentionType.DECODER attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, - attn_metadata) + return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( @@ -589,8 +590,9 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, - attn_metadata) + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, + value, kv_cache, attn_metadata) + def _assert_actual_match_ideal(test_params: PhaseTestParameters, output_under_test: torch.Tensor) -> None: @@ -605,8 +607,8 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, ''' ideal_output = test_params.packed_qkvo.ideal_output assert torch.allclose(ideal_output, - output_under_test - .view_as(ideal_output)) + output_under_test.view_as(ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -671,11 +673,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(batch_size, - num_heads, - head_size, - scale, - max_enc_seq_len) + enc_test_params = _encoder_attn_setup(batch_size, num_heads, head_size, + scale, max_enc_seq_len) # Decoder self-attention setup @@ -729,8 +728,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, - enc_packed_actual_output) + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -821,7 +819,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, decphase_dec_test_params, + _run_encoder_decoder_cross_attention_test(attn, + decphase_dec_test_params, None, kv_cache, decphase_attn_metadata) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e18de66264204..cf4b41b96996a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -96,11 +96,7 @@ def ref_masked_attention(query: torch.Tensor, QKVInputs = namedtuple("QKVInputs", ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) -QKVO = namedtuple("QKVO", - [ - "qkv", - "ideal_output" - ]) +QKVO = namedtuple("QKVO", ["qkv", "ideal_output"]) # total_num_tokens x (num_heads*head_size) PackedQKVInputs = namedtuple("PackedQKVInputs", [ @@ -108,20 +104,13 @@ def ref_masked_attention(query: torch.Tensor, "q_seq_lens", "kv_seq_lens" ]) -PackedQKVO = namedtuple("PackedQKVO", - [ - "packed_qkv", - "ideal_output" - ]) +PackedQKVO = namedtuple("PackedQKVO", ["packed_qkv", "ideal_output"]) -KVMemoryMap = namedtuple("KVMemoryMap", [ - "block_tables", "slot_mapping" -]) +KVMemoryMap = namedtuple("KVMemoryMap", ["block_tables", "slot_mapping"]) + +PhaseTestParameters = namedtuple("PhaseTestParameters", + ["packed_qkvo", "kv_mmap"]) -PhaseTestParameters = namedtuple("PhaseTestParameters", [ - "packed_qkvo", - "kv_mmap" -]) def make_qkv( batch_size: int, @@ -600,8 +589,8 @@ def make_test_metadata( decoder_test_params: PhaseTestParameters, default_attn_type: AttentionType, device: Union[torch.device, str], - encoder_test_params: Optional[PhaseTestParameters]=None, - cross_test_params: Optional[PhaseTestParameters]=None + encoder_test_params: Optional[PhaseTestParameters] = None, + cross_test_params: Optional[PhaseTestParameters] = None ) -> AttentionMetadata: ''' Construct fake attention metadata for a combined self-/cross-attention @@ -647,13 +636,12 @@ def make_test_metadata( # the number of {prefills, decodes} and the number of # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) - if is_prompt: - # Prefill: operate on total num. of prompt - # tokens - num_prefill_or_decode_tokens = sum(seq_lens) - else: - # Decode: operate on one token per seq - num_prefill_or_decode_tokens = len(seq_lens) + + # Prefill: operate on total num. of prompt + # tokens + # Decode: operate on one token per seq + num_prefill_or_decode_tokens = \ + sum(seq_lens) if is_prompt else len(seq_lens) # Seems for non-prefix-caching scenarios context_lens # is never needed From aa5363a589209a306b78dc3037eb0e278416386e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 21:32:57 -0400 Subject: [PATCH 185/239] test points and test resources structures integrated --- tests/kernels/test_encoder_decoder_attn.py | 188 +++++++++++++-------- 1 file changed, 113 insertions(+), 75 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 8462b1580794b..65bc7e55ee503 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -10,7 +10,7 @@ import pytest import torch - +from collections import namedtuple from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType @@ -31,8 +31,25 @@ MAX_ENC_SEQ_LENS = [128] -def _basic_setup(num_heads: int, head_size: int, num_blocks: int, - block_size: int, backend_name: str) -> tuple: +TestPoint = namedtuple("TestPoint",[ + "num_heads", + "head_size", + "backend_name", + "batch_size", + "block_size", + "max_dec_seq_len", + "max_enc_seq_len", + "num_blocks" +]) + +TestResources = namedtuple("TestResources",[ + "scale", + "attn_backend", + "attn", + "kv_cache" +]) + +def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' Compute & build entities required for the self-/cross-attention test. @@ -55,29 +72,45 @@ def _basic_setup(num_heads: int, head_size: int, num_blocks: int, * None if num_blocks or block_size is None ''' - scale = float(1.0 / (head_size**0.5)) - attn_backend = make_backend(backend_name) + scale = float(1.0 / (test_pt.head_size**0.5)) + attn_backend = make_backend(test_pt.backend_name) attn = Attention( - num_heads, - head_size, + test_pt.num_heads, + test_pt.head_size, scale=scale, ) - if num_blocks is None or num_heads is None: + if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return scale, attn_backend, attn, None + return TestResources(scale, + attn_backend, + attn, + None) # Construct KV cache - kv_cache = make_kv_cache(num_blocks, - num_heads, - head_size, - block_size, + kv_cache = make_kv_cache(test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, device=CUDA_DEVICE) - return scale, attn_backend, attn, kv_cache + return TestResources(scale, + attn_backend, + attn, + kv_cache) -def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, - scale: float, max_q_seq_len: int) \ +def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: + (num_heads, + head_size, + _, + batch_size, + _, + _, + max_q_seq_len, + _) = test_pt + + scale=test_rsrcs.scale + ''' Set up test vectors & data structures for encoder attention test. @@ -147,14 +180,11 @@ def _encoder_attn_setup(batch_size: int, num_heads: int, head_size: int, def _decoder_attn_setup( - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_q_seq_len: int, - block_base_addr: int = 0 + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: + ''' Set up test vectors & data structures for self-attention test. @@ -232,6 +262,17 @@ def _decoder_attn_setup( * max_block_idx: highest block address in the self-attention block-table ''' + (num_heads, + head_size, + _, + batch_size, + block_size, + max_q_seq_len, + _, + _) = test_pt + + scale = test_rsrcs.scale + max_kv_seq_len = max_q_seq_len qkv, \ @@ -325,17 +366,13 @@ def _decoder_attn_setup( def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, encoder_test_params: PhaseTestParameters, - prefill_phase_test_params: + prefill_decoder_phase_test_params: PhaseTestParameters, - batch_size: int, - num_heads: int, - head_size: int, - block_size: int, - scale: float, - max_decoder_seq_len: int, - max_encoder_seq_len: int, + test_pt: TestPoint, + test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ -> tuple: + ''' Set up test vectors & data structures for cross-attention test. @@ -397,11 +434,22 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, * max_block_idx: highest block address in the cross-attention block-table ''' + (num_heads, + head_size, + _, + batch_size, + block_size, + max_decoder_seq_len, + max_encoder_seq_len, + _) = test_pt + + scale = test_rsrcs.scale + decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens prefill_q_seq_lens = \ - prefill_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens cross_kv, \ @@ -526,11 +574,11 @@ def _run_encoder_attention_test(attn: Attention, None, attn_metadata) -def _run_decoder_self_attention_test(attn: Attention, +def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: + ''' Run decoder self-attention test. @@ -552,6 +600,8 @@ def _run_decoder_self_attention_test(attn: Attention, & attn_metadata ''' assert attn_type == AttentionType.DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, @@ -559,9 +609,11 @@ def _run_decoder_self_attention_test(attn: Attention, def _run_encoder_decoder_cross_attention_test( - attn: Attention, decoder_test_params: PhaseTestParameters, - cross_test_params: PhaseTestParameters, kv_cache: torch.Tensor, + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + cross_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata) -> torch.Tensor: + ''' Run encoder/decoder cross-attention test. @@ -581,6 +633,8 @@ def _run_encoder_decoder_cross_attention_test( & attn_metadata ''' attn_metadata.attention_type = AttentionType.ENCODER_DECODER + attn = test_rsrcs.attn + kv_cache = test_rsrcs.kv_cache if cross_test_params is None: key = None value = None @@ -609,7 +663,6 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, assert torch.allclose(ideal_output, output_under_test.view_as(ideal_output)) - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -622,6 +675,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: + ''' Encoder/decoder attention test: @@ -651,19 +705,18 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - # Num KV cache blocks - num_blocks = 4096 + test_pt = TestPoint(num_heads, + head_size, + backend_name, + batch_size, + block_size, + max_dec_seq_len, + max_dec_seq_len, + 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - scale, \ - attn_backend, \ - attn, \ - kv_cache = _basic_setup(num_heads, - head_size, - num_blocks, - block_size, - backend_name) + test_rsrcs = _make_test_resources(test_pt) # Encoder attention setup @@ -673,20 +726,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(batch_size, num_heads, head_size, - scale, max_enc_seq_len) + enc_test_params = _encoder_attn_setup(test_pt,test_rsrcs) # Decoder self-attention setup dec_qkv, \ prephase_dec_test_params, \ decphase_dec_test_params, \ - cross_block_base_addr = _decoder_attn_setup(batch_size, - num_heads, - head_size, - block_size, - scale, - max_dec_seq_len) + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup @@ -695,20 +742,15 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, - batch_size, - num_heads, - head_size, - block_size, - scale, - max_dec_seq_len, - max_enc_seq_len, + test_pt, + test_rsrcs, block_base_addr = \ cross_block_base_addr) # Shared prefill metadata structure prephase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, + test_rsrcs.attn_backend, True, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, @@ -722,7 +764,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( - attn, + test_rsrcs.attn, enc_test_params, prephase_attn_metadata, attn_type=AttentionType.ENCODER) @@ -734,9 +776,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( self_prefill_packed_actual_output: torch.Tensor = \ _run_decoder_self_attention_test( - attn, + test_rsrcs, prephase_dec_test_params, - kv_cache, prephase_attn_metadata, attn_type=AttentionType.DECODER) @@ -748,10 +789,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - kv_cache, prephase_attn_metadata) # - Prefill cross-attention correct? @@ -765,7 +805,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) decphase_attn_metadata: AttentionMetadata = make_test_metadata( - attn_backend, + test_rsrcs.attn_backend, False, dec_qkv.q_seq_lens, decoder_test_params=decphase_dec_test_params, @@ -778,9 +818,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( - attn, + test_rsrcs, decphase_dec_test_params, - kv_cache, decphase_attn_metadata, attn_type=AttentionType.DECODER) @@ -792,10 +831,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( - attn, + test_rsrcs, decphase_dec_test_params, None, - kv_cache, decphase_attn_metadata) # - Decode cross-attention correct? @@ -819,9 +857,9 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # of prefill and decode tokens. decphase_attn_metadata.num_prefill_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(attn, + _run_encoder_decoder_cross_attention_test(test_rsrcs, decphase_dec_test_params, - None, kv_cache, + None, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" From 53514169c2d59b81b97d834f5264634ebd84687a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 21:34:30 -0400 Subject: [PATCH 186/239] formatting --- tests/kernels/test_encoder_decoder_attn.py | 91 ++++++---------------- 1 file changed, 22 insertions(+), 69 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 65bc7e55ee503..e14c4d16b141e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -6,11 +6,12 @@ * Encoder/decoder cross-attention """ +from collections import namedtuple from typing import Optional import pytest import torch -from collections import namedtuple + from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType @@ -30,24 +31,14 @@ MAX_DEC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128] - -TestPoint = namedtuple("TestPoint",[ - "num_heads", - "head_size", - "backend_name", - "batch_size", - "block_size", - "max_dec_seq_len", - "max_enc_seq_len", - "num_blocks" +TestPoint = namedtuple("TestPoint", [ + "num_heads", "head_size", "backend_name", "batch_size", "block_size", + "max_dec_seq_len", "max_enc_seq_len", "num_blocks" ]) -TestResources = namedtuple("TestResources",[ - "scale", - "attn_backend", - "attn", - "kv_cache" -]) +TestResources = namedtuple("TestResources", + ["scale", "attn_backend", "attn", "kv_cache"]) + def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' @@ -81,10 +72,7 @@ def _make_test_resources(test_pt: TestPoint) -> TestResources: ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return TestResources(scale, - attn_backend, - attn, - None) + return TestResources(scale, attn_backend, attn, None) # Construct KV cache kv_cache = make_kv_cache(test_pt.num_blocks, @@ -92,25 +80,14 @@ def _make_test_resources(test_pt: TestPoint) -> TestResources: test_pt.head_size, test_pt.block_size, device=CUDA_DEVICE) - return TestResources(scale, - attn_backend, - attn, - kv_cache) + return TestResources(scale, attn_backend, attn, kv_cache) def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - (num_heads, - head_size, - _, - batch_size, - _, - _, - max_q_seq_len, - _) = test_pt - - scale=test_rsrcs.scale + (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + scale = test_rsrcs.scale ''' Set up test vectors & data structures for encoder attention test. @@ -184,7 +161,6 @@ def _decoder_attn_setup( test_rsrcs: TestResources, block_base_addr: int = 0, ) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: - ''' Set up test vectors & data structures for self-attention test. @@ -262,13 +238,7 @@ def _decoder_attn_setup( * max_block_idx: highest block address in the self-attention block-table ''' - (num_heads, - head_size, - _, - batch_size, - block_size, - max_q_seq_len, - _, + (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, _) = test_pt scale = test_rsrcs.scale @@ -372,7 +342,6 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ -> tuple: - ''' Set up test vectors & data structures for cross-attention test. @@ -434,14 +403,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, * max_block_idx: highest block address in the cross-attention block-table ''' - (num_heads, - head_size, - _, - batch_size, - block_size, - max_decoder_seq_len, - max_encoder_seq_len, - _) = test_pt + (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, + max_encoder_seq_len, _) = test_pt scale = test_rsrcs.scale @@ -578,7 +541,6 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, attn_type: AttentionType) -> torch.Tensor: - ''' Run decoder self-attention test. @@ -601,7 +563,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, ''' assert attn_type == AttentionType.DECODER attn = test_rsrcs.attn - kv_cache = test_rsrcs.kv_cache + kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = attn_type packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, @@ -609,11 +571,9 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, + test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, cross_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata) -> torch.Tensor: - ''' Run encoder/decoder cross-attention test. @@ -663,6 +623,7 @@ def _assert_actual_match_ideal(test_params: PhaseTestParameters, assert torch.allclose(ideal_output, output_under_test.view_as(ideal_output)) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -675,7 +636,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: - ''' Encoder/decoder attention test: @@ -705,14 +665,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) - test_pt = TestPoint(num_heads, - head_size, - backend_name, - batch_size, - block_size, - max_dec_seq_len, - max_dec_seq_len, - 4096) + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_dec_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -726,7 +680,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # anyway but are required to be present & valid by the # backend. - enc_test_params = _encoder_attn_setup(test_pt,test_rsrcs) + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -859,8 +813,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( with pytest.raises(NotImplementedError) as exc_info: _run_encoder_decoder_cross_attention_test(test_rsrcs, decphase_dec_test_params, - None, - decphase_attn_metadata) + None, decphase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL From 8d390a075bb7435aed10d1af97fada08aacea489 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:00:52 -0400 Subject: [PATCH 187/239] first attempt at chunked prefill failure test --- tests/kernels/test_encoder_decoder_attn.py | 133 +++++++++++++++++++++ vllm/attention/backends/utils.py | 3 +- vllm/attention/backends/xformers.py | 4 +- 3 files changed, 136 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index e14c4d16b141e..fd0f96658e24f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -817,3 +817,136 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( # "Encoder decoder models do not currently support chunked prefill" assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_backend_fails_for_chunked_prefill_enc_dec( + num_heads: int, head_size: int, backend_name: str, batch_size: int, + block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: + ''' + Encoder/decoder attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_dec_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Encoder attention setup + + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Decoder self-attention setup + + dec_qkv, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + + # Cross-attention setup + + prephase_cross_test_params, \ + decphase_cross_test_params, \ + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + + + + # PREFILL: self-attention test + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + prephase_attn_metadata.num_decode_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + + _run_decoder_self_attention_test( + test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) \ No newline at end of file diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ad88b4f964a54..66916c28d7685 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -44,8 +44,7 @@ def check_hip_or_chunked_prefill_attention_encdec( # xFormers backend. raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - if attn_metadata.attention_type != AttentionType.DECODER \ - and attn_metadata.num_prefill_tokens > 0 and \ + if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible # with chunked prefill. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cc240242f7eae..31630f0dbecc3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -480,7 +480,7 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if attn_type != AttentionType.DECODER: + if attn_metadata.is_all_encoder_attn_metadata_set: # Raise NotImplementedError for unsupported encoder/decoder # scenarios from vllm.attention.backends.utils import ( @@ -528,7 +528,7 @@ def forward( num_prefill_tokens = query.shape[0] num_decode_tokens = 0 - if attn_type != AttentionType.ENCODER_DECODER: + if attn_type == AttentionType.DECODER: # Only enforce this shape-constraint for decoder # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens From 68b6d4b29ffb787713beacd66ee922b1d8ec9326 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:06:58 -0400 Subject: [PATCH 188/239] narrowed the space of test-cases for unsupported scenarios --- tests/kernels/test_encoder_decoder_attn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index fd0f96658e24f..14feb1f948b2a 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -31,6 +31,10 @@ MAX_DEC_SEQ_LENS = [128] MAX_ENC_SEQ_LENS = [128] +# Narrow teest-cases for unsupported-scenario +# tests +HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] + TestPoint = namedtuple("TestPoint", [ "num_heads", "head_size", "backend_name", "batch_size", "block_size", "max_dec_seq_len", "max_enc_seq_len", "num_blocks" @@ -823,7 +827,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) @pytest.mark.parametrize("backend_name", BACKEND_NAMES) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) From 5923002add95a787a6ae5604bd8a5aa2f14c66d1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:12:59 -0400 Subject: [PATCH 189/239] format --- tests/kernels/test_encoder_decoder_attn.py | 52 +++++++--------------- vllm/attention/backends/utils.py | 1 - 2 files changed, 15 insertions(+), 38 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 14feb1f948b2a..70cc807277af8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -798,32 +798,6 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( _assert_actual_match_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out) - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - decphase_attn_metadata.num_prefill_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - _run_encoder_decoder_cross_attention_test(test_rsrcs, - decphase_dec_test_params, - None, decphase_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -833,10 +807,14 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_chunked_prefill_enc_dec( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -927,8 +905,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec( # - Is encoder attention result correct? _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) - - # PREFILL: self-attention test # The following test conditions could in principle be a @@ -949,8 +925,10 @@ def test_backend_fails_for_chunked_prefill_enc_dec( prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_decoder_self_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) \ No newline at end of file + _run_decoder_self_attention_test(test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 66916c28d7685..9f38147799045 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" from vllm.attention import AttentionMetadata -from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.xformers import XFormersMetadata from vllm.utils import is_hip From c3e5d2aa0520f371a55de158af542c2ae22b4373 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 5 Jun 2024 22:26:25 -0400 Subject: [PATCH 190/239] skeleton of encdec prefix cache failure test; fixed bug where max enc seq len was unused --- tests/kernels/test_encoder_decoder_attn.py | 139 ++++++++++++++++++++- 1 file changed, 137 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 70cc807277af8..b71934cb116ae 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -670,7 +670,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( override_backend_env_variable(monkeypatch, backend_name) test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_dec_seq_len, 4096) + block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -845,7 +845,142 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, override_backend_env_variable(monkeypatch, backend_name) test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_dec_seq_len, 4096) + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Encoder attention setup + + # Let encoder_attn_setup() choose default block table + # base address; the block_tables and slot_mapping + # tensors are not actually utilized by encoder attention + # anyway but are required to be present & valid by the + # backend. + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Decoder self-attention setup + + dec_qkv, \ + prephase_dec_test_params, \ + decphase_dec_test_params, \ + cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + + # Cross-attention setup + + prephase_cross_test_params, \ + decphase_cross_test_params, \ + = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr = \ + cross_block_base_addr) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + # * Use prefill kernel + + enc_packed_actual_output: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata, + attn_type=AttentionType.ENCODER) + + # - Is encoder attention result correct? + _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + + # PREFILL: self-attention test + + # The following test conditions could in principle be a + # standalone test, however the test setup is + # so involved that it is easier + # to piggyback off of the test vectors & other data structures + # created for testing decode-phase encoder/decoder cross- + # attention above. + # ---- + # Set up a contrived scenario where the attention metadata + # is configured for chunked prefill & encoder/decoder cross- + # attention. Required that this triggers a NotImplementedError. + # + # We assume that decode_attn_metadata.num_decode_tokens > 1 + # already; the line below sets up a chunked prefill + # metadata configuration where there is nominally a mix + # of prefill and decode tokens. + prephase_attn_metadata.num_decode_tokens = 1 + with pytest.raises(NotImplementedError) as exc_info: + + _run_decoder_self_attention_test(test_rsrcs, + prephase_dec_test_params, + prephase_attn_metadata, + attn_type=AttentionType.DECODER) + + # "Encoder decoder models do not currently support chunked prefill" + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + + +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: + ''' + Encoder/decoder attention test: + + * Construct fake test vectors for self- and cross-attention + * Construct attention metadata structure with self- and cross-attention + attributes + * Test self- and cross-attention in the following order + + * Prefill self-attention + * Prefill cross-attention + * Decode self-attention + * Decode cross-attention + * This order would exacerbate any accidental overlap in the + self-/cross-attention block tables, which we attempt to avoid + * Validate output correctness against ideal reference attention + implementation + + Block tables are constructed such that cross-attention KV cache is in a + higher, non-intersecting address-space than self-attention KV cache. + + Self- and cross-attention share the same query tensor but not the K/V + tensors. Self-attention K/Vs must have the same seq len as Q while + cross-attention K/Vs are allowed to differ in seq len, as is often the case + for cross-attention. + ''' + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init From 739ab3ca7ccc945a569c7f05a270b8eb2de78317 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:29:23 -0400 Subject: [PATCH 191/239] wip prefill test --- tests/kernels/test_encoder_decoder_attn.py | 24 ++++++++++--- vllm/attention/backends/utils.py | 39 ++++++++++++++-------- vllm/attention/backends/xformers.py | 14 +++++--- 3 files changed, 54 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b71934cb116ae..2d8942afb7adb 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,7 +16,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, + STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -1006,7 +1008,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Cross-attention setup prephase_cross_test_params, \ - decphase_cross_test_params, \ + _, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, @@ -1057,8 +1059,22 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. - prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: + # Fake a non-empty block_tables + # prephase_dec_test_params.kv_mmap.block_tables = \ + # decphase_dec_test_params.kv_mmap.block_tables + + # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables + + # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, + # prephase_dec_test_params.kv_mmap.slot_mapping) + + # prefix_test_params = PhaseTestParameters( + # prephase_dec_test_params.packed_qkvo, + # prefix_kv_mmap + # ) + + prephase_attn_metadata.block_tables = decphase_dec_test_params.kv_mmap.block_tables _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, @@ -1066,4 +1082,4 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, attn_type=AttentionType.DECODER) # "Encoder decoder models do not currently support chunked prefill" - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 9f38147799045..3200ef9113820 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,28 +1,39 @@ """Attention utils""" -from vllm.attention import AttentionMetadata -from vllm.attention.backends.xformers import XFormersMetadata +# from vllm.attention import AttentionMetadata +# from vllm.attention.backends.xformers import XFormersMetadata from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ -"Encoder/decoder models " + \ -"currently do not support chunked prefill." +"Chunked prefill is not currently " + \ +"supported with encoder/decoder models." STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ -"Encoder/decoder models currently" + \ -"do not support ROCm/HIP." +"ROCm/HIP is not currently supported" + \ +"with encoder/decoder models." STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ -"Encoder/decoder models currently support only the XFormers backend." +"Currently only the XFormers backend " + \ + "supports encoder/decoder models." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ +"Prefix caching is not currently supported " + \ +"with encoder/decoder models" # Check for unsupported encoder/decoder scenarios +def is_encoder_decoder_metadata(attn_metadata) -> bool: + return attn_metadata.is_all_encoder_attn_metadata_set + +def fail_encoder_decoder_prefix_caching() -> None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) + def check_hip_or_chunked_prefill_attention_encdec( - attn_metadata: AttentionMetadata): + attn_metadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. @@ -36,12 +47,12 @@ def check_hip_or_chunked_prefill_attention_encdec( # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - if not isinstance(attn_metadata, XFormersMetadata): - # Right now encoder/decoder support is only implemented - # for the XFormers backend. Pretty unlikely to encounter - # this case currently given this function will be invoked inside - # xFormers backend. - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) + # if not isinstance(attn_metadata, XFormersMetadata): + # # Right now encoder/decoder support is only implemented + # # for the XFormers backend. Pretty unlikely to encounter + # # this case currently given this function will be invoked inside + # # xFormers backend. + # raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 31630f0dbecc3..6b94ef10764ea 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -15,6 +15,11 @@ PagedAttentionMetadata) from vllm.logger import init_logger +from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec, + is_encoder_decoder_metadata, + fail_encoder_decoder_prefix_caching) + logger = init_logger(__name__) @@ -480,11 +485,9 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if attn_metadata.is_all_encoder_attn_metadata_set: + if is_encoder_decoder_metadata(attn_metadata): # Raise NotImplementedError for unsupported encoder/decoder # scenarios - from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec) check_hip_or_chunked_prefill_attention_encdec(attn_metadata) if (kv_cache is not None): @@ -562,12 +565,13 @@ def forward( assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None + if is_encoder_decoder_metadata(attn_metadata): + fail_encoder_decoder_prefix_caching() + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - # - # TODO(afeldman-nm): support cross-attention out = PagedAttention.forward_prefix( query, key, From 22652249deac4d19ab9116b0b943767501d538f8 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:53:38 -0400 Subject: [PATCH 192/239] passing prefix cache failure test --- tests/kernels/test_encoder_decoder_attn.py | 5 ++++- vllm/attention/backends/xformers.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2d8942afb7adb..c5a888c8be8a3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1074,8 +1074,11 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # prefix_kv_mmap # ) - prephase_attn_metadata.block_tables = decphase_dec_test_params.kv_mmap.block_tables + num_seqs = len(prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) + prephase_attn_metadata._cached_prefill_metadata.block_tables = torch.randint(0,10,(num_seqs,1)) + + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6b94ef10764ea..a930643afb0d1 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -562,12 +562,12 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - assert prefill_meta.query_start_loc is not None - assert prefill_meta.max_query_len is not None - if is_encoder_decoder_metadata(attn_metadata): fail_encoder_decoder_prefix_caching() + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, From d72aaa9098694fda0771477429735bc35985d6c0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 10:55:00 -0400 Subject: [PATCH 193/239] format --- tests/kernels/test_encoder_decoder_attn.py | 29 +++++++++++----------- vllm/attention/backends/utils.py | 5 ++-- vllm/attention/backends/xformers.py | 8 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c5a888c8be8a3..4e1f9a2ec7709 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,8 +16,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, - STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor @@ -945,13 +944,13 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: @@ -1063,7 +1062,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Fake a non-empty block_tables # prephase_dec_test_params.kv_mmap.block_tables = \ # decphase_dec_test_params.kv_mmap.block_tables - + # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, @@ -1072,13 +1071,15 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # prefix_test_params = PhaseTestParameters( # prephase_dec_test_params.packed_qkvo, # prefix_kv_mmap - # ) + # ) - num_seqs = len(prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) + num_seqs = len( + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) - prephase_attn_metadata._cached_prefill_metadata.block_tables = torch.randint(0,10,(num_seqs,1)) + prephase_attn_metadata._cached_prefill_metadata.block_tables = \ + torch.randint( + 0, 10, (num_seqs, 1)) - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3200ef9113820..3423587be4889 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -29,11 +29,12 @@ def is_encoder_decoder_metadata(attn_metadata) -> bool: return attn_metadata.is_all_encoder_attn_metadata_set + def fail_encoder_decoder_prefix_caching() -> None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec( - attn_metadata) -> None: + +def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a930643afb0d1..149de709141a5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,15 +11,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import ( + check_hip_or_chunked_prefill_attention_encdec, + fail_encoder_decoder_prefix_caching, is_encoder_decoder_metadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger -from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec, - is_encoder_decoder_metadata, - fail_encoder_decoder_prefix_caching) - logger = init_logger(__name__) From 1c19d36f5f9dde347909bb3b5f38d9421d772461 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 11:31:45 -0400 Subject: [PATCH 194/239] type annotations; formatting --- tests/kernels/test_encoder_decoder_attn.py | 45 ++++++---------------- tests/kernels/utils.py | 22 ++++++----- vllm/attention/backends/utils.py | 13 ++----- vllm/attention/backends/xformers.py | 2 - 4 files changed, 27 insertions(+), 55 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 4e1f9a2ec7709..26a9c49f8f069 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,9 +1,11 @@ """ -Test +Tests: + +* E2E Encoder attention + Decoder self-attention + + Encoder/decoder cross-attention +* Confirm enc/dec models will fail for chunked prefill +* Confirm enc/dec models will fail for prefix caching -* Encoder attention -* Decoder self-attention -* Encoder/decoder cross-attention """ from collections import namedtuple @@ -346,7 +348,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: Optional[int]=0) \ - -> tuple: + -> tuple[PhaseTestParameters, + PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -866,13 +869,13 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, dec_qkv, \ prephase_dec_test_params, \ - decphase_dec_test_params, \ + _, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup prephase_cross_test_params, \ - decphase_cross_test_params, \ + _, \ = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, enc_test_params, prephase_dec_test_params, @@ -908,13 +911,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, # PREFILL: self-attention test - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. @@ -1001,7 +997,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, dec_qkv, \ prephase_dec_test_params, \ - decphase_dec_test_params, \ + _, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) # Cross-attention setup @@ -1043,13 +1039,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # PREFILL: self-attention test - # The following test conditions could in principle be a - # standalone test, however the test setup is - # so involved that it is easier - # to piggyback off of the test vectors & other data structures - # created for testing decode-phase encoder/decoder cross- - # attention above. - # ---- # Set up a contrived scenario where the attention metadata # is configured for chunked prefill & encoder/decoder cross- # attention. Required that this triggers a NotImplementedError. @@ -1060,18 +1049,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # of prefill and decode tokens. with pytest.raises(NotImplementedError) as exc_info: # Fake a non-empty block_tables - # prephase_dec_test_params.kv_mmap.block_tables = \ - # decphase_dec_test_params.kv_mmap.block_tables - - # prefix_block_tables = decphase_dec_test_params.kv_mmap.block_tables - - # prefix_kv_mmap = KVMemoryMap(prefix_block_tables, - # prephase_dec_test_params.kv_mmap.slot_mapping) - - # prefix_test_params = PhaseTestParameters( - # prephase_dec_test_params.packed_qkvo, - # prefix_kv_mmap - # ) num_seqs = len( prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index cf4b41b96996a..1680231c41fbf 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -246,8 +246,9 @@ def make_qkv( decode_kv_seq_lens) -def pack_tensor(unpacked_tensor: torch.Tensor, seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def pack_tensor( + unpacked_tensor: torch.Tensor, seq_lens: List[int], + device: Union[torch.device, str]) -> tuple[torch.Tensor, List[int]]: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -355,9 +356,11 @@ def make_backend(backend_name: str) -> AttentionBackend: f"Unrecognized backend_name {backend_name} for unit test") -def _make_metadata_tensors(seq_lens: List[int], context_lens: List[int], - encoder_seq_lens: List[int], - device: Union[torch.device, str]) -> tuple: +def _make_metadata_tensors( + seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], + device: Union[torch.device, str] +) -> tuple[torch.Tensor, torch.Tensor, int, int, Optional[List[int]], + torch.Tensor, int]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -500,10 +503,11 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], maybe_make_long_tensor(decode_slot_mapping, device) -def make_block_tables_slot_mapping(block_size: int, - seq_lens: List[int], - device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple: +def make_block_tables_slot_mapping( + block_size: int, + seq_lens: List[int], + device: Union[torch.device, str], + block_base_addr: int = 0) -> tuple[torch.Tensor, List[int], int]: ''' Construct fake block tables & slot mappings. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3423587be4889..d67251dd17b23 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" -# from vllm.attention import AttentionMetadata -# from vllm.attention.backends.xformers import XFormersMetadata +from vllm.attention import AttentionMetadata from vllm.utils import is_hip # Error string(s) for encoder/decoder @@ -34,7 +33,8 @@ def fail_encoder_decoder_prefix_caching() -> None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: +def check_hip_or_chunked_prefill_attention_encdec( + attn_metadata: AttentionMetadata) -> None: ''' Check for unsupported encoder/decoder scenarios when invoking attention. @@ -48,13 +48,6 @@ def check_hip_or_chunked_prefill_attention_encdec(attn_metadata) -> None: # encoder/decoder models raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) - # if not isinstance(attn_metadata, XFormersMetadata): - # # Right now encoder/decoder support is only implemented - # # for the XFormers backend. Pretty unlikely to encounter - # # this case currently given this function will be invoked inside - # # xFormers backend. - # raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND) - if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 149de709141a5..72fe333021f08 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -281,8 +281,6 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Extract appropriate attention bias from attention metadata according to attention type. - Depends on attn_metadata having a valid attention_type. - Arguments: * attn_metadata: Attention metadata structure associated with attention From e10340d83e98de0588b2f3a83d6ebc95de7ec608 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 11:49:14 -0400 Subject: [PATCH 195/239] completely replaced collections.namedtuple with typing.NamedTuple w/ type annotations; formatting --- tests/kernels/test_encoder_decoder_attn.py | 27 ++++++---- tests/kernels/utils.py | 61 ++++++++++++++-------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 26a9c49f8f069..a923ba4e18a4e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -8,15 +8,14 @@ """ -from collections import namedtuple -from typing import Optional +from typing import NamedTuple, Optional import pytest import torch from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import ( STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -38,13 +37,23 @@ # tests HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]] -TestPoint = namedtuple("TestPoint", [ - "num_heads", "head_size", "backend_name", "batch_size", "block_size", - "max_dec_seq_len", "max_enc_seq_len", "num_blocks" -]) -TestResources = namedtuple("TestResources", - ["scale", "attn_backend", "attn", "kv_cache"]) +class TestPoint(NamedTuple): + num_heads: int + head_size: int + backend_name: str + batch_size: int + block_size: int + max_dec_seq_len: int + max_enc_seq_len: int + num_blocks: int + + +class TestResources(NamedTuple): + scale: float + attn_backend: AttentionBackend + attn: Attention + kv_cache: torch.Tensor def _make_test_resources(test_pt: TestPoint) -> TestResources: diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 1680231c41fbf..2d802645de060 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,8 +2,7 @@ import itertools import random -from collections import namedtuple -from typing import List, Optional, Union +from typing import List, NamedTuple, Optional, Union import pytest import torch @@ -19,6 +18,44 @@ STR_INVALID_VAL: str = "INVALID" +class QKVInputs(NamedTuple): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_seq_lens: List[int] + kv_seq_lens: List[int] + + +class QKVO(NamedTuple): + qkv: QKVInputs + ideal_output: torch.Tensor + + +class PackedQKVInputs(NamedTuple): + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + q_start_loc_list: List[int] + kv_start_loc_list: List[int] + q_seq_lens: List[int] + kv_seq_lens: List[int] + + +class PackedQKVO(NamedTuple): + packed_qkv: PackedQKVInputs + ideal_output: torch.Tensor + + +class KVMemoryMap(NamedTuple): + block_tables: torch.Tensor + slot_mapping: torch.Tensor + + +class PhaseTestParameters(NamedTuple): + packed_qkvo: PackedQKVO + kv_mmap: KVMemoryMap + + def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' @@ -92,26 +129,6 @@ def ref_masked_attention(query: torch.Tensor, return out -# batch_size x max_q_seq_len x num_heads x head_size -QKVInputs = namedtuple("QKVInputs", - ["query", "key", "value", "q_seq_lens", "kv_seq_lens"]) - -QKVO = namedtuple("QKVO", ["qkv", "ideal_output"]) - -# total_num_tokens x (num_heads*head_size) -PackedQKVInputs = namedtuple("PackedQKVInputs", [ - "query", "key", "value", "q_start_loc_list", "kv_start_loc_list", - "q_seq_lens", "kv_seq_lens" -]) - -PackedQKVO = namedtuple("PackedQKVO", ["packed_qkv", "ideal_output"]) - -KVMemoryMap = namedtuple("KVMemoryMap", ["block_tables", "slot_mapping"]) - -PhaseTestParameters = namedtuple("PhaseTestParameters", - ["packed_qkvo", "kv_mmap"]) - - def make_qkv( batch_size: int, max_q_seq_len: int, From 67ab576fdbc13815b4e7ce60520c48d84da4bb05 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 6 Jun 2024 12:08:31 -0400 Subject: [PATCH 196/239] removed HIP check; clarified assumptions about supported backends in enc/dec supported feature checks --- tests/kernels/test_encoder_decoder_attn.py | 38 ++++++------------ tests/kernels/utils.py | 16 ++++++++ vllm/attention/backends/utils.py | 46 +++++++++++++++++----- vllm/attention/backends/xformers.py | 16 ++++---- 4 files changed, 73 insertions(+), 43 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index a923ba4e18a4e..9e07611cba5c7 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -625,22 +625,6 @@ def _run_encoder_decoder_cross_attention_test( value, kv_cache, attn_metadata) -def _assert_actual_match_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor) -> None: - ''' - Assert that observed output matches the ideal output - contained in the test parameters data structure. - - Arguments: - - * test_params: Test parameters including packed ideal output - * output_under_test: actually observed output value - ''' - ideal_output = test_params.packed_qkvo.ideal_output - assert torch.allclose(ideal_output, - output_under_test.view_as(ideal_output)) - - @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -741,7 +725,7 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -753,8 +737,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Prefill self-attention correct? - _assert_actual_match_ideal(prephase_dec_test_params, - self_prefill_packed_actual_output) + assert_actual_matches_ideal(prephase_dec_test_params, + self_prefill_packed_actual_output) # PREFILL: cross-attention test @@ -766,8 +750,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( prephase_attn_metadata) # - Prefill cross-attention correct? - _assert_actual_match_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) + assert_actual_matches_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata @@ -795,8 +779,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( attn_type=AttentionType.DECODER) # - Decode self-attention correct? - _assert_actual_match_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) + assert_actual_matches_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) # DECODE: cross-attention test @@ -808,8 +792,8 @@ def test_enc_dec_self_and_cross_attention_prefill_decode_phases( decphase_attn_metadata) # - Decode cross-attention correct? - _assert_actual_match_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + assert_actual_matches_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -916,7 +900,7 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test @@ -1044,7 +1028,7 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, attn_type=AttentionType.ENCODER) # - Is encoder attention result correct? - _assert_actual_match_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) # PREFILL: self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 2d802645de060..5e6d5cb2c9bd6 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -757,3 +757,19 @@ def make_test_metadata( cross_kv_mmap.slot_mapping, cross_block_tables=None if cross_kv_mmap is None else \ cross_kv_mmap.block_tables) + + +def assert_actual_matches_ideal(test_params: PhaseTestParameters, + output_under_test: torch.Tensor) -> None: + ''' + Assert that observed output matches the ideal output + contained in the test parameters data structure. + + Arguments: + + * test_params: Test parameters including packed ideal output + * output_under_test: actually observed output value + ''' + ideal_output = test_params.packed_qkvo.ideal_output + assert torch.allclose(ideal_output, + output_under_test.view_as(ideal_output)) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index d67251dd17b23..45a6f4af37d13 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ """Attention utils""" from vllm.attention import AttentionMetadata -from vllm.utils import is_hip # Error string(s) for encoder/decoder # unsupported attention scenarios @@ -25,28 +24,57 @@ # Check for unsupported encoder/decoder scenarios -def is_encoder_decoder_metadata(attn_metadata) -> bool: +def is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata) -> bool: + ''' + Return True of the attn_metadata argument contains + the metadata fields that would be required for + encoder attention, which proves that the user is + not running a purely decoder-only model. + + Assumes attn_metadata is derived from a backend that supports + encoder/decoder models. + + Arguments: + + * attn_metadata: instance of supported backend metadata. + Type annotation omitted to avoid circular import. + + + Returns: + + * True if attn_metadata is configured for an encoder/decoder model + ''' return attn_metadata.is_all_encoder_attn_metadata_set def fail_encoder_decoder_prefix_caching() -> None: + ''' + Fail with NotImplementedError & a message indicating + enc/dec + prefix caching is unsupported + ''' raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) -def check_hip_or_chunked_prefill_attention_encdec( +def assert_no_encdec_chunked_prefill_assuming_supported_backend( attn_metadata: AttentionMetadata) -> None: ''' - Check for unsupported encoder/decoder scenarios when invoking - attention. + Fail if encoder/decoder model is being executed with + chunked prefill. + Assumes we already know that the particular attention + backend in-use is supported. + Arguments: * attn_metadata: Attention metadata structure ''' - if is_hip(): - # AMD ROCm/HIP support currently not implemented for - # encoder/decoder models - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ROCM_HIP) + + if not is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata): + # Only care about encoder/decoder + # scenarios. + return if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 72fe333021f08..32ea44f74d106 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -12,8 +12,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import ( - check_hip_or_chunked_prefill_attention_encdec, - fail_encoder_decoder_prefix_caching, is_encoder_decoder_metadata) + assert_no_encdec_chunked_prefill_assuming_supported_backend, + fail_encoder_decoder_prefix_caching, + is_encoder_decoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -481,10 +482,10 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - if is_encoder_decoder_metadata(attn_metadata): - # Raise NotImplementedError for unsupported encoder/decoder - # scenarios - check_hip_or_chunked_prefill_attention_encdec(attn_metadata) + # Raise NotImplementedError for unsupported encoder/decoder + # scenarios (has no effect on decoder-only models) + assert_no_encdec_chunked_prefill_assuming_supported_backend( + attn_metadata) if (kv_cache is not None): # Even if there are no new key/value pairs to cache, @@ -558,7 +559,8 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata(attn_metadata): + if is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata): fail_encoder_decoder_prefix_caching() assert prefill_meta.query_start_loc is not None From dc7d3c8371a4a186abcc660ee6b9b1df8ed3298c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sat, 8 Jun 2024 15:01:56 -0400 Subject: [PATCH 197/239] wip --- tests/kernels/test_encoder_decoder_attn.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 9e07611cba5c7..04412f78d8d84 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -39,6 +39,22 @@ class TestPoint(NamedTuple): + """ + Encapsulates the attributes which define the + test_enc_dec_self_and_cross_attention_prefill_decode_phases() + test + + Attributes: + num_heads: The number of heads in the model. + head_size: Head dimension + backend_name: Name of the backend framework used. + batch_size: Number of samples per batch. + block_size: Size of each block of data processed. + max_dec_seq_len: Maximum sequence length for the decoder. + max_enc_seq_len: Maximum sequence length for the encoder. + num_blocks: Number of blocks in the model. + """ + num_heads: int head_size: int backend_name: str From e9c2a8571241e1d256001c90ed492b3439b7de1f Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 07:06:09 -0400 Subject: [PATCH 198/239] wip comments --- tests/kernels/test_encoder_decoder_attn.py | 40 ++++++++++++++++++---- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 04412f78d8d84..3a86d07277b02 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -1,8 +1,9 @@ """ Tests: -* E2E Encoder attention + Decoder self-attention + - Encoder/decoder cross-attention +* E2E test of Encoder attention + Decoder self-attention + + Encoder/decoder cross-attention (collectively + "encoder/decoder attention") * Confirm enc/dec models will fail for chunked prefill * Confirm enc/dec models will fail for prefix caching @@ -40,9 +41,8 @@ class TestPoint(NamedTuple): """ - Encapsulates the attributes which define the - test_enc_dec_self_and_cross_attention_prefill_decode_phases() - test + Encapsulates the attributes which define a single invocation + of the test_e2e_enc_dec_attn() test Attributes: num_heads: The number of heads in the model. @@ -66,6 +66,34 @@ class TestPoint(NamedTuple): class TestResources(NamedTuple): + ''' + Encapsuates key components for performing an + encoder/decoder attention test + + Note that + (1) attn automatically selects an attention backend + based on platform info & a set of canned + heuristics + (2) attn_backend is thus *not the same backend + instance* used by attn, but rather it is + intended to be a + *different instance* of the *same backend class*; + it is assumed that the user of TestResources + will leverage attn_backend for the purpose of + constructing backend-compatible attention + metadata instances + + Attributes: + + * scale: 1/sqrt(d) scale factor for attn + * attn_backend: implementatino of abstraction + attention interface using + a particular kernel library + i.e. XFormers + * attn: Attention layer instance + * kv_cache: shared key/value cache for all attention + ''' + scale: float attn_backend: AttentionBackend attn: Attention @@ -649,7 +677,7 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_enc_dec_self_and_cross_attention_prefill_decode_phases( +def test_e2e_enc_dec_attn( num_heads: int, head_size: int, backend_name: str, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: From 57910284fe69f9dea9bf1783ce889a305f12c7c7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 07:06:25 -0400 Subject: [PATCH 199/239] small fix --- tests/kernels/test_encoder_decoder_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3a86d07277b02..ee76064ec6005 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -67,7 +67,7 @@ class TestPoint(NamedTuple): class TestResources(NamedTuple): ''' - Encapsuates key components for performing an + Encapsulates key components for performing an encoder/decoder attention test Note that From f0cd5eab267107978559facdcbe6ba6fa014657a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 08:00:07 -0400 Subject: [PATCH 200/239] formatting --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ee76064ec6005..1d042b9fa6fe4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -677,10 +677,10 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: From 2a7fd866e313842de13371dbfae813fc7714cd55 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:36:45 -0400 Subject: [PATCH 201/239] enc/dec test comment updates; some function arg changes; formatting --- tests/kernels/test_encoder_decoder_attn.py | 512 +++++++++++---------- tests/kernels/utils.py | 2 +- 2 files changed, 264 insertions(+), 250 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 1d042b9fa6fe4..950b8244365a3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -102,25 +102,29 @@ class TestResources(NamedTuple): def _make_test_resources(test_pt: TestPoint) -> TestResources: ''' - Compute & build entities required for the self-/cross-attention test. + Build key components for performing encoder/decoder attention test. + + Note that + (1) The Attention instance constructed here, automatically selects + an attention backend class based on platform info & a set of canned + heuristics, so + (2) The attention backend instance constructed here is thus *not + the same backend instance* used by attn, but rather it is + intended to be a *different instance* of the *same backend class*; + therefore, + (3) This function requires that test_pt.backend_name matches the backend + class that Attention will automatically select when it is constructed. + Arguments: - * num_heads: Number of attention heads - * head_size: Head dimension - * num_blocks: Number of KV cache blocks (no KV cache if None) - * block_size: Number of offsets within a KV cache block - (no KV cache if None) - * backend_name: selection of backend + * test_pt: TestPoint data structure; this function relies on the + following fields: num_heads, head_size, num_blocks, + block_size, backend_name Returns: - * scale: 1/sqrt(head_size) - * attn_backend: backend instance - * attn: Attention wrapper instance - * kv_cache: fake KV cache, 2 x num_blocks x (block_size * num_heads * - head_size) - * None if num_blocks or block_size is None + * TestResources data structure. ''' scale = float(1.0 / (test_pt.head_size**0.5)) @@ -158,33 +162,29 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ The query/key/value tensors are passed to an ideal reference self-attention implementation to generate an ideal output tensor. - This function also constructs the self-attention KV cache memory mapping - (slot mapping and block table), ensuring that the block table starts at - block_base_addr + Encoder inference does not populate the KV cache, therefore + no KV cache memory mapping is constructed Arguments: - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * block_base_addr: self-attention block table base address + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + Returns: - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * packed_ideal_output: number_of_tokens x num_heads x head_size - * block_tables: fake self-attn decode-phase block table - * slot_mapping: fake self-attn decode-phase slot mapping - * q_seq_lens: list of query sequence lengths + * PhaseTestParameters data structure comprising (1) packed query/key/value + tensors, (2) the ideal output of attention computed using a naive + implementation, and (3) KVCache field set to None ''' max_kv_seq_len = max_q_seq_len + # Make test tensors + qkv_in, _, _ = make_qkv(batch_size, max_q_seq_len, max_kv_seq_len, @@ -193,7 +193,9 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) - # No causal attention mask + # Compute correct answer using naive non-causal attention + # implementation + ideal_output = ref_masked_attention(qkv_in.query, qkv_in.key, qkv_in.value, @@ -251,51 +253,30 @@ def _decoder_attn_setup( Arguments: - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * block_base_addr: self-attention block table base address + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address Returns: - - * query: "baseline" query; batch_size x padded_seq_len x num_heads x - head_size - * prefill_packed_query: "prefill" query; number_of_tokens x num_heads x - head_size - * prefill_packed_key: self-attn "prefill" key; number_of_tokens x num_heads - x head_size - * prefill_packed_value: self-attn "prefill" value; number_of_tokens x - num_heads x head_size - * prefill_packed_ideal_output: self-attn "prefill" ideal output; - number_of_tokens x num_heads x head_size - * prefill_q_seq_lens: list of token counts for each *prefill query* (one - less than baseline query) - * prefill_kv_seq_lens: list of token counts for each self-attn *prefill - key/value* (should match prefill_q_seq_lens) - * decode_packed_query: "decode" query; number_of_tokens x num_heads x - head_size - * decode_packed_key: self-attn "decode" key; number_of_tokens x num_heads x - head_size - * decode_packed_value: self-attn "decode" key; number_of_tokens x num_heads - x head_size - * decode_packed_ideal_output: self-attn "decode" ideal output; - number_of_tokens x num_heads x head_size - * decode_q_seq_lens: list of token counts for each *decode query* (should - be 1) - * decode_kv_seq_lens: list of token counts for each self-attn *decode - key/value* (should match decode_q_seq_lens) - * q_seq_lens: "baseline" query seq lens; number_of_tokens x num_heads x - head_size - * kv_seq_lens: self-attn "baseline" key/value seq lens; number_of_tokens - x num_heads x head_size - * decode_block_tables: fake self-attn decode-phase block table - * decode_slot_mapping: fake self-attn decode-phase slot mapping - * prefill_slot_mapping: fake self-attn prefill-phase slot mapping - * prefill_block_tables: fake self-attn prefill-phase block table - * max_block_idx: highest block address in the self-attention block-table + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x + head_size) query/key/value tensors + * Prefill-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for prefill phase. + * Decode-phase decoder self-attention PhaseTestParameters data structure, + including (1) packed (number_of_tokens x num_heads x head_size) + query/key/value tensors along with (2) ideal attention output + computed using a naive implementation, and (3) memory-mapping data + structures appropriate for decode phase. + * max_block_idx: max physical address in decoder self-attention block-table + (intended to be used as the base address for the encoder/ + decoder cross-attention block-table, which is not + constructed in this function) ''' (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, @@ -305,6 +286,8 @@ def _decoder_attn_setup( max_kv_seq_len = max_q_seq_len + # Build test tensors + qkv, \ prefill_qkv, \ decode_qkv = make_qkv(batch_size, @@ -315,6 +298,9 @@ def _decoder_attn_setup( attn_type=AttentionType.DECODER, device=CUDA_DEVICE) + # Compute correct answer using naive attention implementation + # with causal attention mask + causal_mask = make_causal_mask(max_q_seq_len, max_kv_seq_len).to(CUDA_DEVICE) @@ -326,6 +312,8 @@ def _decoder_attn_setup( q_seq_lens=qkv.q_seq_lens, kv_seq_lens=qkv.kv_seq_lens) + # Split out the prefill- & decode-phase ideal answers & pack them + prefill_ideal_output = torch.zeros_like(ideal_output) decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1]) for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens): @@ -357,6 +345,9 @@ def _decoder_attn_setup( # (including both prefill & decode) # * Slot-mapping with entries for tokens that will be decoded in the # current decode iteration + # + # Note: the format described above is simply mirroring what ModelRunner + # produces prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) @@ -432,36 +423,38 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, Arguments: - * query: pre-existing "baseline" query; batch_size x padded_seq_len x - num_heads x head_size - * q_seq_lens: list of token-counts for each "baseline" query sequence - * prefill_q_seq_lens: list of token-counts for each "prefill" query - sequence - * batch_size - * num_heads: Number of attention heads - * head_size: Head dimension - * block_size: Number of offsets per KV cache block - * scale: attention scale parameter - * max_q_seq_len: upper limit on query length for synthetic test vectors - * max_kv_seq_len: upper limit on key/value length for synthetic test - vectors - * block_base_addr: cross-attention block table base address + * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x + num_heads x head_size) decoder self-attention inputs; + this function relies on the query and q_seq_lens + fields + * encoder_test_params: PhaseTestParameters data structure which was + used for encoder inference; KV cache field + is not used by this function + * prefill_decoder_phase_test_params: PhaseTestParameters data structure + used for prefill-phase decoder + self-attention; all fields + including KV cache required + * test_pt: TestPoint data structure; this function relies on the + following fields: batch_size, num_heads, head_size, + block_size, max_q_seq_len + * test_rsrcs: TestResources data structure; this function relies on the + scale field + * block_base_addr: decoder self-attention block-table base address Returns: - * packed_key: cross-attention key; number_of_tokens x num_heads x head_size - * packed_value: cross-attention value; number_of_tokens x num_heads x - head_size - * prefill_packed_ideal_output: "prefill" ideal output; number_of_tokens x - num_heads x head_size - * decode_packed_ideal_output: "decode" ideal output; number_of_tokens x - num_heads x head_size - * kv_seq_lens: list of token-counts for each key/value - * decode_block_tables: fake decode-phase block tables - * decode_slot_mapping: fake decode-phase slot mapping - * prefill_slot_mapping: fake prefill-phase slot mapping - * prefill_block_tables: fake prefill-phase block tables - * max_block_idx: highest block address in the cross-attention block-table + * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for prefill phase. + * Decode-phase encoder/decoder cross-attention PhaseTestParameters data + structure, including (1) packed + (number_of_tokens x num_heads x head_size) query/key/value tensors + along with (2) ideal attention output computed using a + naive implementation, and (3) memory-mapping data structures appropriate + for decode phase. ''' (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, @@ -533,6 +526,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # * Empty slot-mapping tensor (since K & V are fixed in size, # new decoded tokens are not KV-cached and require no slot- # mapping) + # + # Note: the format above is simply an extension of what ModelRunner + # produces for decoder-only models prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) @@ -569,30 +565,32 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, def _run_encoder_attention_test(attn: Attention, encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: + attn_metadata: AttentionMetadata) \ + -> torch.Tensor: ''' Run encoder attention. - attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for either encoder attention + attn_metadata.attention_type is assigned AttentionType.ENCODER in order + to configure the kernel invocation for encoder attention - attn_type must be AttentionType.ENCODER + Requires attn_metadata.num_decode_tokens == 0 + (There is no encoder execution in the decode-phase) Arguments: * attn: Attention wrapper instance - * pckd_qkv: Packed query/key/value inputs + * encoder_test_params: encoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention - * attn_type: AttentionType.DECODER or AttentionType.ENCODER Returns: - * Attention.forward() applied to packed_{query,key,value}, kv_cache + * Attention.forward() applied to packed {query,key,value} and & attn_metadata ''' - assert attn_type == AttentionType.ENCODER assert attn_metadata.num_decode_tokens == 0 - attn_metadata.attention_type = attn_type + attn_metadata.attention_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) @@ -600,32 +598,32 @@ def _run_encoder_attention_test(attn: Attention, def _run_decoder_self_attention_test(test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata, - attn_type: AttentionType) -> torch.Tensor: + attn_metadata: AttentionMetadata) \ + -> torch.Tensor: ''' Run decoder self-attention test. - attn_metadata.attention_type is assigned attn_type in order to configure - the kernel invocation for decoder self-attention. - - attn_type must be AttentionType.DECODER + attn_metadata.attention_type is assigned AttentionType.DECODER + in order to configure the kernel invocation for decoder self-attention. Arguments: - * attn: Attention wrapper instance - * pckd_qkv: Packed query/key/value inputs - * kv_cache - * attn_metadata: attention metadata for encoder/decoder-self attention - * attn_type: AttentionType.DECODER or AttentionType.ENCODER + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query/key/value fields + * attn_metadata: attention metadata for decoder-self attention + (contains KV cache memory-mapping) Returns: * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - assert attn_type == AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache - attn_metadata.attention_type = attn_type + attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) @@ -633,20 +631,34 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - cross_test_params: PhaseTestParameters, + cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. + Via PhaseTestParameters data structures, consumes the same query utilized + for decoder self-attention, plus a key/value specific to cross-attention. + + if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv + is None, this reflects that in decode-phase cross attention there + is no growth in the key and value tensors. + attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER in order to configure the kernel invocation for encoder/decoder cross- attention. Arguments: - * attn: Attention wrapper instance - * packed_{query,key,value}: total_num_tokens x (num_heads*head_size) - * kv_cache + * test_rsrcs: TestResources instance; this function relies on the kv_cache + and attn (Attention wrapper instance) fields + * decoder_test_params: decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + query field + * cross_test_params: encoder/decoder PhaseTestParameters data structure; + this function relies on the packed + (number_of_tokens x num_heads x head_size) + key/value fields * attn_metadata: attention metadata for encoder/decoder-self attention Returns: @@ -682,19 +694,25 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, max_dec_seq_len: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: - - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order + End-to-end encoder/decoder test: + + * Construct fake test vectors for (1) encoder attention, + (2) decoder self-attention, and (3) encoder/decoder cross-attention + * Construct (1) attention metadata structure with self- and cross-attention + attributes for prefill-phase, and (2) an analogous attention metadata + structure but for decode-phase + * Test attention steps in the following order + * Encoder attention * Prefill self-attention * Prefill cross-attention * Decode self-attention * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid + * Besides being reflective of realistic use-cases, this order would + exacerbate any accidental overlap in the self-/cross-attention + block tables, which one hopes to avoid + + * Validate output correctness against ideal reference attention implementation @@ -705,11 +723,32 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, tensors. Self-attention K/Vs must have the same seq len as Q while cross-attention K/Vs are allowed to differ in seq len, as is often the case for cross-attention. + + This test utilizes PyTest monkey patching to force the attention backend + via an environment variable. + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). + + Note on metadata: there is a single attention metadata structure shared by + all prefill-phase attention operations (encoder, decoder, enc/dec cross), + and a single one shared by all decode-phase attention operations + (decoder & enc/dec cross.) This is intended to reflect the behavior + of ModelRunner, which constructs a single attention metadata structure for + each prefill or decode run. A realistic scenario would rely on the + attention backend to utilize the appropriate attention metadata fields + according to the value of attn_metadata.attention_type. Thus, this test is + organized so as to confirm that the backend-under-test can handle a + shared prefill attention metadata structure & a shared decode attention + metadata structure. ''' # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, block_size, max_dec_seq_len, max_enc_seq_len, 4096) @@ -717,24 +756,24 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # instance, KV cache init test_rsrcs = _make_test_resources(test_pt) - # Encoder attention setup - - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. + # Construct encoder attention test params (only used + # during prefill) enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - # Decoder self-attention setup + # Construct Decoder self-attention prefill-phase & decode-phase + # test params, including query/key/value tensors, decoder self-attention + # memory-mapping. cross_block_base_addr is the uppermost address in the + # decoder self-attention block-table, i.e. a base address which the + # encoder/decoder cross-attention block-table may build downward toward. dec_qkv, \ prephase_dec_test_params, \ decphase_dec_test_params, \ cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - # Cross-attention setup + # Construct encoder/decoder cross-attention prefill-phase & decode-phase + # test params, including key/value tensors, cross-attention memory-mapping prephase_cross_test_params, \ decphase_cross_test_params, \ @@ -759,32 +798,29 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel - enc_packed_actual_output: torch.Tensor = \ + enc_pckd_act_out: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - # PREFILL: self-attention test + # PREFILL: decoder self-attention test - self_prefill_packed_actual_output: torch.Tensor = \ + prephase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + prephase_attn_metadata) - # - Prefill self-attention correct? + # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, - self_prefill_packed_actual_output) + prephase_dec_pckd_act_out) - # PREFILL: cross-attention test + # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( @@ -793,16 +829,12 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, prephase_cross_test_params, prephase_attn_metadata) - # - Prefill cross-attention correct? + # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, prephase_cross_pckd_act_out) # DECODE: build decode-phase attention metadata - # - Cross-attention KV context is equal in length to - # encoder input - # context_lens = copy.deepcopy(enc_pckd_qkvo.packed_qkv.q_seq_lens) - decphase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, False, @@ -813,20 +845,19 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, default_attn_type=AttentionType.DECODER, device=CUDA_DEVICE) - # DECODE: self-attention test + # DECODE: decoder self-attention test decphase_dec_pckd_act_out: torch.Tensor = \ _run_decoder_self_attention_test( test_rsrcs, decphase_dec_test_params, - decphase_attn_metadata, - attn_type=AttentionType.DECODER) + decphase_attn_metadata) - # - Decode self-attention correct? + # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, decphase_dec_pckd_act_out) - # DECODE: cross-attention test + # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out: torch.Tensor = \ _run_encoder_decoder_cross_attention_test( @@ -835,7 +866,7 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, None, decphase_attn_metadata) - # - Decode cross-attention correct? + # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, decphase_cross_pckd_act_out) @@ -857,29 +888,25 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: + Confirm encoder/decoder models will fail with NotImplemented + if chunked prefill is enabled. - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation + This test + 1. Executes a subset of test setup code from + test_e2e_enc_dec_attn() (everything up to encoder + execution); see test_e2e_enc_dec_attn() for more context + on how this code works. - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + 2. Modifies the prefill-phase attention metadata structure + to imply a chunked-prefill scenario - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. + 3. Attempts to execute decoder self-attention + + 4. Asserts that that decoder self-attention fails & with the correct + error message + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). ''' # Force Attention wrapper backend @@ -894,12 +921,6 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, # Encoder attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -934,37 +955,35 @@ def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - # PREFILL: self-attention test - + # Meat of the test: require that chunked prefill triggers failure. + # # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- + # is configured for chunked prefill & decoder self- # attention. Required that this triggers a NotImplementedError. # - # We assume that decode_attn_metadata.num_decode_tokens > 1 + # We assume that decode_attn_metadata.num_prefill_tokens > 1 # already; the line below sets up a chunked prefill # metadata configuration where there is nominally a mix # of prefill and decode tokens. prephase_attn_metadata.num_decode_tokens = 1 with pytest.raises(NotImplementedError) as exc_info: - _run_decoder_self_attention_test(test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + # Doomed decoder self-attention + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, + prephase_attn_metadata) # "Encoder decoder models do not currently support chunked prefill" + # or something to that effect assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL @@ -985,29 +1004,25 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, max_enc_seq_len: int, monkeypatch) -> None: ''' - Encoder/decoder attention test: + Confirm encoder/decoder models will fail with NotImplemented + if prefix caching is enabled. - * Construct fake test vectors for self- and cross-attention - * Construct attention metadata structure with self- and cross-attention - attributes - * Test self- and cross-attention in the following order - - * Prefill self-attention - * Prefill cross-attention - * Decode self-attention - * Decode cross-attention - * This order would exacerbate any accidental overlap in the - self-/cross-attention block tables, which we attempt to avoid - * Validate output correctness against ideal reference attention - implementation + This test + 1. Executes a subset of test setup code from + test_e2e_enc_dec_attn() (everything up to encoder + execution); see test_e2e_enc_dec_attn() for more context + on how this code works. - Block tables are constructed such that cross-attention KV cache is in a - higher, non-intersecting address-space than self-attention KV cache. + 2. Modifies the prefill-phase attention metadata structure + to imply a prefix caching scenario - Self- and cross-attention share the same query tensor but not the K/V - tensors. Self-attention K/Vs must have the same seq len as Q while - cross-attention K/Vs are allowed to differ in seq len, as is often the case - for cross-attention. + 3. Attempts to execute decoder self-attention + + 4. Asserts that that decoder self-attention fails & with the correct + error message + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). ''' # Force Attention wrapper backend @@ -1022,12 +1037,6 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # Encoder attention setup - # Let encoder_attn_setup() choose default block table - # base address; the block_tables and slot_mapping - # tensors are not actually utilized by encoder attention - # anyway but are required to be present & valid by the - # backend. - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) # Decoder self-attention setup @@ -1062,30 +1071,36 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, device=CUDA_DEVICE) # PREFILL: encoder attention - # * Use prefill kernel enc_packed_actual_output: torch.Tensor = \ _run_encoder_attention_test( test_rsrcs.attn, enc_test_params, - prephase_attn_metadata, - attn_type=AttentionType.ENCODER) + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - # PREFILL: self-attention test - - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & encoder/decoder cross- - # attention. Required that this triggers a NotImplementedError. + # Meat of the test: require that prefix caching triggers failure. # - # We assume that decode_attn_metadata.num_decode_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. + # Set up a contrived scenario where the attention metadata + # is configured for prefix caching & decoder self- + # attention. Require that this triggers a NotImplementedError. with pytest.raises(NotImplementedError) as exc_info: - # Fake a non-empty block_tables + # In XFormers backend, the trigger for utilizing the + # prefix caching kernel is + # + # kv_cache is not None and prefill_meta.block_tables.numel() > 0 + # + # We can shallowly emulate a prefix caching scenario by passing + # in a non-None KV cache in test_rsrcs (already the + # case) and then tweaking the cached prefill attention metadata + # from the encoder run to have a non-empty (gibberish) block + # table. This block table will never actually be used, because + # its presence will signify to the backend a prefix-caching + # scenario and (given that the attention metadata structure + # is configured for an encoder/decoder scenario too) trigger + # a NotImplemented a exception. num_seqs = len( prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) @@ -1094,10 +1109,9 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, torch.randint( 0, 10, (num_seqs, 1)) - _run_decoder_self_attention_test(test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata, - attn_type=AttentionType.DECODER) + _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, + prephase_attn_metadata) - # "Encoder decoder models do not currently support chunked prefill" + # "Encoder decoder models do not currently support prefix caching" + # or something to that effect assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5e6d5cb2c9bd6..4f06088da909e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -92,7 +92,7 @@ def ref_masked_attention(query: torch.Tensor, * key: batch_size x kv_padded_seq_len x num_heads x head_size * value: batch_size x kv_padded_seq_len x num_heads x head_size * scale: Attention scale factor - * Custom mask: custom attention mask; good place to inject a causal + * custom_mask: custom attention mask; good place to inject a causal attention mask * q_seq_lens: list of unpadded query seq_lens for each batch index * kv_seq_lens: list of unpadded key/value seq_lens for each batch index From 8e1daa16b4fdceb9a10f138888cec2303673e501 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:40:22 -0400 Subject: [PATCH 202/239] formatting --- tests/kernels/test_encoder_decoder_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ee76064ec6005..1d042b9fa6fe4 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -677,10 +677,10 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn( - num_heads: int, head_size: int, backend_name: str, batch_size: int, - block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch) -> None: ''' Encoder/decoder attention test: From a2a7ac5874083192c348853dba529c8eed4260ea Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 09:45:29 -0400 Subject: [PATCH 203/239] fixed attention selector test to use FLASH_ATTN string constant var in all relevant locations --- tests/kernels/test_attention_selector.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 79e03c7478de0..d9000e58d1d43 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -42,32 +42,32 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=[7, 5]): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported sliding window backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) - assert backend.name != "FLASH_ATTN" + assert backend.name != STR_FLASH_ATTN_VAL def test_invalid_env(monkeypatch): From d3575687c7ac4966e64f4c92c4fef9389aaa9afa Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 10 Jun 2024 10:41:54 -0400 Subject: [PATCH 204/239] additional commenting & added string constants for other backends --- tests/kernels/test_encoder_decoder_attn.py | 2 +- tests/kernels/utils.py | 168 +++++++++++++++------ 2 files changed, 126 insertions(+), 44 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 950b8244365a3..922e3ddb43dd8 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -28,7 +28,7 @@ BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = ["XFORMERS"] +BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 4f06088da909e..ea226878e4b33 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,12 +13,34 @@ from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, maybe_make_long_tensor, maybe_max) +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" class QKVInputs(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + query/key/value. + + Attributes: + + * {query,key,value}: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + query: torch.Tensor key: torch.Tensor value: torch.Tensor @@ -27,11 +49,37 @@ class QKVInputs(NamedTuple): class QKVO(NamedTuple): + ''' + Data structure for representing unpacked attention inputs, + alongside unpacked known-correct attention output + + Attributes: + + * qkv: unpacked (batch_size x padded_seq_len x + num_heads x head_size) attention inputs + * ideal_output: unpacked (batch_size x padded_seq_len x + num_heads x head_size) known-correct attention output + ''' + qkv: QKVInputs ideal_output: torch.Tensor class PackedQKVInputs(NamedTuple): + ''' + Data structure for representing packed attention inputs + + Attributes: + + * {query,key,value}: packed (number_of_tokens x num_heads + x head_size) attention inputs + * q_seq_lens: list of query start locations within packed tensor + * kv_seq_lens: shared list of key/value start locations within + packed tensor + * q_seq_lens: query sequence lengths list + * kv_seq_lens: shared key/value sequence lengths list + ''' + query: torch.Tensor key: torch.Tensor value: torch.Tensor @@ -42,16 +90,51 @@ class PackedQKVInputs(NamedTuple): class PackedQKVO(NamedTuple): + ''' + Data structure for representing packed attention inputs, + alongside packed known-correct attention output + + Attributes: + + * packed_qkv: packed (number_of_tokens x num_heads + x head_size) attention inputs + * ideal_output: packed (number_of_tokens x num_heads + x head_size) known-correct attention output + ''' + packed_qkv: PackedQKVInputs ideal_output: torch.Tensor class KVMemoryMap(NamedTuple): + ''' + Data structure for encapsulating KV cache memory mapping. + + Attributes: + + * block_tables: KV cache block tables + * slot_mapping: mapping of sequence offset to physical address + ''' + block_tables: torch.Tensor slot_mapping: torch.Tensor class PhaseTestParameters(NamedTuple): + ''' + Data structure for encapsulating the test parameters + for a given test "phase" (prefill or decode phase) and attention + scenario (encoder, decoder-self, encoder/decoder-cross) + + Attributes: + + * packed_qkvo: packed (number_of_tokens x num_heads + x head_size) attention inputs & known-correct + output + * kv_mmap: KV cache memory mapping, specific to this test phase & + attention scenario + ''' + packed_qkvo: PackedQKVO kv_mmap: KVMemoryMap @@ -174,7 +257,9 @@ def make_qkv( Returns: - * QKVInputs structure + * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) + * Prefill QKVInputs structure (containing all but the last sequence offset) + * Decode QKVInputs structure (containing all only the last sequence offset) ''' if force_max_len: @@ -245,18 +330,18 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return QKVInputs(query, + return QKVInputs(query, # Overall QKV inputs key, value, q_seq_lens, kv_seq_lens), \ - QKVInputs(prefill_query, + QKVInputs(prefill_query, # Prefill subset of QKV sequences prefill_key, prefill_value, prefill_q_seq_lens, prefill_kv_seq_lens), \ QKVInputs( - decode_query, + decode_query, # Decode subset of KV sequences decode_key, decode_value, decode_q_seq_lens, @@ -311,19 +396,14 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, Arguments: - * query: batch_size x padded_seq_len x num_heads x head_size - * key: batch_size x padded_seq_len x num_heads x head_size - * value: batch_size x padded_seq_len x num_heads x head_size - * q_seq_lens: list of token counts for each query - * kv_seq_lens: list of token counts for each key/value + * qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size) + attention inputs + * device: CPU or CUDA device Returns - * packed_query: number_of_tokens x num_heads x head_size - * packed_key: number_of_tokens x num_heads x head_size - * packed_value: number_of_tokens x num_heads x head_size - * q_start_loc_list: start idx of each query in packed_query - * kv_start_loc_list: start idx of each {key,value} in packed_{key,value} + * Packed (number_of_tokens x num_heads x head_size) QKV inputs + derived from unpacked inputs ''' if qkv.query is None: @@ -367,7 +447,7 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' - if backend_name == "XFORMERS": + if backend_name == STR_XFORMERS_ATTN_VAL: return XFormersBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") @@ -383,20 +463,19 @@ def _make_metadata_tensors( Arguments: - * is_prompt: True -> Prefill, False -> Decode - * seq_lens: list of token-counts for each seq + * seq_lens: list of token-counts for each decoder input seq * context_lens: list of context length values for each seq + * encoder_seq_lens: list of token-counts for each encoder input seq * device: CPU or CUDA device Returns: - * seq_lens_tensor: seq_lens list, as tensor + * seq_lens_tensor: decoder seq_lens list, as tensor * context_lens_tensor: context_lens list, as tensor - * max_query_len: max(seq_lens) if is_seq, o/w 1 * max_context_len: max(context_lens) * max_seq_len: max(seq_lens) * seq_start_loc: start idx of each sequence - * query_start_loc: start idx of each query + * max_encoder_seq_len: encoder seq_lens list, as tensor ''' seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) context_lens_tensor = maybe_make_int_tensor(context_lens, device) @@ -614,12 +693,15 @@ def make_test_metadata( cross_test_params: Optional[PhaseTestParameters] = None ) -> AttentionMetadata: ''' - Construct fake attention metadata for a combined self-/cross-attention - scenario i.e. an encoder/decoder model. + Construct fake attention metadata for a given test phase + (prefill-phase or decode-phase). - is_encoder_only_test=True causes the default attention metadata attention - type to be AttentionType.ENCODER. False causes the default to - be AttentionType.DECODER. + encoder_test_params and cross_test_params arguments all encoder + attention and enc/dec cross-attention to use distinct metadata values + from decoder self-attention (decoder_test_params.) + + if encoder_test_params and cross_test_params are None, the attention + metadata will support decoder-only scenario. Assumptions: @@ -630,32 +712,29 @@ def make_test_metadata( * attn_backend: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * seq_lens: list of token counts for each sequence - * context_lens: list of context lengths for each sequence - * block_tables: self-attention block tables - * slot_mapping: self-attention slot_mapping - * is_encoder_only_test: True if testing encoder; False if testing - decoder self-attention or encoder/decoder cross-attention. + * decoder_test_params: decoder self-attention test params; + this function requires + kv_mmap (memory mapping) field + * default_attn_type: value of attn_metadata.attention_type at + construction time * device: CPU or CUDA device - * encoder_seq_lens: list of token counts for each encoder sequence, if any - exist - * cross_block_tables: cross-attention block tables, if required - * cross_slot_mapping: cross-attention slot mapping, if required + * encoder_test_params: encoder attention test params; + this function requires encoder query + sequence lengths field. If None, + encoder query sequence lengths are + treated as None + * cross_test_params: enc/dec cross-attention test params; + this function requires kv_mmap field. + If None, KV cache memory map data + structures are treated as None Return: - * AttentionMetadata structure supporting self- and cross-attention + * AttentionMetadata structure ''' - # Extract - # * Decoder input sequence lengths (seq_lens) - # * Decoder self-attention slot mapping & block tables (kv_mmap) - #seq_lens = decoder_test_params.packed_qkvo.packed_qkv.q_seq_lens kv_mmap = decoder_test_params.kv_mmap - # is_prompt determines whether input tokens are treated - # as 100% prefill or 100% decode. In either case, - # the number of {prefills, decodes} and the number of - # {prefill, decode} tokens can be inferred from seq_lens num_prefills_or_decodes = len(seq_lens) # Prefill: operate on total num. of prompt @@ -684,6 +763,8 @@ def make_test_metadata( cross_kv_mmap = cross_test_params.kv_mmap if is_prompt: + # Prefill-phase scenario + num_prefills = num_prefills_or_decodes num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 @@ -721,6 +802,7 @@ def make_test_metadata( cross_kv_mmap.block_tables) else: # not is_prompt + # Decode-phase scenario num_prefills = 0 num_prefill_tokens = 0 From 97cad0b96dafe5cabe004bc4a119f00d7d9db5a6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Wed, 12 Jun 2024 16:59:48 -0400 Subject: [PATCH 205/239] encoder-only unit test passes --- tests/kernels/test_encoder_decoder_attn.py | 55 +++++++++++++++++++++- tests/kernels/utils.py | 41 +++++++++++----- vllm/attention/backends/utils.py | 7 +++ vllm/attention/backends/xformers.py | 27 ++++++----- 4 files changed, 103 insertions(+), 27 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 922e3ddb43dd8..c8c72c21f5cc3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -680,6 +680,59 @@ def _run_encoder_decoder_cross_attention_test( return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) +@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) +@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) +def test_encoder_only(num_heads: int, head_size: int, backend_name: str, + batch_size: int, block_size: int, + max_dec_seq_len: int, max_enc_seq_len: int, + monkeypatch): + + # Force Attention wrapper backend + override_backend_env_variable(monkeypatch, backend_name) + + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, + block_size, max_dec_seq_len, max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Construct encoder attention test params (only used + # during prefill) + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Shared prefill metadata structure + + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + None, + decoder_test_params=None, + encoder_test_params=enc_test_params, + cross_test_params=None, + default_attn_type=AttentionType.ENCODER, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + + enc_pckd_act_out: torch.Tensor = \ + _run_encoder_attention_test( + test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) + + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -1114,4 +1167,4 @@ def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, # "Encoder decoder models do not currently support prefix caching" # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ea226878e4b33..3dfc70a46eeb4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -733,15 +733,28 @@ def make_test_metadata( * AttentionMetadata structure ''' - kv_mmap = decoder_test_params.kv_mmap - - num_prefills_or_decodes = len(seq_lens) - - # Prefill: operate on total num. of prompt - # tokens - # Decode: operate on one token per seq - num_prefill_or_decode_tokens = \ - sum(seq_lens) if is_prompt else len(seq_lens) + # Decoder self-attention memory mapping + # decoder_test_params is None signals encoder-only + # scenario, so kv_mmap is None + kv_mmap = None if decoder_test_params is None else \ + decoder_test_params.kv_mmap + + # This function constructs metadata assuming no chunked prefill, + # i.e. 100% prefill tokens or 100% decode tokens + # + # - If is_prompt, num_prefills_or_decodes is the number of prefills + # and num_prefill_or_decode_tokens is the number of prefill tokens + # - If not is_prompt, num_prefills_or_decodes is the number of decodes + # and num_prefill_or_decode_tokens is the number of decode tokens + # + # seq_lens is None signals encoder-only + # scenario, in which case num_prefills_or_decodes and + # num_prefill_or_decode_tokens are unused + num_prefills_or_decodes = None if seq_lens is None else \ + len(seq_lens) + + num_prefill_or_decode_tokens = None if seq_lens is None else \ + (sum(seq_lens) if is_prompt else len(seq_lens)) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -750,14 +763,14 @@ def make_test_metadata( if encoder_test_params is None: encoder_seq_lens = None else: - # Encoder/decoder models only: + # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: cross_kv_mmap = None else: - # Encoder/decoder models only: + # Encoder/decoder or encoder-only models only: # * Extract *cross-attention* slot_mapping and block table # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap @@ -782,7 +795,8 @@ def make_test_metadata( return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=kv_mmap.slot_mapping, + slot_mapping=None if kv_mmap is None else \ + kv_mmap.slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -790,7 +804,8 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=kv_mmap.block_tables, + block_tables=None if kv_mmap is None else \ + kv_mmap.block_tables, use_cuda_graph=False, _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 45a6f4af37d13..c34b35cf77c8a 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -76,6 +76,13 @@ def assert_no_encdec_chunked_prefill_assuming_supported_backend( # scenarios. return + if attn_metadata.num_prefill_tokens is None or \ + attn_metadata.num_decode_tokens is None: + # The metadata which would be + # indicative of chunked prefill is unset; + # this may be the case for encoder-only models + return + if attn_metadata.num_prefill_tokens > 0 and \ attn_metadata.num_decode_tokens > 0: # Encoder/decoder models are currently incompatible diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 32ea44f74d106..32fed02d3adc9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -211,7 +211,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) #assert self.context_lens_tensor is not None - assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] @@ -220,19 +219,20 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None - if self.seq_lens is None \ + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[:self.num_prefill_tokens], + seq_lens=None if self.seq_lens is None \ else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills], + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], + self.context_lens_tensor[:self.num_prefills], + block_tables=None if self.block_tables is None else \ + self.block_tables[:self.num_prefills], use_cuda_graph=False, _attn_type=self.attention_type, # Begin encoder & cross attn fields below... @@ -252,7 +252,6 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: self._cached_decode_metadata.attention_type = \ self.attention_type return self._cached_decode_metadata - assert self.block_tables is not None assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) @@ -260,15 +259,17 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[self.num_prefill_tokens:], seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + self.seq_lens_tensor[self.num_prefills:], max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - block_tables=self.block_tables[self.num_prefills:], + block_tables=None if self.block_tables is None else \ + self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From 29fa1af416f1a6e18abe76ad05a68170f960cc6e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 13 Jun 2024 07:58:14 -0400 Subject: [PATCH 206/239] refactoring --- vllm/attention/backends/utils.py | 42 ++++++++++++++++------ vllm/attention/backends/xformers.py | 56 ++++++++++++++++++----------- 2 files changed, 68 insertions(+), 30 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index c34b35cf77c8a..b0c05fca285f8 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -7,33 +7,33 @@ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ "Chunked prefill is not currently " + \ -"supported with encoder/decoder models." +"supported with encoder/decoder or encoder-only models." STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." +"with encoder/decoder or encoder-only models." STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ "Currently only the XFormers backend " + \ - "supports encoder/decoder models." + "supports encoder/decoder and encoder-only models." STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ "Prefix caching is not currently supported " + \ -"with encoder/decoder models" +"with encoder/decoder or encoder-only models" # Check for unsupported encoder/decoder scenarios -def is_encoder_decoder_metadata_assuming_supported_backend( +def is_encoder_metadata_assuming_supported_backend( attn_metadata) -> bool: ''' - Return True of the attn_metadata argument contains + Return True if the attn_metadata argument contains the metadata fields that would be required for encoder attention, which proves that the user is - not running a purely decoder-only model. + not running a purely decoder-only model Assumes attn_metadata is derived from a backend that supports - encoder/decoder models. + encoder-only or encoder/decoder models. Arguments: @@ -43,10 +43,32 @@ def is_encoder_decoder_metadata_assuming_supported_backend( Returns: - * True if attn_metadata is configured for an encoder/decoder model + * True if attn_metadata is configured for an encoder-only model ''' return attn_metadata.is_all_encoder_attn_metadata_set +def is_encoder_decoder_metadata_assuming_supported_backend( + attn_metadata) -> bool: + ''' + Return True if the attn_metadata argument contains + the metadata fields that would be required for + encoder/decoder attention, which proves that the user is + running an encoder/decoder model + + Assumes attn_metadata is derived from a backend that supports + encoder-only or encoder/decoder models. + + Arguments: + + * attn_metadata: instance of supported backend metadata. + Type annotation omitted to avoid circular import. + + + Returns: + + * True if attn_metadata is configured for an encoder/decoder model + ''' + return attn_metadata.is_all_encoder_decoder_attn_metadata_set def fail_encoder_decoder_prefix_caching() -> None: ''' @@ -70,7 +92,7 @@ def assert_no_encdec_chunked_prefill_assuming_supported_backend( * attn_metadata: Attention metadata structure ''' - if not is_encoder_decoder_metadata_assuming_supported_backend( + if not is_encoder_metadata_assuming_supported_backend( attn_metadata): # Only care about encoder/decoder # scenarios. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 32fed02d3adc9..63fe10dc2502f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -14,7 +14,7 @@ from vllm.attention.backends.utils import ( assert_no_encdec_chunked_prefill_assuming_supported_backend, fail_encoder_decoder_prefix_caching, - is_encoder_decoder_metadata_assuming_supported_backend) + is_encoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -202,6 +202,8 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure self._cached_prefill_metadata.attention_type = \ self.attention_type return self._cached_prefill_metadata @@ -210,29 +212,35 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - #assert self.context_lens_tensor is not None + # Compute some attn_metadata fields which default to None query_start_loc = None if self.query_start_loc is None \ else self.query_start_loc[:self.num_prefills + 1] - + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[:self.num_prefill_tokens] + seq_lens=None if self.seq_lens is None \ + else self.seq_lens[:self.num_prefills] + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[:self.num_prefills] + context_lens_tensor=None if self.context_lens_tensor is None else \ + self.context_lens_tensor[:self.num_prefills] + block_tables=None if self.block_tables is None else \ + self.block_tables[:self.num_prefills] + + # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[:self.num_prefill_tokens], - seq_lens=None if self.seq_lens is None \ - else self.seq_lens[:self.num_prefills], - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[:self.num_prefills], + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, query_start_loc=query_start_loc, - context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills], - block_tables=None if self.block_tables is None else \ - self.block_tables[:self.num_prefills], + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, use_cuda_graph=False, _attn_type=self.attention_type, # Begin encoder & cross attn fields below... @@ -249,24 +257,32 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: return None if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure self._cached_decode_metadata.attention_type = \ self.attention_type return self._cached_decode_metadata assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) + # Compute some attn_metadata fields which default to None + slot_mapping=None if self.slot_mapping is None else \ + self.slot_mapping[self.num_prefill_tokens:] + seq_lens_tensor=None if self.seq_lens_tensor is None else \ + self.seq_lens_tensor[self.num_prefills:] + block_tables=None if self.block_tables is None else \ + self.block_tables[self.num_prefills:] + + # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[self.num_prefill_tokens:], - seq_lens_tensor=None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:], + slot_mapping=slot_mapping, + seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - block_tables=None if self.block_tables is None else \ - self.block_tables[self.num_prefills:], + block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, _attn_type=self. _attn_type, # Begin encoder & cross attn fields below... @@ -560,7 +576,7 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata_assuming_supported_backend( + if is_encoder_metadata_assuming_supported_backend( attn_metadata): fail_encoder_decoder_prefix_caching() From c9f11ff89f49915524284126fe2a80c0d356af05 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 15:40:45 -0400 Subject: [PATCH 207/239] comment fixes --- tests/kernels/utils.py | 6 +++--- vllm/attention/backends/xformers.py | 15 +++------------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ea226878e4b33..71cd93cd3df92 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -73,9 +73,9 @@ class PackedQKVInputs(NamedTuple): * {query,key,value}: packed (number_of_tokens x num_heads x head_size) attention inputs - * q_seq_lens: list of query start locations within packed tensor - * kv_seq_lens: shared list of key/value start locations within - packed tensor + * q_start_loc_list: list of query start locations within packed tensor + * kv_start_loc_list: shared list of key/value start locations within + packed tensor * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list ''' diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 18c324598bf9f..3e10b55fb6fcd 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -142,15 +142,11 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # _attn_type: AttentionType = AttentionType.DECODER - # (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value - # sequence length (usually encoder sequence length) in the cross-attention - # computation. None if this is self-attention + # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # The maximum cross-sequence-length, if cross_seq_lens is specified. - # Note that for cross-attention there is no difference in key/value - # sequence length between prefill and decode + # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None # Cross-attention memory-mapping data structures: slot mapping @@ -200,15 +196,11 @@ def attention_type(self, atype: AttentionType) -> None: "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ "self.cross_block_tables in order to perform cross-attention" - self._attn_type = AttentionType.ENCODER_DECODER elif atype == AttentionType.ENCODER: assert self.is_all_encoder_attn_metadata_set, \ "Must set self.encoder_seq_lens in order to perform cross-attention" - self._attn_type = AttentionType.ENCODER - else: - # AttentionType.{ENCODER,DECODER} - self._attn_type = atype + self._attn_type = atype @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -224,7 +216,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: (self.encoder_seq_lens is not None) assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) - #assert self.context_lens_tensor is not None assert self.block_tables is not None query_start_loc = None if self.query_start_loc is None \ From 196e671a2fa583c7b0beb65236b61c0df54bff4b Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 15:45:03 -0400 Subject: [PATCH 208/239] removed unnecessary tests --- tests/kernels/test_encoder_decoder_attn.py | 254 +-------------------- tests/kernels/utils.py | 6 +- 2 files changed, 8 insertions(+), 252 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 922e3ddb43dd8..8b93bb8fe6782 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -149,9 +149,7 @@ class that Attention will automatically select when it is constructed. def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt - scale = test_rsrcs.scale ''' Set up test vectors & data structures for encoder attention test. @@ -181,6 +179,10 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ implementation, and (3) KVCache field set to None ''' + (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + + scale = test_rsrcs.scale + max_kv_seq_len = max_q_seq_len # Make test tensors @@ -868,250 +870,4 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: - ''' - Confirm encoder/decoder models will fail with NotImplemented - if chunked prefill is enabled. - - This test - 1. Executes a subset of test setup code from - test_e2e_enc_dec_attn() (everything up to encoder - execution); see test_e2e_enc_dec_attn() for more context - on how this code works. - - 2. Modifies the prefill-phase attention metadata structure - to imply a chunked-prefill scenario - - 3. Attempts to execute decoder self-attention - - 4. Asserts that that decoder self-attention fails & with the correct - error message - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) - - # Encoder attention setup - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Decoder self-attention setup - - dec_qkv, \ - prephase_dec_test_params, \ - _, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - - # Cross-attention setup - - prephase_cross_test_params, \ - _, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_packed_actual_output: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - - # Meat of the test: require that chunked prefill triggers failure. - # - # Set up a contrived scenario where the attention metadata - # is configured for chunked prefill & decoder self- - # attention. Required that this triggers a NotImplementedError. - # - # We assume that decode_attn_metadata.num_prefill_tokens > 1 - # already; the line below sets up a chunked prefill - # metadata configuration where there is nominally a mix - # of prefill and decode tokens. - prephase_attn_metadata.num_decode_tokens = 1 - with pytest.raises(NotImplementedError) as exc_info: - - # Doomed decoder self-attention - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata) - - # "Encoder decoder models do not currently support chunked prefill" - # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL - - -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES_FOR_UNSUPP) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) -@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_backend_fails_for_prefix_caching_enc_dec(num_heads: int, - head_size: int, - backend_name: str, - batch_size: int, - block_size: int, - max_dec_seq_len: int, - max_enc_seq_len: int, - monkeypatch) -> None: - ''' - Confirm encoder/decoder models will fail with NotImplemented - if prefix caching is enabled. - - This test - 1. Executes a subset of test setup code from - test_e2e_enc_dec_attn() (everything up to encoder - execution); see test_e2e_enc_dec_attn() for more context - on how this code works. - - 2. Modifies the prefill-phase attention metadata structure - to imply a prefix caching scenario - - 3. Attempts to execute decoder self-attention - - 4. Asserts that that decoder self-attention fails & with the correct - error message - - Note on ROCm/HIP: currently encoder/decoder models are not supported on - AMD GPUs, therefore this test simply is skipped if is_hip(). - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) - - # Encoder attention setup - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Decoder self-attention setup - - dec_qkv, \ - prephase_dec_test_params, \ - _, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) - - # Cross-attention setup - - prephase_cross_test_params, \ - _, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) - - # Shared prefill metadata structure - - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_packed_actual_output: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_packed_actual_output) - - # Meat of the test: require that prefix caching triggers failure. - # - # Set up a contrived scenario where the attention metadata - # is configured for prefix caching & decoder self- - # attention. Require that this triggers a NotImplementedError. - with pytest.raises(NotImplementedError) as exc_info: - # In XFormers backend, the trigger for utilizing the - # prefix caching kernel is - # - # kv_cache is not None and prefill_meta.block_tables.numel() > 0 - # - # We can shallowly emulate a prefix caching scenario by passing - # in a non-None KV cache in test_rsrcs (already the - # case) and then tweaking the cached prefill attention metadata - # from the encoder run to have a non-empty (gibberish) block - # table. This block table will never actually be used, because - # its presence will signify to the backend a prefix-caching - # scenario and (given that the attention metadata structure - # is configured for an encoder/decoder scenario too) trigger - # a NotImplemented a exception. - - num_seqs = len( - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens) - - prephase_attn_metadata._cached_prefill_metadata.block_tables = \ - torch.randint( - 0, 10, (num_seqs, 1)) - - _run_decoder_self_attention_test(test_rsrcs, prephase_dec_test_params, - prephase_attn_metadata) - - # "Encoder decoder models do not currently support prefix caching" - # or something to that effect - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING + decphase_cross_pckd_act_out) \ No newline at end of file diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 71cd93cd3df92..608a2b68fee81 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -696,9 +696,9 @@ def make_test_metadata( Construct fake attention metadata for a given test phase (prefill-phase or decode-phase). - encoder_test_params and cross_test_params arguments all encoder - attention and enc/dec cross-attention to use distinct metadata values - from decoder self-attention (decoder_test_params.) + encoder_test_params and cross_test_params arguments allow encoder + attention and enc/dec cross-attention (respectively) to use distinct + metadata values from decoder self-attention (decoder_test_params.) if encoder_test_params and cross_test_params are None, the attention metadata will support decoder-only scenario. From 03e5d8135731854460b62cac16149cbc8bb8ade9 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:04:28 -0400 Subject: [PATCH 209/239] assert value None-ness matches key None-ness --- vllm/attention/backends/xformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e10b55fb6fcd..c62b782d928f5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -478,9 +478,11 @@ def forward( """ query = query.view(-1, self.num_heads, self.head_size) if key is not None: + assert value is not None key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which From 1f3874db00ad9d3dc17239d92428107d9701c700 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:08:16 -0400 Subject: [PATCH 210/239] comment fix --- tests/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 608a2b68fee81..58938493603dd 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -31,7 +31,7 @@ class QKVInputs(NamedTuple): ''' Data structure for representing unpacked attention inputs, - query/key/value. + query/key/values and their sequence lengths. Attributes: From 528b4a71e3824d1aba1da9004c91aa9df1a713e6 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:19:25 -0400 Subject: [PATCH 211/239] Remove util fxns & error strings for unneeded tests --- tests/kernels/test_encoder_decoder_attn.py | 7 +- vllm/attention/backends/utils.py | 80 +--------------------- vllm/attention/backends/xformers.py | 16 +---- 3 files changed, 6 insertions(+), 97 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 8b93bb8fe6782..af3614aaa2221 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -17,9 +17,7 @@ from tests.kernels.utils import * from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType -from vllm.attention.backends.utils import ( - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING, - STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] @@ -149,7 +147,6 @@ class that Attention will automatically select when it is constructed. def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ -> PhaseTestParameters: - ''' Set up test vectors & data structures for encoder attention test. @@ -870,4 +867,4 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) \ No newline at end of file + decphase_cross_pckd_act_out) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 45a6f4af37d13..24c89fd7967a1 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,84 +1,8 @@ -"""Attention utils""" - -from vllm.attention import AttentionMetadata +"""Attention backend utils""" # Error string(s) for encoder/decoder # unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ -"Chunked prefill is not currently " + \ -"supported with encoder/decoder models." - STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ "ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." - -STR_NOT_IMPL_ENC_DEC_NON_XFORMERS_BACKEND = \ -"Currently only the XFormers backend " + \ - "supports encoder/decoder models." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING = \ -"Prefix caching is not currently supported " + \ -"with encoder/decoder models" - -# Check for unsupported encoder/decoder scenarios - - -def is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata) -> bool: - ''' - Return True of the attn_metadata argument contains - the metadata fields that would be required for - encoder attention, which proves that the user is - not running a purely decoder-only model. - - Assumes attn_metadata is derived from a backend that supports - encoder/decoder models. - - Arguments: - - * attn_metadata: instance of supported backend metadata. - Type annotation omitted to avoid circular import. - - - Returns: - - * True if attn_metadata is configured for an encoder/decoder model - ''' - return attn_metadata.is_all_encoder_attn_metadata_set - - -def fail_encoder_decoder_prefix_caching() -> None: - ''' - Fail with NotImplementedError & a message indicating - enc/dec + prefix caching is unsupported - ''' - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHING) - - -def assert_no_encdec_chunked_prefill_assuming_supported_backend( - attn_metadata: AttentionMetadata) -> None: - ''' - Fail if encoder/decoder model is being executed with - chunked prefill. - - Assumes we already know that the particular attention - backend in-use is supported. - - Arguments: - - * attn_metadata: Attention metadata structure - ''' - - if not is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata): - # Only care about encoder/decoder - # scenarios. - return - - if attn_metadata.num_prefill_tokens > 0 and \ - attn_metadata.num_decode_tokens > 0: - # Encoder/decoder models are currently incompatible - # with chunked prefill. - raise NotImplementedError( \ - STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL) +"with encoder/decoder models." \ No newline at end of file diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index c62b782d928f5..cec6ee5867ae0 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import ( - assert_no_encdec_chunked_prefill_assuming_supported_backend, - fail_encoder_decoder_prefix_caching, - is_encoder_decoder_metadata_assuming_supported_backend) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -125,7 +121,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Attention type enum. # - # * Impact on XFormersImpl.forward(): + # * Impact on XFormersImpl.forward(): # # * DECODER: normal decoder-only behavior; # use decoder self-attention block table @@ -139,7 +135,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # will match encoder sequence lengths, pass encoder sequence # attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ # max_encoder_seq_len) - # + # _attn_type: AttentionType = AttentionType.DECODER # Encoder sequence lengths representation @@ -489,11 +485,6 @@ def forward( # seqlen datastructures we utilize attn_type = attn_metadata.attention_type - # Raise NotImplementedError for unsupported encoder/decoder - # scenarios (has no effect on decoder-only models) - assert_no_encdec_chunked_prefill_assuming_supported_backend( - attn_metadata) - if (attn_type != AttentionType.ENCODER and \ kv_cache is not None): # KV-cache during decoder-self- or @@ -571,9 +562,6 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: - if is_encoder_decoder_metadata_assuming_supported_backend( - attn_metadata): - fail_encoder_decoder_prefix_caching() assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None From b3c3411e26b7cf6f27604825d99a920c34605c9c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Fri, 14 Jun 2024 16:39:35 -0400 Subject: [PATCH 212/239] formatting --- tests/kernels/test_encoder_decoder_attn.py | 9 +++++---- vllm/attention/backends/xformers.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3212f331c47b2..99a5ae7b5f808 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -679,6 +679,7 @@ def _run_encoder_decoder_cross_attention_test( return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, kv_cache, attn_metadata) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -688,10 +689,9 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) def test_encoder_only(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch): - + batch_size: int, block_size: int, max_dec_seq_len: int, + max_enc_seq_len: int, monkeypatch): + # Force Attention wrapper backend override_backend_env_variable(monkeypatch, backend_name) @@ -733,6 +733,7 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6745e3bba7c9f..c25957ea156a0 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -287,7 +287,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + _attn_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, From e229e0018138698bf13135f067eaf32a8cbf9167 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Sun, 16 Jun 2024 22:47:04 -0400 Subject: [PATCH 213/239] format --- tests/kernels/test_encoder_decoder_attn.py | 16 ++++++-- tests/kernels/utils.py | 45 +++++++++++++--------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 99a5ae7b5f808..ef6c0fa9876b1 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -221,7 +221,7 @@ def _decoder_attn_setup( test_pt: TestPoint, test_rsrcs: TestResources, block_base_addr: int = 0, -) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: +) -> Tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]: ''' Set up test vectors & data structures for self-attention test. @@ -390,8 +390,8 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, PhaseTestParameters, test_pt: TestPoint, test_rsrcs: TestResources, - block_base_addr: Optional[int]=0) \ - -> tuple[PhaseTestParameters, + block_base_addr: int=0) \ + -> Tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -456,6 +456,9 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, for decode phase. ''' + assert encoder_test_params.packed_qkvo.packed_qkv is not None + assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None + (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, max_encoder_seq_len, _) = test_pt @@ -467,6 +470,7 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_q_seq_lens = \ prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + assert prefill_q_seq_lens is not None cross_kv, \ _, \ @@ -591,6 +595,7 @@ def _run_encoder_attention_test(attn: Attention, assert attn_metadata.num_decode_tokens == 0 attn_metadata.attention_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, None, attn_metadata) @@ -624,6 +629,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, kv_cache = test_rsrcs.kv_cache attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv + assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, kv_cache, attn_metadata) @@ -664,6 +670,8 @@ def _run_encoder_decoder_cross_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' + assert decoder_test_params.packed_qkvo.packed_qkv is not None + attn_metadata.attention_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache @@ -839,7 +847,7 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, cross_block_base_addr) # Shared prefill metadata structure - + assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None prephase_attn_metadata: AttentionMetadata = make_test_metadata( test_rsrcs.attn_backend, True, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 9ff07bc9a6264..ffa6b69ef2374 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,7 +2,7 @@ import itertools import random -from typing import List, NamedTuple, Optional, Union +from typing import Any, List, NamedTuple, Optional, Tuple, Union import pytest import torch @@ -83,10 +83,10 @@ class PackedQKVInputs(NamedTuple): query: torch.Tensor key: torch.Tensor value: torch.Tensor - q_start_loc_list: List[int] - kv_start_loc_list: List[int] - q_seq_lens: List[int] - kv_seq_lens: List[int] + q_start_loc_list: Optional[List[int]] + kv_start_loc_list: Optional[List[int]] + q_seq_lens: Optional[List[int]] + kv_seq_lens: Optional[List[int]] class PackedQKVO(NamedTuple): @@ -102,7 +102,7 @@ class PackedQKVO(NamedTuple): x head_size) known-correct attention output ''' - packed_qkv: PackedQKVInputs + packed_qkv: Optional[PackedQKVInputs] ideal_output: torch.Tensor @@ -136,7 +136,7 @@ class PhaseTestParameters(NamedTuple): ''' packed_qkvo: PackedQKVO - kv_mmap: KVMemoryMap + kv_mmap: Optional[KVMemoryMap] def override_backend_env_variable(mpatch: pytest.MonkeyPatch, @@ -185,6 +185,9 @@ def ref_masked_attention(query: torch.Tensor, * Attention result, batch_size x q_padded_seq_len x num_heads x head_size ''' + assert q_seq_lens is not None + assert kv_seq_lens is not None + batch_size = query.shape[0] assert (len(q_seq_lens) == batch_size) assert (len(kv_seq_lens) == batch_size) @@ -219,10 +222,10 @@ def make_qkv( num_heads: int, head_size: int, device: Union[torch.device, str], - force_kv_seq_lens: List[int] = None, + force_kv_seq_lens: Optional[List[int]] = None, attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, -) -> tuple[QKVInputs, QKVInputs, QKVInputs]: +) -> Tuple[QKVInputs, QKVInputs, QKVInputs]: ''' Construct QKV test tensors for self- and cross-attention. @@ -276,8 +279,9 @@ def make_qkv( kv_seq_lens = q_seq_lens else: # K,V seq lens are distinct from Q seq lens & random + assert max_kv_seq_len is not None if force_max_len: - kv_seq_lens = [max_kv_seq_len for _ in range(batch_size)] + kv_seq_lens = [max_kv_seq_len] * batch_size else: kv_seq_lens = [ random.randint(2, max_kv_seq_len) for _ in range(batch_size) @@ -350,7 +354,7 @@ def make_qkv( def pack_tensor( unpacked_tensor: torch.Tensor, seq_lens: List[int], - device: Union[torch.device, str]) -> tuple[torch.Tensor, List[int]]: + device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]: ''' Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where @@ -454,10 +458,10 @@ def make_backend(backend_name: str) -> AttentionBackend: def _make_metadata_tensors( - seq_lens: List[int], context_lens: List[int], encoder_seq_lens: List[int], - device: Union[torch.device, str] -) -> tuple[torch.Tensor, torch.Tensor, int, int, Optional[List[int]], - torch.Tensor, int]: + seq_lens: Optional[List[int]], context_lens: Optional[List[int]], + encoder_seq_lens: Optional[List[int]], device: Union[torch.device, str] +) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[List[int]], + torch.Tensor, Optional[int]]: ''' Build scalar & tensor values required to build attention metadata structure. @@ -603,7 +607,7 @@ def make_block_tables_slot_mapping( block_size: int, seq_lens: List[int], device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple[torch.Tensor, List[int], int]: + block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]: ''' Construct fake block tables & slot mappings. @@ -685,8 +689,8 @@ def make_block_tables_slot_mapping( def make_test_metadata( attn_backend: AttentionBackend, is_prompt: bool, - seq_lens: List[int], - decoder_test_params: PhaseTestParameters, + seq_lens: Optional[List[int]], + decoder_test_params: Optional[PhaseTestParameters], default_attn_type: AttentionType, device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, @@ -765,6 +769,7 @@ def make_test_metadata( else: # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths + assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens if cross_test_params is None: @@ -819,6 +824,10 @@ def make_test_metadata( else: # not is_prompt # Decode-phase scenario + assert kv_mmap is not None + assert num_prefill_or_decode_tokens is not None + assert seq_lens is not None + num_prefills = 0 num_prefill_tokens = 0 num_decode_tokens = num_prefill_or_decode_tokens From 7b9cb7f4339364b66180bf5cf7015f8fea67479d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 11:01:05 -0400 Subject: [PATCH 214/239] Replace attn_metadata.attention_type and attn_metadata._attn_type with attn_type argument to forward() --- tests/kernels/test_encoder_decoder_attn.py | 15 +-- tests/kernels/utils.py | 5 - vllm/attention/backends/xformers.py | 138 ++++++++++----------- vllm/attention/layer.py | 15 ++- 4 files changed, 85 insertions(+), 88 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index ef6c0fa9876b1..de33840bf57dd 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -593,11 +593,11 @@ def _run_encoder_attention_test(attn: Attention, & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_metadata.attention_type = AttentionType.ENCODER + attn_type=AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, attn_metadata) + None, attn_metadata, attn_type=attn_type) def _run_decoder_self_attention_test(test_rsrcs: TestResources, @@ -625,13 +625,13 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' + attn_type = AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache - attn_metadata.attention_type = AttentionType.DECODER packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata) + kv_cache, attn_metadata, attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -672,7 +672,7 @@ def _run_encoder_decoder_cross_attention_test( ''' assert decoder_test_params.packed_qkvo.packed_qkv is not None - attn_metadata.attention_type = AttentionType.ENCODER_DECODER + attn_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache if cross_test_params is None: @@ -685,7 +685,7 @@ def _run_encoder_decoder_cross_attention_test( value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, - value, kv_cache, attn_metadata) + value, kv_cache, attn_metadata, attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @@ -727,7 +727,6 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, decoder_test_params=None, encoder_test_params=enc_test_params, cross_test_params=None, - default_attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -855,7 +854,6 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, decoder_test_params=prephase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=prephase_cross_test_params, - default_attn_type=AttentionType.ENCODER, device=CUDA_DEVICE) # PREFILL: encoder attention @@ -903,7 +901,6 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, decoder_test_params=decphase_dec_test_params, encoder_test_params=enc_test_params, cross_test_params=decphase_cross_test_params, - default_attn_type=AttentionType.DECODER, device=CUDA_DEVICE) # DECODE: decoder self-attention test diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index ffa6b69ef2374..49232b209a186 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -691,7 +691,6 @@ def make_test_metadata( is_prompt: bool, seq_lens: Optional[List[int]], decoder_test_params: Optional[PhaseTestParameters], - default_attn_type: AttentionType, device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, cross_test_params: Optional[PhaseTestParameters] = None @@ -719,8 +718,6 @@ def make_test_metadata( * decoder_test_params: decoder self-attention test params; this function requires kv_mmap (memory mapping) field - * default_attn_type: value of attn_metadata.attention_type at - construction time * device: CPU or CUDA device * encoder_test_params: encoder attention test params; this function requires encoder query @@ -812,7 +809,6 @@ def make_test_metadata( block_tables=None if kv_mmap is None else \ kv_mmap.block_tables, use_cuda_graph=False, - _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, @@ -855,7 +851,6 @@ def make_test_metadata( context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, - _attn_type=default_attn_type, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d03417f071510..832cd561c9932 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -119,25 +119,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Begin encoder attn & enc/dec cross-attn fields... - # Attention type enum. - # - # * Impact on XFormersImpl.forward(): - # - # * DECODER: normal decoder-only behavior; - # use decoder self-attention block table - # * ENCODER: no KV caching; pass encoder sequence - # attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - # max_encoder_seq_len) to kernel, in lieu of decoder - # sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - # * ENCODER_DECODER: cross-attention behavior; - # use cross-attention block table for caching KVs derived - # from encoder hidden states; since KV sequence lengths - # will match encoder sequence lengths, pass encoder sequence - # attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - # max_encoder_seq_len) - # - _attn_type: AttentionType = AttentionType.DECODER - # Encoder sequence lengths representation encoder_seq_lens: Optional[List[int]] = None encoder_seq_lens_tensor: Optional[torch.Tensor] = None @@ -180,24 +161,6 @@ def is_all_cross_attn_metadata_set(self): (self.cross_slot_mapping is not None) and \ (self.cross_block_tables is not None) - @property - def attention_type(self) -> AttentionType: - return self._attn_type - - @attention_type.setter - def attention_type(self, atype: AttentionType) -> None: - - if atype == AttentionType.ENCODER_DECODER: - assert self.is_all_cross_attn_metadata_set, \ - "Must set self.encoder_seq_lens, self.cross_slot_mapping, " + \ - "self.cross_block_tables in order to perform cross-attention" - - elif atype == AttentionType.ENCODER: - assert self.is_all_encoder_attn_metadata_set, \ - "Must set self.encoder_seq_lens in order to perform cross-attention" - - self._attn_type = atype - @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self.num_prefills == 0: @@ -206,8 +169,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention # metadata structure - self._cached_prefill_metadata.attention_type = \ - self.attention_type return self._cached_prefill_metadata assert (self.seq_lens is not None) or \ @@ -244,7 +205,6 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - _attn_type=self.attention_type, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -261,8 +221,6 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: if self._cached_decode_metadata is not None: # Recover cached decode-phase attention # metadata structure - self._cached_decode_metadata.attention_type = \ - self.attention_type return self._cached_decode_metadata assert (self.seq_lens_tensor is not None) or \ (self.encoder_seq_lens_tensor is not None) @@ -286,8 +244,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: max_decode_seq_len=self.max_decode_seq_len, block_tables=block_tables, use_cuda_graph=self.use_cuda_graph, - _attn_type=self. - _attn_type, # Begin encoder & cross attn fields below... + # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, @@ -295,7 +252,8 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata -def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ +def _get_attn_bias(attn_metadata: XFormersMetadata, + attn_type: AttentionType) -> \ Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata @@ -304,12 +262,13 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ Arguments: * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention Returns: - * Appropriate attention bias value + * Appropriate attention bias value given the attention type ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: @@ -318,24 +277,24 @@ def _get_attn_bias(attn_metadata: XFormersMetadata) -> \ return attn_metadata.cross_attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]]) -> None: + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. - Depends on attn_metadata having a valid attention_type. - Arguments: * attn_metadata: Attention metadata structure associated with attention * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: attn_metadata.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: @@ -344,11 +303,13 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, attn_metadata.cross_attn_bias = attn_bias else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, - is_prompt: bool) -> tuple: + is_prompt: bool, + attn_type: AttentionType) \ + -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent @@ -362,6 +323,9 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, Arguments: * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention Returns: @@ -370,7 +334,6 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, * Appropriate block tables (or None) ''' - attn_type = attn_metadata.attention_type if attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run @@ -394,7 +357,7 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, None else: raise AttributeError( - f"Invalid attn_metadata.attention_type {str(attn_type)}") + f"Invalid attention type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -463,6 +426,7 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -481,15 +445,48 @@ def forward( (2) cross-attention key and value tensors do not grow during decode + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ + + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if attn_type == AttentionType.ENCODER: + if not attn_metadata.is_all_encoder_attn_metadata_set: + raise AttributeError("Encoder attention requires setting " + \ + "encoder metadata attributes.") + elif attn_type == AttentionType.ENCODER_DECODER: + if not attn_metadata.is_all_cross_attn_metadata_set: + raise AttributeError("Encoder/decoder cross-attention " + \ + "requires setting cross-attention " + \ + "metadata attributes.") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -501,7 +498,6 @@ def forward( # Self-attention vs. cross-attention will impact # which KV cache memory-mapping & which # seqlen datastructures we utilize - attn_type = attn_metadata.attention_type if (attn_type != AttentionType.ENCODER and \ kv_cache is not None): @@ -536,7 +532,7 @@ def forward( self.kv_cache_dtype, kv_scale) - if attn_metadata.attention_type != AttentionType.ENCODER: + if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. # Encoder/decoder cross-attention requires no chunked # prefill (100% prefill or 100% decode tokens, no mix) @@ -576,7 +572,7 @@ def forward( # block tables are empty if the prompt does not have a cached # prefix. out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta) + query, key, value, prefill_meta, attn_type=attn_type) assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: @@ -609,7 +605,9 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, False) + block_tables_arg = _get_seq_len_block_table_args(decode_meta, + False, + attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -634,6 +632,7 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -647,15 +646,12 @@ def _run_memory_efficient_xformers_forward( key: shape = [num_prefill_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally """ - # Enforce that the appropriate *_seq_lens attribute of attn_metadata - # (seq_lens or encoder_seq_lens) is set. - # seq_lens, \ - # _,\ - # _ = _get_seq_len_block_table_args(attn_metadata, True) - # assert seq_lens is not None - original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -673,10 +669,10 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata) + attn_bias = _get_attn_bias(attn_metadata,attn_type) if attn_bias is None: if self.alibi_slopes is None: - if attn_metadata.attention_type == \ + if attn_type == \ AttentionType.ENCODER_DECODER: assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None @@ -685,7 +681,7 @@ def _run_memory_efficient_xformers_forward( attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) else: - if attn_metadata.attention_type == AttentionType.ENCODER: + if attn_type == AttentionType.ENCODER: assert attn_metadata.encoder_seq_lens is not None # Default encoder self-attention mask is non-causal @@ -707,7 +703,7 @@ def _run_memory_efficient_xformers_forward( self.num_kv_heads, query.dtype, attn_metadata.seq_lens) - _set_attn_bias(attn_metadata, attn_bias) + _set_attn_bias(attn_metadata, attn_bias, attn_type) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index db55a31476fed..77be19772601f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -85,9 +85,18 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, + attn_type: Optional[AttentionType] = None ) -> torch.Tensor: - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) + if attn_type is None: + # Support backends without an attention type argument + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._kv_scale) + else: + # Backends with encoder/decoder support require attention + # type argument to distinguish between encoder attention, + # decoder self-attention, or encoder/decoder cross-attention + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._kv_scale, attn_type=attn_type) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore From 5f8c7f6cd6776cbda8289a5cee28e5cd8b858f4d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 11:26:24 -0400 Subject: [PATCH 215/239] Moved attention type for attn_metadata to attention forward(); added NotImplement failures to backends in non-decoder-only scenarios --- tests/kernels/test_encoder_decoder_attn.py | 26 ++++++--- vllm/attention/backends/abstract.py | 16 +++--- vllm/attention/backends/blocksparse_attn.py | 24 +++++--- vllm/attention/backends/flash_attn.py | 24 +++++--- vllm/attention/backends/flashinfer.py | 23 +++++--- vllm/attention/backends/ipex_attn.py | 23 +++++--- vllm/attention/backends/pallas.py | 23 +++++--- vllm/attention/backends/rocm_flash_attn.py | 24 +++++--- vllm/attention/backends/torch_sdpa.py | 23 +++++--- vllm/attention/backends/xformers.py | 63 ++++++++++----------- vllm/attention/layer.py | 35 ++++++------ 11 files changed, 173 insertions(+), 131 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index de33840bf57dd..f61b0a0dcc706 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -593,11 +593,15 @@ def _run_encoder_attention_test(attn: Attention, & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_type=AttentionType.ENCODER + attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, attn_metadata, attn_type=attn_type) + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + None, + attn_metadata, + attn_type=attn_type) def _run_decoder_self_attention_test(test_rsrcs: TestResources, @@ -630,8 +634,12 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - kv_cache, attn_metadata, attn_type=attn_type) + return attn.forward(packed_qkv.query, + packed_qkv.key, + packed_qkv.value, + kv_cache, + attn_metadata, + attn_type=attn_type) def _run_encoder_decoder_cross_attention_test( @@ -684,8 +692,12 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv.key value = None if cross_pckd_qkv is None else \ cross_pckd_qkv.value - return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, - value, kv_cache, attn_metadata, attn_type=attn_type) + return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, + key, + value, + kv_cache, + attn_metadata, + attn_type=attn_type) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ece0da25ee6f2..1ac8efc6b2584 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -124,12 +124,12 @@ def __init__( @abstractmethod def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index dce2b83615b7a..2afa4d286900e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -4,7 +4,7 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.blocksparse_attention.interface import ( LocalStridedBlockSparseAttn, get_head_sliding_step) from vllm.attention.ops.paged_attn import PagedAttention @@ -321,14 +321,14 @@ def __init__( ) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -340,6 +340,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "BlocksparseFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 1c48e2a0bb33d..16098ca68213d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -7,7 +7,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashAttentionBackend(AttentionBackend): @@ -250,14 +250,14 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention. Args: @@ -269,6 +269,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "FlashAttentionImpl") + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7b7959d257fac..2a7db5a35382e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,7 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class FlashInferBackend(AttentionBackend): @@ -185,15 +185,20 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: FlashInferMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "FlashInferImpl") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index f09b24f2a0304..c3328d6ed7665 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -7,7 +7,7 @@ from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -150,14 +150,14 @@ def split_kv_cache( return key_cache, value_cache def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: IpexAttnMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. Args: @@ -170,6 +170,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b203c5ec54c92..a453f642509a6 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -6,7 +6,7 @@ import torch_xla.experimental.dynamo_set_buffer_donor from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) class PallasAttentionBackend(AttentionBackend): @@ -120,14 +120,14 @@ def __init__( self.megacore_mode = "batch" def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], - attn_metadata: PallasMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + attn_metadata: PallasMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with Pallas attention. Args: @@ -141,6 +141,11 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9294068c64d1a..9a1d90107745d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -259,14 +259,14 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -278,6 +278,12 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "ROCmFlashAttentionImpl") + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c01e0a0a3a19c..efafc233da7df 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, AttentionType) from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.utils import is_cpu @@ -138,14 +138,14 @@ def __init__( "Please use xFormers backend instead.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: @@ -158,6 +158,11 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert kv_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + \ + "encoder/decoder cross-attention " + \ + "are not implemented for " + \ + "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 832cd561c9932..bf4d755d2a72f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -276,8 +276,7 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _set_attn_bias(attn_metadata: XFormersMetadata, @@ -302,8 +301,7 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, elif attn_type == AttentionType.ENCODER_DECODER: attn_metadata.cross_attn_bias = attn_bias else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, @@ -356,8 +354,7 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, attn_metadata.max_encoder_seq_len, \ None else: - raise AttributeError( - f"Invalid attention type {str(attn_type)}") + raise AttributeError(f"Invalid attention type {str(attn_type)}") class XFormersImpl(AttentionImpl[XFormersMetadata]): @@ -419,15 +416,14 @@ def __init__( f"Supported head sizes are: {suppored_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor], - attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor], + attn_metadata: "XFormersMetadata", + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. For decoder-only models: query, key and value must be non-None. @@ -477,15 +473,15 @@ def forward( # Check that appropriate attention metadata attributes are # selected for the desired attention type - if attn_type == AttentionType.ENCODER: - if not attn_metadata.is_all_encoder_attn_metadata_set: - raise AttributeError("Encoder attention requires setting " + \ - "encoder metadata attributes.") - elif attn_type == AttentionType.ENCODER_DECODER: - if not attn_metadata.is_all_cross_attn_metadata_set: - raise AttributeError("Encoder/decoder cross-attention " + \ - "requires setting cross-attention " + \ - "metadata attributes.") + if attn_type == AttentionType.ENCODER and \ + (not attn_metadata.is_all_encoder_attn_metadata_set): + raise AttributeError("Encoder attention requires setting " + \ + "encoder metadata attributes.") + elif attn_type == AttentionType.ENCODER_DECODER and \ + (not attn_metadata.is_all_cross_attn_metadata_set): + raise AttributeError("Encoder/decoder cross-attention " + \ + "requires setting cross-attention " + \ + "metadata attributes.") query = query.view(-1, self.num_heads, self.head_size) if key is not None: @@ -605,8 +601,8 @@ def forward( seq_lens_arg, \ max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, - False, + block_tables_arg = _get_seq_len_block_table_args(decode_meta, + False, attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( @@ -627,13 +623,12 @@ def forward( return output.view(-1, self.num_heads * self.head_size) def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -669,7 +664,7 @@ def _run_memory_efficient_xformers_forward( # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata,attn_type) + attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: if attn_type == \ diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 77be19772601f..984c8d77b94e7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,25 +78,22 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params) - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - attn_type: Optional[AttentionType] = None - ) -> torch.Tensor: - if attn_type is None: - # Support backends without an attention type argument - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale) - else: - # Backends with encoder/decoder support require attention - # type argument to distinguish between encoder attention, - # decoder self-attention, or encoder/decoder cross-attention - return self.impl.forward(query, key, value, kv_cache, attn_metadata, - self._kv_scale, attn_type=attn_type) + def forward(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER) \ + -> torch.Tensor: + + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._kv_scale, + attn_type=attn_type) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore From 525303c7c61127900680ff06b6cc09610001b71e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 18:06:33 -0400 Subject: [PATCH 216/239] num encoder tokens --- tests/kernels/utils.py | 5 +++++ vllm/attention/backends/xformers.py | 11 +++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 49232b209a186..94e7379123c7c 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -763,11 +763,14 @@ def make_test_metadata( if encoder_test_params is None: encoder_seq_lens = None + num_encoder_tokens = None else: # Encoder/decoder or encoder-only models only: # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens + num_encoder_tokens = None if encoder_seq_lens is None else \ + (sum(encoder_seq_lens)) if cross_test_params is None: cross_kv_mmap = None @@ -809,6 +812,7 @@ def make_test_metadata( block_tables=None if kv_mmap is None else \ kv_mmap.block_tables, use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, @@ -851,6 +855,7 @@ def make_test_metadata( context_lens_tensor=context_lens_tensor, block_tables=kv_mmap.block_tables, use_cuda_graph=False, + num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index bf4d755d2a72f..2cd61d0161f9e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -68,9 +68,6 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): updated from `CUDAGraphRunner.forward` API. """ - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| @@ -78,6 +75,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -126,6 +126,9 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum sequence length among encoder sequences max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None @@ -538,7 +541,7 @@ def forward( # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens - num_prefill_tokens = query.shape[0] + num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 if attn_type == AttentionType.DECODER: From ea37e17ab5ad7c084c13bf8e8492039d6a9bcdbf Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 18 Jun 2024 19:16:38 -0400 Subject: [PATCH 217/239] merge conflict; typing; formatting --- vllm/attention/backends/xformers.py | 1 + vllm/utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2cd61d0161f9e..a3f3d41a5491c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -541,6 +541,7 @@ def forward( # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 diff --git a/vllm/utils.py b/vllm/utils.py index edad75a6904b5..ebcd2181c1086 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -825,6 +825,7 @@ def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ float('-inf')).masked_fill(mask == 0, 0.0) return mask + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): @@ -834,4 +835,4 @@ def wrapper(*args, **kwargs) -> Any: return f(*args, **kwargs) wrapper.has_run = False # type: ignore[attr-defined] - return wrapper \ No newline at end of file + return wrapper From 597526a49e041ec99329add79ef272ce6e457b9e Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:18:02 -0400 Subject: [PATCH 218/239] removed extra line --- vllm/attention/backends/xformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index a3f3d41a5491c..d66d2ce9c277a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -556,9 +556,7 @@ def forward( decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] From a178b7a8c9838665ee7e169471206b70d62e1b71 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:20:00 -0400 Subject: [PATCH 219/239] changed nested if/else to elif/else in xformers mask computation code --- vllm/attention/backends/xformers.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d66d2ce9c277a..1090c2e062bd5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -677,19 +677,18 @@ def _run_memory_efficient_xformers_forward( # Default enc/dec cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens) else: - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - - # Default encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens) - else: - assert attn_metadata.seq_lens is not None - - # Default decoder self-attention mask is causal - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens) + assert attn_metadata.seq_lens is not None + + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) From 06c7f7500140c574d20a12079dbd1ef83db29688 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:28:42 -0400 Subject: [PATCH 220/239] reorganized helper functions that were only being used for testing into tests/kernels/utils.py from vllm/utils.py --- tests/kernels/test_encoder_decoder_attn.py | 3 +- tests/kernels/utils.py | 81 ++++++++++++++++++++- vllm/utils.py | 84 +--------------------- 3 files changed, 82 insertions(+), 86 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index f61b0a0dcc706..44519e6c9ec3f 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,7 +18,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.utils import is_hip, make_causal_mask, maybe_make_long_tensor +from vllm.utils import is_hip +from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 94e7379123c7c..b7c51c1bcf5c7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -10,8 +10,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import (make_tensor_with_pad, maybe_make_int_tensor, - maybe_make_long_tensor, maybe_max) +import numpy as np +from numbers import Number # String name of register which may be set in order to # force auto-selection of attention backend by Attention @@ -138,6 +138,83 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] +def make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + return torch.tensor(padded_x, dtype=dtype, device=device) + +def maybe_make_int_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D int torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D int torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.int, device=device) + +def maybe_make_long_tensor(_list: List[int], + device: Union[torch.device, str]) \ + -> torch.Tensor: + ''' + Convert Python int list to a 1D long torch.Tensor on `device` + + Returns: + + * If _list is not None: 1D long torch.Tensor on `device` + * None otherwise + ''' + return None if _list is None else torch.tensor( + _list, dtype=torch.long, device=device) + + +def maybe_max(_list: List) -> Optional[Number]: + ''' + Returns: + + * If _list is not None: max(_list) + * None otherwise + ''' + return None if _list is None else max(_list) + +def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ + -> torch.Tensor: + ''' + Create a q_max_seq_len x kv_max_seq_len causal mask + + Arguments: + + * q_max_seq_len: query max seq len + * kv_max_seq_len: key/value max seq len + + Returns: + + * 2D tensor, q_max_seq_len x kv_max_seq_len + ''' + + # Create a matrix where entry (i, j) is True if i >= j + mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) + # Replace True with float('-inf') and False with 0 + mask = mask.masked_fill(mask == 1, + float('-inf')).masked_fill(mask == 0, 0.0) + return mask def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index 7e3ebab01513f..7a4639950472d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -13,13 +13,12 @@ import warnings from collections import defaultdict from functools import lru_cache, partial, wraps -from numbers import Number from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) -import numpy as np + import psutil import torch import torch.types @@ -585,26 +584,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: "String must be a series of integers separated by commas " f"(e.g., 1, 2, 3). Given input: {s}") from e - -def make_tensor_with_pad( - x: List[List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], -) -> torch.Tensor: - """Make a padded tensor of a 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb - return torch.tensor(padded_x, dtype=dtype, device=device) - - def async_tensor_h2d( data: list, dtype: torch.dtype, @@ -799,67 +778,6 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) -def maybe_make_int_tensor(_list: List[int], - device: Union[torch.device, str]) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D int torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D int torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) - -def maybe_make_long_tensor(_list: List[int], - device: Union[torch.device, str]) \ - -> torch.Tensor: - ''' - Convert Python int list to a 1D long torch.Tensor on `device` - - Returns: - - * If _list is not None: 1D long torch.Tensor on `device` - * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) - - -def maybe_max(_list: List) -> Optional[Number]: - ''' - Returns: - - * If _list is not None: max(_list) - * None otherwise - ''' - return None if _list is None else max(_list) - -def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: - ''' - Create a q_max_seq_len x kv_max_seq_len causal mask - - Arguments: - - * q_max_seq_len: query max seq len - * kv_max_seq_len: key/value max seq len - - Returns: - - * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' - - # Create a matrix where entry (i, j) is True if i >= j - mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) - # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) - return mask - - #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): From 47c9f396fdcd40895597423ebfefe585b014c2f3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:32:52 -0400 Subject: [PATCH 221/239] removed attention_type --- tests/kernels/test_encoder_decoder_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 44519e6c9ec3f..2421c022b0ec5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -574,7 +574,7 @@ def _run_encoder_attention_test(attn: Attention, ''' Run encoder attention. - attn_metadata.attention_type is assigned AttentionType.ENCODER in order + attn.forward() is passed attn_type=AttentionType.ENCODER in order to configure the kernel invocation for encoder attention Requires attn_metadata.num_decode_tokens == 0 @@ -612,7 +612,7 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, ''' Run decoder self-attention test. - attn_metadata.attention_type is assigned AttentionType.DECODER + attn.forward() is passed attn_type=AttentionType.DECODER in order to configure the kernel invocation for decoder self-attention. Arguments: @@ -657,7 +657,7 @@ def _run_encoder_decoder_cross_attention_test( is None, this reflects that in decode-phase cross attention there is no growth in the key and value tensors. - attn_metadata.attention_type is assigned AttentionType.ENCODER_DECODER + attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER in order to configure the kernel invocation for encoder/decoder cross- attention. From 2f0b05bb805513e73eb0609ea87b6367ec9d4803 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:35:34 -0400 Subject: [PATCH 222/239] typing and formatting --- tests/kernels/test_encoder_decoder_attn.py | 2 +- tests/kernels/utils.py | 12 +++++++----- vllm/utils.py | 3 ++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 2421c022b0ec5..654cd621145c5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -15,11 +15,11 @@ import torch from tests.kernels.utils import * +from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.utils import is_hip -from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor HEAD_SIZES = [64, 256] diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index b7c51c1bcf5c7..3ae345bafa36e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,16 +2,16 @@ import itertools import random +from numbers import Number from typing import Any, List, NamedTuple, Optional, Tuple, Union +import numpy as np import pytest import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend -import numpy as np -from numbers import Number # String name of register which may be set in order to # force auto-selection of attention backend by Attention @@ -138,6 +138,7 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -156,7 +157,7 @@ def make_tensor_with_pad( padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) -def maybe_make_int_tensor(_list: List[int], +def maybe_make_int_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: ''' @@ -170,7 +171,7 @@ def maybe_make_int_tensor(_list: List[int], return None if _list is None else torch.tensor( _list, dtype=torch.int, device=device) -def maybe_make_long_tensor(_list: List[int], +def maybe_make_long_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: ''' @@ -185,7 +186,7 @@ def maybe_make_long_tensor(_list: List[int], _list, dtype=torch.long, device=device) -def maybe_max(_list: List) -> Optional[Number]: +def maybe_max(_list: Optional[List]) -> Optional[Number]: ''' Returns: @@ -216,6 +217,7 @@ def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ float('-inf')).masked_fill(mask == 0, 0.0) return mask + def override_backend_env_variable(mpatch: pytest.MonkeyPatch, backend_name: str) -> None: ''' diff --git a/vllm/utils.py b/vllm/utils.py index 7a4639950472d..cc11e4b00283f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,7 +18,6 @@ Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) - import psutil import torch import torch.types @@ -584,6 +583,7 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: "String must be a series of integers separated by commas " f"(e.g., 1, 2, 3). Given input: {s}") from e + def async_tensor_h2d( data: list, dtype: torch.dtype, @@ -778,6 +778,7 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f): From d23c28466765496049a1696d0a053a0a2505ce9a Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:38:08 -0400 Subject: [PATCH 223/239] typing and formatting; fixed escape sequences in comments --- tests/kernels/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 3ae345bafa36e..45f56e364175e 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -633,18 +633,18 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], Context: * Your goal is to test (1) prefill of N prompts, with prompt-lengths - {K_i \forall i \in [0,N)}, followed by (2) decoding of a single token + {K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token for all N prompts (N tokens total); the resultant sequence lengths - after decode would be {K_i + 1 for i \in [0,N)} + after decode would be {K_i + 1 for i \\in [0,N)} * The test you want to do requires (1) having the prefill slot mapping for all tokens present during prefill, the number of which is - M = \sum_i{K_i}, and (2) having the decode slot mapping for all N + M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N decoded tokens This function consumes a single 1D slot mapping, which is the concatenation of N slot mappings each of length K_i + 1 (corresponding to the sequence lengths after decode), with a total length of - P = \sum_i{K_i + 1} = M + N + P = \\sum_i{K_i + 1} = M + N The prefill-phase slot mapping results from excising the (K_i + 1)-th entry from each of the N subsequences in the slot mapping (i.e. omitting the From 1a6e5a31846e2ef886b66e9cc9216ffe983d0ec0 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:52:04 -0400 Subject: [PATCH 224/239] moved make_tensor_with_pad() helper function back to vllm.utils --- tests/kernels/utils.py | 22 ++-------------------- vllm/utils.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 45f56e364175e..7e1e084f650ea 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -5,7 +5,6 @@ from numbers import Number from typing import Any, List, NamedTuple, Optional, Tuple, Union -import numpy as np import pytest import torch @@ -13,6 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend +from vllm.utils import make_tensor_with_pad + # String name of register which may be set in order to # force auto-selection of attention backend by Attention # wrapper @@ -138,25 +139,6 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] - -def make_tensor_with_pad( - x: List[List[int]], - max_len: int, - pad: int, - dtype: torch.dtype, - device: Optional[Union[str, torch.device]], -) -> torch.Tensor: - """Make a padded tensor of a 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb - return torch.tensor(padded_x, dtype=dtype, device=device) - def maybe_make_int_tensor(_list: Optional[List[int]], device: Union[torch.device, str]) \ -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index cc11e4b00283f..1a778005bd867 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,6 +27,8 @@ from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger +import numpy as np + logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { @@ -573,6 +575,23 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Force garbage collection gc.collect() +def make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + """Make a padded tensor of a 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + return torch.tensor(padded_x, dtype=dtype, device=device) def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" From e2a46e3b7b9f9d1a9cc751046c3cddd1522620ed Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Tue, 25 Jun 2024 02:53:35 -0400 Subject: [PATCH 225/239] formatting --- tests/kernels/utils.py | 1 - vllm/utils.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 7e1e084f650ea..f0b0dd5dbaee6 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.xformers import XFormersBackend - from vllm.utils import make_tensor_with_pad # String name of register which may be set in order to diff --git a/vllm/utils.py b/vllm/utils.py index 1a778005bd867..127f86733a852 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -18,6 +18,7 @@ Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) +import numpy as np import psutil import torch import torch.types @@ -27,8 +28,6 @@ from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger -import numpy as np - logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { @@ -575,6 +574,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Force garbage collection gc.collect() + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -593,6 +593,7 @@ def make_tensor_with_pad( padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) + def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: From 75756b91e3753a9c2a60dbae42b2e46d3612ece5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 27 Jun 2024 11:28:19 -0400 Subject: [PATCH 226/239] removed redundant elif --- vllm/attention/backends/xformers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 83a63b6b8bf23..b1daaefc9f3b5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -276,10 +276,9 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: return attn_metadata.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return attn_metadata.cross_attn_bias else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") + # attn_type == AttentionType.ENCODER_DECODER + return attn_metadata.cross_attn_bias def _set_attn_bias(attn_metadata: XFormersMetadata, From a5018499e3b8475749a8d1af80e14c8d172cf2c7 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Thu, 27 Jun 2024 18:57:56 -0400 Subject: [PATCH 227/239] reverted unnecessarily vllm/utils.py changes --- vllm/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 6cc4af98e5b9f..92abdb3fb9b14 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -579,6 +579,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() +def str_to_int_tuple(s: str) -> Tuple[int, ...]: + """Convert a string to a tuple of integers.""" + try: + return tuple(map(int, s.split(","))) + except ValueError as e: + raise ValueError( + "String must be a series of integers separated by commas " + f"(e.g., 1, 2, 3). Given input: {s}") from e + + def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -598,16 +608,6 @@ def make_tensor_with_pad( return torch.tensor(padded_x, dtype=dtype, device=device) -def str_to_int_tuple(s: str) -> Tuple[int, ...]: - """Convert a string to a tuple of integers.""" - try: - return tuple(map(int, s.split(","))) - except ValueError as e: - raise ValueError( - "String must be a series of integers separated by commas " - f"(e.g., 1, 2, 3). Given input: {s}") from e - - def async_tensor_h2d( data: list, dtype: torch.dtype, From 5dbebbc6f3aafe706a5555119fefa519b71c4634 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:32:43 -0400 Subject: [PATCH 228/239] Update vllm/attention/backends/torch_sdpa.py nit: This will reduce the number of line changes and make the code look better. Co-authored-by: Woosuk Kwon --- vllm/attention/backends/torch_sdpa.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index eeef24ed4fb33..c2fefe5342362 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -145,7 +145,8 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: From 07df0e158a60b7d2a90407eecc868eaa10a58180 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Mon, 8 Jul 2024 09:33:03 -0400 Subject: [PATCH 229/239] Update vllm/attention/layer.py Co-authored-by: Woosuk Kwon --- vllm/attention/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5eee5914d3642..ae2607cf71dea 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -89,8 +89,8 @@ def forward(self, value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER) \ - -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: return self.impl.forward(query, key, From 7ce9a51d4fb3e286fdaa3a3ba12e60d0908d2d64 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 09:38:03 -0400 Subject: [PATCH 230/239] merged in first pieces of woosuk feedback & latest main; formatting --- vllm/attention/backends/torch_sdpa.py | 18 +++++++++--------- vllm/attention/layer.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c2fefe5342362..197981e47e921 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -138,15 +138,15 @@ def __init__( "Please use xFormers backend instead.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: TorchSDPAMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ae2607cf71dea..b8cc87be8c748 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -83,14 +83,15 @@ def __init__( alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params) - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: return self.impl.forward(query, key, From 9ae6728ecfe48769f578b0fad3f8e3950daa683d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 09:46:58 -0400 Subject: [PATCH 231/239] fixed specific point-changes requested by woosuk --- vllm/attention/backends/torch_sdpa.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 197981e47e921..6b9b2c6f4b5a4 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -160,9 +160,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + + "encoder/decoder cross-attention " + + "are not implemented for " + "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. From a1bf65212cab0933b2520d8557a9d9132fff8c3d Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:17:04 -0400 Subject: [PATCH 232/239] test_encoder_decoder_attn.py cleanup --- tests/kernels/test_encoder_decoder_attn.py | 315 +++++++++++---------- vllm/attention/backends/torch_sdpa.py | 6 +- 2 files changed, 166 insertions(+), 155 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 654cd621145c5..f25e7d480b6b3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -99,7 +99,7 @@ class TestResources(NamedTuple): kv_cache: torch.Tensor -def _make_test_resources(test_pt: TestPoint) -> TestResources: +def _make_test_resources(test_pt: TestPoint, ) -> TestResources: ''' Build key components for performing encoder/decoder attention test. @@ -146,8 +146,10 @@ class that Attention will automatically select when it is constructed. return TestResources(scale, attn_backend, attn, kv_cache) -def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ - -> PhaseTestParameters: +def _encoder_attn_setup( + test_pt: TestPoint, + test_rsrcs: TestResources, +) -> PhaseTestParameters: ''' Set up test vectors & data structures for encoder attention test. @@ -177,7 +179,16 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ implementation, and (3) KVCache field set to None ''' - (num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + _, + _, + max_q_seq_len, + _, + ) = test_pt scale = test_rsrcs.scale @@ -210,12 +221,9 @@ def _encoder_attn_setup(test_pt: TestPoint, test_rsrcs: TestResources) \ packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE) return PhaseTestParameters( - PackedQKVO( - packed_qkv, \ - packed_ideal_output), - - None # No KV cache - ) + PackedQKVO(packed_qkv, packed_ideal_output), + None # No KV cache + ) def _decoder_attn_setup( @@ -279,8 +287,16 @@ def _decoder_attn_setup( constructed in this function) ''' - (num_heads, head_size, _, batch_size, block_size, max_q_seq_len, _, - _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_q_seq_len, + _, + _, + ) = test_pt scale = test_rsrcs.scale @@ -288,15 +304,17 @@ def _decoder_attn_setup( # Build test tensors - qkv, \ - prefill_qkv, \ - decode_qkv = make_qkv(batch_size, - max_q_seq_len, - max_kv_seq_len, - num_heads, - head_size, - attn_type=AttentionType.DECODER, - device=CUDA_DEVICE) + ( + qkv, + prefill_qkv, + decode_qkv, + ) = make_qkv(batch_size, + max_q_seq_len, + max_kv_seq_len, + num_heads, + head_size, + attn_type=AttentionType.DECODER, + device=CUDA_DEVICE) # Compute correct answer using naive attention implementation # with causal attention mask @@ -351,49 +369,45 @@ def _decoder_attn_setup( prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) - decode_block_tables, \ - slot_mapping_list, \ - max_block_idx = make_block_tables_slot_mapping(block_size, - qkv.q_seq_lens, - device=CUDA_DEVICE, - block_base_addr = block_base_addr) - - prefill_slot_mapping, \ - decode_slot_mapping = split_slot_mapping(slot_mapping_list, - qkv.q_seq_lens, - device=CUDA_DEVICE) + ( + decode_block_tables, + slot_mapping_list, + max_block_idx, + ) = make_block_tables_slot_mapping(block_size, + qkv.q_seq_lens, + device=CUDA_DEVICE, + block_base_addr=block_base_addr) + + ( + prefill_slot_mapping, + decode_slot_mapping, + ) = split_slot_mapping(slot_mapping_list, + qkv.q_seq_lens, + device=CUDA_DEVICE) prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE) decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE) - return qkv, \ - PhaseTestParameters( # Prefill test params - PackedQKVO( - prefill_pckd_qkv, \ - prefill_packed_ideal_output), \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping)), \ - PhaseTestParameters( # Decode test params - PackedQKVO( - decode_pckd_qkv, \ - decode_packed_ideal_output), \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping)), \ - max_block_idx - -def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, - encoder_test_params: - PhaseTestParameters, - prefill_decoder_phase_test_params: - PhaseTestParameters, - test_pt: TestPoint, - test_rsrcs: TestResources, - block_base_addr: int=0) \ - -> Tuple[PhaseTestParameters, - PhaseTestParameters]: + return ( + qkv, + PhaseTestParameters( # Prefill test params + PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode test params + PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping)), + max_block_idx) + + +def _enc_dec_cross_attn_setup_reuses_query( + decoder_qkv: QKVInputs, + encoder_test_params: PhaseTestParameters, + prefill_decoder_phase_test_params: PhaseTestParameters, + test_pt: TestPoint, + test_rsrcs: TestResources, + block_base_addr: int = 0, +) -> Tuple[PhaseTestParameters, PhaseTestParameters]: ''' Set up test vectors & data structures for cross-attention test. @@ -460,22 +474,32 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, assert encoder_test_params.packed_qkvo.packed_qkv is not None assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None - (num_heads, head_size, _, batch_size, block_size, max_decoder_seq_len, - max_encoder_seq_len, _) = test_pt + ( + num_heads, + head_size, + _, + batch_size, + block_size, + max_decoder_seq_len, + max_encoder_seq_len, + _, + ) = test_pt scale = test_rsrcs.scale decoder_query = decoder_qkv.query decoder_seq_lens = decoder_qkv.q_seq_lens encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - prefill_q_seq_lens = \ - prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens + prefill_q_seq_lens = ( + prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens) assert prefill_q_seq_lens is not None - cross_kv, \ - _, \ - _ = make_qkv(batch_size, + ( + cross_kv, + _, + _, + ) = make_qkv(batch_size, max_decoder_seq_len, max_encoder_seq_len, num_heads, @@ -537,13 +561,14 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE) decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE) - decode_block_tables, \ - prefill_slot_mapping_list, \ - _ = make_block_tables_slot_mapping( - block_size, - cross_kv.kv_seq_lens, - block_base_addr=block_base_addr, - device=CUDA_DEVICE) + ( + decode_block_tables, + prefill_slot_mapping_list, + _, + ) = make_block_tables_slot_mapping(block_size, + cross_kv.kv_seq_lens, + block_base_addr=block_base_addr, + device=CUDA_DEVICE) prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list, device=CUDA_DEVICE) @@ -551,26 +576,20 @@ def _enc_dec_cross_attn_setup_reuses_query(decoder_qkv: QKVInputs, # Packed key/value (query is already provided) packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE) - return PhaseTestParameters( # Prefill-phase test params - PackedQKVO( - packed_cross_kv, \ - prefill_packed_ideal_output), \ - KVMemoryMap( - prefill_block_tables, \ - prefill_slot_mapping)), \ - PhaseTestParameters( # Decode-phase test params - PackedQKVO( - None, - decode_packed_ideal_output), \ - KVMemoryMap( - decode_block_tables, \ - decode_slot_mapping)) - - -def _run_encoder_attention_test(attn: Attention, - encoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata) \ - -> torch.Tensor: + return ( + PhaseTestParameters( # Prefill-phase test params + PackedQKVO(packed_cross_kv, prefill_packed_ideal_output), + KVMemoryMap(prefill_block_tables, prefill_slot_mapping)), + PhaseTestParameters( # Decode-phase test params + PackedQKVO(None, decode_packed_ideal_output), + KVMemoryMap(decode_block_tables, decode_slot_mapping))) + + +def _run_encoder_attention_test( + attn: Attention, + encoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run encoder attention. @@ -605,10 +624,11 @@ def _run_encoder_attention_test(attn: Attention, attn_type=attn_type) -def _run_decoder_self_attention_test(test_rsrcs: TestResources, - decoder_test_params: PhaseTestParameters, - attn_metadata: AttentionMetadata) \ - -> torch.Tensor: +def _run_decoder_self_attention_test( + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run decoder self-attention test. @@ -644,9 +664,11 @@ def _run_decoder_self_attention_test(test_rsrcs: TestResources, def _run_encoder_decoder_cross_attention_test( - test_rsrcs: TestResources, decoder_test_params: PhaseTestParameters, - cross_test_params: Optional[PhaseTestParameters], - attn_metadata: AttentionMetadata) -> torch.Tensor: + test_rsrcs: TestResources, + decoder_test_params: PhaseTestParameters, + cross_test_params: Optional[PhaseTestParameters], + attn_metadata: AttentionMetadata, +) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -689,10 +711,8 @@ def _run_encoder_decoder_cross_attention_test( value = None else: cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv - key = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.key - value = None if cross_pckd_qkv is None else \ - cross_pckd_qkv.value + key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) + value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) return attn.forward(decoder_test_params.packed_qkvo.packed_qkv.query, key, value, @@ -744,11 +764,8 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) + enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( + test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @@ -762,10 +779,16 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, - max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch) -> None: +def test_e2e_enc_dec_attn( + num_heads: int, + head_size: int, + backend_name: str, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, + monkeypatch, +) -> None: ''' End-to-end encoder/decoder test: @@ -840,23 +863,26 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # decoder self-attention block-table, i.e. a base address which the # encoder/decoder cross-attention block-table may build downward toward. - dec_qkv, \ - prephase_dec_test_params, \ - decphase_dec_test_params, \ - cross_block_base_addr = _decoder_attn_setup(test_pt,test_rsrcs) + ( + dec_qkv, + prephase_dec_test_params, + decphase_dec_test_params, + cross_block_base_addr, + ) = _decoder_attn_setup(test_pt, test_rsrcs) # Construct encoder/decoder cross-attention prefill-phase & decode-phase # test params, including key/value tensors, cross-attention memory-mapping - prephase_cross_test_params, \ - decphase_cross_test_params, \ - = _enc_dec_cross_attn_setup_reuses_query(dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr = \ - cross_block_base_addr) + ( + prephase_cross_test_params, + decphase_cross_test_params, + ) = _enc_dec_cross_attn_setup_reuses_query( + dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr=cross_block_base_addr) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None @@ -871,22 +897,17 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = \ - _run_encoder_attention_test( - test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) + enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) # PREFILL: decoder self-attention test - prephase_dec_pckd_act_out: torch.Tensor = \ - _run_decoder_self_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_attn_metadata) + prephase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, @@ -894,11 +915,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # PREFILL: encoder/decoder cross-attention test - prephase_cross_pckd_act_out: torch.Tensor = \ - _run_encoder_decoder_cross_attention_test( - test_rsrcs, - prephase_dec_test_params, - prephase_cross_test_params, + prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata) # - Is prefill encoder/decoder cross-attention correct? @@ -918,11 +936,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # DECODE: decoder self-attention test - decphase_dec_pckd_act_out: torch.Tensor = \ - _run_decoder_self_attention_test( - test_rsrcs, - decphase_dec_test_params, - decphase_attn_metadata) + decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, @@ -930,12 +945,8 @@ def test_e2e_enc_dec_attn(num_heads: int, head_size: int, backend_name: str, # DECODE: encoder/decoder cross-attention test - decphase_cross_pckd_act_out: torch.Tensor = \ - _run_encoder_decoder_cross_attention_test( - test_rsrcs, - decphase_dec_test_params, - None, - decphase_attn_metadata) + decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 6b9b2c6f4b5a4..48418f24870f9 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -160,9 +160,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + - "encoder/decoder cross-attention " + - "are not implemented for " + + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "TorchSDPABackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. From 4f27946dcfb73f0a60420eb3ca6c9a74f6c6d3d1 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:27:35 -0400 Subject: [PATCH 233/239] tests/kernels/utils.py cleanup --- tests/kernels/utils.py | 173 +++++++++++++++++++++-------------------- 1 file changed, 87 insertions(+), 86 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index f0b0dd5dbaee6..23d627820d247 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -138,9 +138,11 @@ class PhaseTestParameters(NamedTuple): packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] -def maybe_make_int_tensor(_list: Optional[List[int]], - device: Union[torch.device, str]) \ - -> torch.Tensor: + +def maybe_make_int_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: ''' Convert Python int list to a 1D int torch.Tensor on `device` @@ -152,9 +154,11 @@ def maybe_make_int_tensor(_list: Optional[List[int]], return None if _list is None else torch.tensor( _list, dtype=torch.int, device=device) -def maybe_make_long_tensor(_list: Optional[List[int]], - device: Union[torch.device, str]) \ - -> torch.Tensor: + +def maybe_make_long_tensor( + _list: Optional[List[int]], + device: Union[torch.device, str], +) -> torch.Tensor: ''' Convert Python int list to a 1D long torch.Tensor on `device` @@ -176,8 +180,11 @@ def maybe_max(_list: Optional[List]) -> Optional[Number]: ''' return None if _list is None else max(_list) -def make_causal_mask(q_max_seq_len: int, kv_max_seq_len: int) \ - -> torch.Tensor: + +def make_causal_mask( + q_max_seq_len: int, + kv_max_seq_len: int, +) -> torch.Tensor: ''' Create a q_max_seq_len x kv_max_seq_len causal mask @@ -394,22 +401,25 @@ def make_qkv( decode_q_seq_lens = [1 for _ in q_seq_lens] decode_kv_seq_lens = [1 for _ in kv_seq_lens] - return QKVInputs(query, # Overall QKV inputs - key, - value, - q_seq_lens, - kv_seq_lens), \ - QKVInputs(prefill_query, # Prefill subset of QKV sequences - prefill_key, - prefill_value, - prefill_q_seq_lens, - prefill_kv_seq_lens), \ - QKVInputs( - decode_query, # Decode subset of KV sequences - decode_key, - decode_value, - decode_q_seq_lens, - decode_kv_seq_lens) + return ( + QKVInputs( + query, # Overall QKV inputs + key, + value, + q_seq_lens, + kv_seq_lens), + QKVInputs( + prefill_query, # Prefill subset of QKV sequences + prefill_key, + prefill_value, + prefill_q_seq_lens, + prefill_kv_seq_lens), + QKVInputs( + decode_query, # Decode subset of KV sequences + decode_key, + decode_value, + decode_q_seq_lens, + decode_kv_seq_lens)) def pack_tensor( @@ -481,14 +491,11 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, qkv.kv_seq_lens, device=device) packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) - return PackedQKVInputs(packed_query, \ - packed_key, \ - packed_value, \ - q_start_loc_list, \ - kv_start_loc_list, \ - None if q_start_loc_list is None else \ - qkv.q_seq_lens, \ - qkv.kv_seq_lens) + return PackedQKVInputs( + packed_query, packed_key, packed_value, q_start_loc_list, + kv_start_loc_list, + (None if q_start_loc_list is None else qkv.q_seq_lens), + qkv.kv_seq_lens) def make_backend(backend_name: str) -> AttentionBackend: @@ -547,18 +554,13 @@ def _make_metadata_tensors( max_seq_len = maybe_max(seq_lens) encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) - max_encoder_seq_len = None if encoder_seq_lens is None else \ - max(encoder_seq_lens) + max_encoder_seq_len = (None if encoder_seq_lens is None else + max(encoder_seq_lens)) seq_start_loc = None - return seq_lens_tensor, \ - context_lens_tensor, \ - max_context_len, \ - max_seq_len, \ - seq_start_loc, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len + return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, + seq_start_loc, encoder_seq_lens_tensor, max_encoder_seq_len) def make_kv_cache(num_blocks: int, @@ -659,8 +661,8 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int], decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len - return maybe_make_long_tensor(prefill_slot_mapping, device), \ - maybe_make_long_tensor(decode_slot_mapping, device) + return (maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device)) def make_block_tables_slot_mapping( @@ -741,9 +743,7 @@ def make_block_tables_slot_mapping( device=device, ) - return block_tables_tensor, \ - slot_mapping_list, \ - max_block_idx + return (block_tables_tensor, slot_mapping_list, max_block_idx) def make_test_metadata( @@ -797,8 +797,8 @@ def make_test_metadata( # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None - kv_mmap = None if decoder_test_params is None else \ - decoder_test_params.kv_mmap + kv_mmap = (None + if decoder_test_params is None else decoder_test_params.kv_mmap) # This function constructs metadata assuming no chunked prefill, # i.e. 100% prefill tokens or 100% decode tokens @@ -811,11 +811,10 @@ def make_test_metadata( # seq_lens is None signals encoder-only # scenario, in which case num_prefills_or_decodes and # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = None if seq_lens is None else \ - len(seq_lens) + num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) - num_prefill_or_decode_tokens = None if seq_lens is None else \ - (sum(seq_lens) if is_prompt else len(seq_lens)) + num_prefill_or_decode_tokens = (None if seq_lens is None else ( + sum(seq_lens) if is_prompt else len(seq_lens))) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -829,8 +828,8 @@ def make_test_metadata( # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = None if encoder_seq_lens is None else \ - (sum(encoder_seq_lens)) + num_encoder_tokens = (None if encoder_seq_lens is None else + (sum(encoder_seq_lens))) if cross_test_params is None: cross_kv_mmap = None @@ -847,21 +846,22 @@ def make_test_metadata( num_prefill_tokens = num_prefill_or_decode_tokens num_decode_tokens = 0 - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, - slot_mapping=None if kv_mmap is None else \ - kv_mmap.slot_mapping, + slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -869,17 +869,16 @@ def make_test_metadata( max_prefill_seq_len=None if seq_lens is None else max(seq_lens), max_decode_seq_len=0, context_lens_tensor=context_lens_tensor, - block_tables=None if kv_mmap is None else \ - kv_mmap.block_tables, + block_tables=(None if kv_mmap is None else kv_mmap.block_tables), use_cuda_graph=False, num_encoder_tokens=num_encoder_tokens, encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=None if cross_kv_mmap is None else \ - cross_kv_mmap.slot_mapping, - cross_block_tables=None if cross_kv_mmap is None else \ - cross_kv_mmap.block_tables) + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) else: # not is_prompt # Decode-phase scenario @@ -892,16 +891,18 @@ def make_test_metadata( num_prefill_tokens = 0 num_decode_tokens = num_prefill_or_decode_tokens - seq_lens_tensor, \ - context_lens_tensor, \ - _, \ - _, \ - _, \ - encoder_seq_lens_tensor, \ - max_encoder_seq_len = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ( + seq_lens_tensor, + context_lens_tensor, + _, + _, + _, + encoder_seq_lens_tensor, + max_encoder_seq_len, + ) = _make_metadata_tensors(seq_lens, + context_lens, + encoder_seq_lens, + device=device) return attn_backend.make_metadata( num_prefills=num_prefills, @@ -919,10 +920,10 @@ def make_test_metadata( encoder_seq_lens=encoder_seq_lens, encoder_seq_lens_tensor=encoder_seq_lens_tensor, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=None if cross_kv_mmap is None else \ - cross_kv_mmap.slot_mapping, - cross_block_tables=None if cross_kv_mmap is None else \ - cross_kv_mmap.block_tables) + cross_slot_mapping=(None if cross_kv_mmap is None else + cross_kv_mmap.slot_mapping), + cross_block_tables=(None if cross_kv_mmap is None else + cross_kv_mmap.block_tables)) def assert_actual_matches_ideal(test_params: PhaseTestParameters, From 5ee30fed1d27dbef98dc3e4512741c9ca301197c Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:31:09 -0400 Subject: [PATCH 234/239] vllm/attention/backends/abstract.py cleanup --- vllm/attention/backends/abstract.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 8e386fd4e3ce8..adb8325168cdf 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -128,12 +128,13 @@ def __init__( @abstractmethod def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: T, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: raise NotImplementedError From 45fc9f71641bdd17c67997598463f12ead3998b2 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:35:00 -0400 Subject: [PATCH 235/239] vllm/attention/backends/blocksparse_attn.py cleanup --- vllm/attention/backends/blocksparse_attn.py | 23 +++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 470b6339a3006..fe4c4a45dca0d 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -321,14 +321,15 @@ def __init__( ) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -341,9 +342,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "BlocksparseFlashAttentionImpl") num_tokens, hidden_size = query.shape From 097aff2029e4560ae28bd7a7acf0f20509f803fe Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:36:05 -0400 Subject: [PATCH 236/239] vllm/attention/backends/flash_attn.py cleanup --- vllm/attention/backends/flash_attn.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f9a04f63acbec..048abed48d2e9 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -250,14 +250,15 @@ def __init__( f"Supported head sizes are: {support_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention. Args: @@ -270,9 +271,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "FlashAttentionImpl") # NOTE(woosuk): FlashAttention does not support FP8 KV cache. From d8a692b7dde0656696b726497030970aac0b53d3 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:39:37 -0400 Subject: [PATCH 237/239] cleaning up a number of backends & backends utils.py --- vllm/attention/backends/flashinfer.py | 23 +++++++++++----------- vllm/attention/backends/ipex_attn.py | 23 +++++++++++----------- vllm/attention/backends/pallas.py | 23 +++++++++++----------- vllm/attention/backends/rocm_flash_attn.py | 23 +++++++++++----------- vllm/attention/backends/utils.py | 5 ++--- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 615e427089865..b27e3e40f566d 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -217,19 +217,20 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: FlashInferMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: FlashInferMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "FlashInferImpl") num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 2404ff68fd47f..6a1295b1000bc 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -150,14 +150,15 @@ def split_kv_cache( return key_cache, value_cache def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Optional[torch.Tensor], - attn_metadata: IpexAttnMetadata, # type: ignore - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: IpexAttnMetadata, # type: ignore + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. Args: @@ -171,9 +172,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index fbfba742fb643..7a6954ceb6d6a 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -125,14 +125,15 @@ def __init__( self.megacore_mode = "batch" def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], - attn_metadata: PallasMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], + attn_metadata: PallasMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with Pallas attention. Args: @@ -147,9 +148,9 @@ def forward( """ assert kv_scale == 1.0 if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6107a3652b049..81b546c65c819 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -290,14 +290,15 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. Args: @@ -310,9 +311,9 @@ def forward( shape = [num_tokens, num_heads * head_size] """ if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " + \ - "encoder/decoder cross-attention " + \ - "are not implemented for " + \ + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " "ROCmFlashAttentionImpl") num_tokens, hidden_size = query.shape diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 82a1f46db6e09..a3cfc6e20748b 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -3,6 +3,5 @@ # Error string(s) for encoder/decoder # unsupported attention scenarios -STR_NOT_IMPL_ENC_DEC_ROCM_HIP = \ -"ROCm/HIP is not currently supported" + \ -"with encoder/decoder models." +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") From 5df73fc708bf3370a5f6d7f85cce4772d5c679b5 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:47:04 -0400 Subject: [PATCH 238/239] xformers backend cleanup --- vllm/attention/backends/xformers.py | 146 ++++++++++++++-------------- 1 file changed, 74 insertions(+), 72 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b1daaefc9f3b5..79aa8309bb225 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,9 +149,9 @@ def is_all_encoder_attn_metadata_set(self): ''' All attention metadata required for encoder attention is set. ''' - return (self.encoder_seq_lens is not None) and \ - (self.encoder_seq_lens_tensor is not None) and \ - (self.max_encoder_seq_len is not None) + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) @property def is_all_cross_attn_metadata_set(self): @@ -160,9 +160,9 @@ def is_all_cross_attn_metadata_set(self): Superset of encoder attention required metadata. ''' - return self.is_all_encoder_attn_metadata_set and \ - (self.cross_slot_mapping is not None) and \ - (self.cross_block_tables is not None) + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) @property def prefill_metadata(self) -> Optional["XFormersMetadata"]: @@ -174,24 +174,24 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: # metadata structure return self._cached_prefill_metadata - assert (self.seq_lens is not None) or \ - (self.encoder_seq_lens is not None) - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None - query_start_loc = None if self.query_start_loc is None \ - else self.query_start_loc[:self.num_prefills + 1] - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[:self.num_prefill_tokens] - seq_lens=None if self.seq_lens is None \ - else self.seq_lens[:self.num_prefills] - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[:self.num_prefills] - context_lens_tensor=None if self.context_lens_tensor is None else \ - self.context_lens_tensor[:self.num_prefills] - block_tables=None if self.block_tables is None else \ - self.block_tables[:self.num_prefills] + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersMetadata( @@ -225,16 +225,16 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: # Recover cached decode-phase attention # metadata structure return self._cached_decode_metadata - assert (self.seq_lens_tensor is not None) or \ - (self.encoder_seq_lens_tensor is not None) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) # Compute some attn_metadata fields which default to None - slot_mapping=None if self.slot_mapping is None else \ - self.slot_mapping[self.num_prefill_tokens:] - seq_lens_tensor=None if self.seq_lens_tensor is None else \ - self.seq_lens_tensor[self.num_prefills:] - block_tables=None if self.block_tables is None else \ - self.block_tables[self.num_prefills:] + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersMetadata( @@ -255,9 +255,11 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata -def _get_attn_bias(attn_metadata: XFormersMetadata, - attn_type: AttentionType) -> \ - Optional[AttentionBias]: + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: AttentionType, +) -> Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata according to attention type. @@ -283,7 +285,8 @@ def _get_attn_bias(attn_metadata: XFormersMetadata, def _set_attn_bias(attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType) -> None: + attn_type: AttentionType, + ) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. @@ -306,10 +309,11 @@ def _set_attn_bias(attn_metadata: XFormersMetadata, raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, - is_prompt: bool, - attn_type: AttentionType) \ - -> tuple: +def _get_seq_len_block_table_args( + attn_metadata: XFormersMetadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: ''' The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent @@ -341,20 +345,18 @@ def _get_seq_len_block_table_args(attn_metadata: XFormersMetadata, max_seq_len = attn_metadata.max_prefill_seq_len else: max_seq_len = attn_metadata.max_decode_seq_len - return attn_metadata.seq_lens_tensor, \ - max_seq_len, \ - attn_metadata.block_tables + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return attn_metadata.encoder_seq_lens_tensor, \ - attn_metadata.max_encoder_seq_len, \ - attn_metadata.cross_block_tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return attn_metadata.encoder_seq_lens_tensor, \ - attn_metadata.max_encoder_seq_len, \ - None + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") @@ -418,14 +420,15 @@ def __init__( f"Supported head sizes are: {suppored_head_sizes}.") def forward( - self, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: Optional[torch.Tensor], - attn_metadata: "XFormersMetadata", - kv_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor], + attn_metadata: "XFormersMetadata", + kv_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. For decoder-only models: query, key and value must be non-None. @@ -475,14 +478,14 @@ def forward( # Check that appropriate attention metadata attributes are # selected for the desired attention type - if attn_type == AttentionType.ENCODER and \ - (not attn_metadata.is_all_encoder_attn_metadata_set): - raise AttributeError("Encoder attention requires setting " + \ + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " "encoder metadata attributes.") - elif attn_type == AttentionType.ENCODER_DECODER and \ - (not attn_metadata.is_all_cross_attn_metadata_set): - raise AttributeError("Encoder/decoder cross-attention " + \ - "requires setting cross-attention " + \ + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " "metadata attributes.") query = query.view(-1, self.num_heads, self.head_size) @@ -497,8 +500,7 @@ def forward( # which KV cache memory-mapping & which # seqlen datastructures we utilize - if (attn_type != AttentionType.ENCODER and \ - kv_cache is not None): + if (attn_type != AttentionType.ENCODER and kv_cache is not None): # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -600,11 +602,11 @@ def forward( if decode_meta := attn_metadata.decode_metadata: - seq_lens_arg, \ - max_seq_len_arg,\ - block_tables_arg = _get_seq_len_block_table_args(decode_meta, - False, - attn_type) + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -629,7 +631,8 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER) -> torch.Tensor: + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. @@ -668,8 +671,7 @@ def _run_memory_efficient_xformers_forward( attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: - if attn_type == \ - AttentionType.ENCODER_DECODER: + if (attn_type == AttentionType.ENCODER_DECODER): assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None From 6cd595c3c879d4ee603bb6a5bc0f1724647a5135 Mon Sep 17 00:00:00 2001 From: Andrew Feldman Date: Mon, 8 Jul 2024 10:47:20 -0400 Subject: [PATCH 239/239] formatting --- vllm/attention/backends/xformers.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 79aa8309bb225..6cc5f1d1477ae 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -283,10 +283,11 @@ def _get_attn_bias( return attn_metadata.cross_attn_bias -def _set_attn_bias(attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType, - ) -> None: +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType, +) -> None: ''' Update appropriate attention bias field of attention metadata, according to attention type. @@ -626,13 +627,13 @@ def forward( return output.view(-1, self.num_heads * self.head_size) def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ) -> torch.Tensor: + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input.