From 61b7294dfe3536354276dbcd59343a2f0aabf9c6 Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 6 Apr 2024 06:13:58 -0700 Subject: [PATCH 01/14] ip --- .../basic_correctness/test_chunked_prefill.py | 63 ++++++ .../test_basic_distributed_correctness.py | 14 +- tests/prompts/example.txt | 7 - tests/worker/test_model_runner.py | 58 +++++- vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 30 ++- vllm/attention/backends/flash_attn.py | 12 +- vllm/attention/backends/torch_sdpa.py | 10 +- vllm/attention/backends/xformers.py | 88 ++++---- vllm/attention/layer.py | 4 +- vllm/attention/ops/paged_attn.py | 6 - vllm/config.py | 2 +- vllm/core/scheduler.py | 6 +- vllm/engine/llm_engine.py | 1 + .../parallel_utils/communication_op.py | 10 +- vllm/worker/model_runner.py | 197 +++++++++++------- 16 files changed, 353 insertions(+), 159 deletions(-) create mode 100644 tests/basic_correctness/test_chunked_prefill.py diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 0000000000000..41d4e6ee679d9 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,63 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + # "gpt2", + # "bigcode/tiny_starcoder_py", + # "EleutherAI/pythia-70m", + # "bigscience/bloom-560m", + # "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", + # "allenai/OLMo-1B", # Broken + # "bigcode/starcoder2-3b", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("chunked_prefill_token_size", [-1]) +@pytest.mark.parametrize("enforce_eager", [True]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, +) -> None: + # To pass the small model tests, we need full precision. + assert dtype == "float" + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + enforce_eager=enforce_eager, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a6422..7aa74e92540ca 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -25,6 +25,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [-1]) def test_models( hf_runner, vllm_runner, @@ -32,12 +33,23 @@ def test_models( model: str, dtype: str, max_tokens: int, + chunked_prefill_token_size: int, ) -> None: + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index e1b97bc6eee75..c90173da05f1c 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1,8 +1 @@ vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. -Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. -Compare and contrast artificial intelligence with human intelligence in terms of processing information. -Describe the basic components of a neural network and how it can be trained. -Write a short story about a robot that dreams for the first time. -Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. -Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. -Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa7..5885254b29232 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -36,7 +36,8 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens # Verify input metadata is correct for prompts. @@ -85,21 +86,21 @@ def test_prepare_prompt(batch_size): assert attn_metadata.use_cuda_graph is False assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + assert input_tokens == input_positions actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -143,7 +144,7 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) @@ -172,9 +173,9 @@ def test_prepare_decode_cuda_graph(batch_size): assert attn_metadata.use_cuda_graph is True assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + assert input_tokens == input_positions # Verify Sampling expected_selected_token_indices = [] @@ -190,3 +191,40 @@ def test_prepare_decode_cuda_graph(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output properly when there's no seq groups.""" + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + seq_group_metadata_list = [] + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( + model_runner._prepare_decode(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_prompt_lens) == 0 + + +# SANG-TODO Test chunked prefill case. diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c2..7636b34a16fed 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,4 +9,5 @@ "AttentionMetadata", "Attention", "get_attn_backend", + "AttentionMetadataPerStage", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6fa..bcf5467a1af0e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Generic import torch @@ -47,7 +47,8 @@ def copy_blocks( @dataclass -class AttentionMetadata: +class AttentionMetadataPerStage: + """Attention metadata for a specific stage. I.e., prefill or decode.""" def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -59,6 +60,29 @@ def asdict_zerocopy(self) -> Dict[str, Any]: } +T = TypeVar("T", bound=AttentionMetadataPerStage) + + +@dataclass +class AttentionMetadata(Generic[T]): + """Attention metadata for prefill and decode.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + kv_cache_dtype: str + prefill_metadata: Optional[T] + decode_metadata: Optional[T] + + class AttentionImpl(ABC): @abstractmethod @@ -80,7 +104,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4e0d9d1418b32..c10e54e71091c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,8 @@ from flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -53,7 +54,8 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -155,7 +153,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb79..895a117804da7 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,8 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -49,17 +50,14 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None @@ -113,7 +111,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index d349c3ef19ea7..696484669d1e9 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -10,7 +10,8 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -56,7 +57,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -67,19 +68,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -125,11 +117,11 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| + |<--------------- num_prefill_tokens --------------->| |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| + |<------------------ num_decode_tokens (M) ----------------->| |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -177,7 +169,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, + attn_metadata: AttentionMetadata[XFormersMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -209,9 +201,27 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + print(f"SANG-TODO original query: {query.size()}") + output = torch.empty_like(query) + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + print(f"SANG-TODO {num_prefill_tokens=} {num_decode_tokens=}") + print(f"SANG-TODO {query.size()=} {decode_query.size()=}") + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if num_prefill_tokens > 0: + print("SANG-TODO run prefill") + prefill_output = output[:num_prefill_tokens] + prefill_meta = attn_metadata.prefill_metadata + assert prefill_meta is not None # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. @@ -234,9 +244,8 @@ def forward( value.shape[-1]) if self.use_naive_attention: - output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(attn_metadata.prompt_lens): + for _, prompt_len in enumerate(prefill_meta.prompt_lens): end = start + prompt_len out = _naive_masked_attention( query[None, start:end], @@ -255,37 +264,40 @@ def forward( # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. - return output.reshape(num_tokens, hidden_size) + return prefill_output.reshape(num_tokens, hidden_size) - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + prefill_output = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + prefill_output = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + + if num_decode_tokens > 0: + print("SANG-TODO run decode") + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -307,10 +319,10 @@ def _run_memory_efficient_xformers_forward( tokens are flattened in to `query` input. Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ # Set attention bias if not provided. This typically happens at diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f94..070349bf29f45 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionMetadataPerStage, AttentionMetadata from vllm.attention.selector import get_attn_backend @@ -41,7 +41,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032eb..2d918491d6576 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,11 +13,6 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -31,7 +26,6 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: diff --git a/vllm/config.py b/vllm/config.py index 6762a75f25f28..d31d336034f8d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -586,7 +586,7 @@ def _verify_args(self) -> None: "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") - if self.max_num_batched_tokens < self.max_num_seqs: + if self.max_num_batched_tokens < self.max_num_seqs and not self.chunked_prefill_enabled: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f9374960..aa8d91da7c05a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -829,10 +829,10 @@ def _schedule_chunked_prefill(self): return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a9a4a7b83d934..ccaf0ef625529 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -611,6 +611,7 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + # SANG-TODO DO NOT PROCESS IF IT IS CHUNKED PREFILL. self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 9cbb40708dd5b..b46234425d1ee 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -171,10 +171,18 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list([metadata_list], src=src, group=group) + async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) + for async_handle in async_handles: + async_handle.wait() + else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a7..6bea82b32868d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import AttentionMetadata, AttentionMetadataPerStage, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger @@ -154,10 +154,9 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 + ) -> Tuple[List[int], List[int], Optional[AttentionMetadataPerStage], + List[int], List[int], List[int], List[int], Set[LoRARequest], + Optional[torch.Tensor], List[int]]: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -171,6 +170,9 @@ def _prepare_prompt( prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + if len(seq_group_metadata_list) == 0: + return [], [], None, [], [], [], [], set(), None, [] + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -269,20 +271,8 @@ def _prepare_prompt( max_subquery_len = max(subquery_lens) max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -334,11 +324,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, max_prompt_len=max_prompt_len, @@ -347,18 +334,16 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + lora_requests, multi_modal_input, slot_mapping) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + ) -> Tuple[List[int], List[int], Optional[AttentionMetadata], List[int], + List[int], Set[LoRARequest], List[int]]: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -368,6 +353,9 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + if len(seq_group_metadata_list) == 0: + return [], [], None, [], [], set(), [] + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -426,15 +414,6 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -442,9 +421,9 @@ def _prepare_decode( if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + assert context_lens.shape[0] == len(input_tokens) + assert context_lens.shape[0] == len(input_positions) + assert context_lens.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -466,11 +445,8 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping, prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, max_prompt_len=None, @@ -479,10 +455,10 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + lora_index_mapping, lora_prompt_mapping, lora_requests, + slot_mapping) def _prepare_sample( self, @@ -585,29 +561,65 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadataPerStage, + SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None + + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + (decode_input_tokens, decode_input_positions, decode_attn_metadata, + decode_lora_index_mapping, decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokesn = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_prompt_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -625,10 +637,28 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokesn": num_decode_tokesn } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + s = time.time() + # TODO(sang): It is dangerous if attn_metadata contains the same + # name key. + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + # NOTE(sang): Broadcast prefill/decode metadata separately for + # simplicity. Compared to one broadcast, its latency increases from + # 20us -> 35~60 us. We can potentially coalesce tensors to reduce the + # overhead. + metadata_dict = None + if decode_attn_metadata is not None: + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) + print("SANG-TODO broadcast takes ", + (time.time() - s) * 1000 * 1000, "us") else: + # Prefill metadata. metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") @@ -637,7 +667,10 @@ def prepare_input_tensors( lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokesn = metadata_dict.pop("num_decode_tokesn") + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -647,6 +680,19 @@ def prepare_input_tensors( generators=None, perform_sampling=False, ) + # Decode metadata. + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokesn, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, @@ -665,8 +711,9 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + decode_meta = attn_metadata.decode_metadata + if decode_meta is not None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -844,13 +891,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping[:batch_size], prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, max_prompt_len=None, @@ -859,6 +903,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + ) + attn_metadata = AttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + prefill_metadata=None, + decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, ) @@ -952,8 +1003,8 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "context_lens": attn_metadata.decode_metadata.context_lens, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -974,10 +1025,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["context_lens"].copy_( + attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() From 0272344da4d78bcbd05cf562bfd530b231073f59 Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 6 Apr 2024 09:11:39 -0700 Subject: [PATCH 02/14] ip --- .../basic_correctness/test_chunked_prefill.py | 32 ++++--- tests/prompts/example.txt | 7 ++ tests/samplers/test_sampler.py | 8 +- tests/test_logits_processor.py | 8 +- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/xformers.py | 23 ++--- vllm/attention/ops/prefix_prefill.py | 2 +- vllm/config.py | 10 +- vllm/core/scheduler.py | 24 +++-- vllm/engine/llm_engine.py | 5 +- vllm/worker/model_runner.py | 94 ++++++++++++------- 11 files changed, 134 insertions(+), 81 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 41d4e6ee679d9..7259bdd0941b0 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -1,27 +1,25 @@ """Compare the outputs of HF and vLLM when using greedy sampling. +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + Run `pytest tests/models/test_chunked_prefill.py`. """ import pytest MODELS = [ "facebook/opt-125m", - # "gpt2", - # "bigcode/tiny_starcoder_py", - # "EleutherAI/pythia-70m", - # "bigscience/bloom-560m", - # "microsoft/phi-2", - # "stabilityai/stablelm-3b-4e1t", - # "allenai/OLMo-1B", # Broken - # "bigcode/starcoder2-3b", + "meta-llama/Llama-2-7b-hf", ] @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -@pytest.mark.parametrize("chunked_prefill_token_size", [-1]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) def test_models( hf_runner, vllm_runner, @@ -31,9 +29,15 @@ def test_models( max_tokens: int, chunked_prefill_token_size: int, enforce_eager: bool, + tensor_parallel_size: int, ) -> None: + if tensor_parallel_size == 2: + if chunked_prefill_token_size != 16 and not enforce_eager: + pytest.skip( + f"Skip {chunked_prefill_token_size=} and {enforce_eager=} for high TP to save testing time." + ) # To pass the small model tests, we need full precision. - assert dtype == "float" + # assert dtype == "float" enable_chunked_prefill = False max_num_batched_tokens = None if chunked_prefill_token_size != -1: @@ -49,10 +53,12 @@ def test_models( dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model + print(vllm_outputs[0]) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt index c90173da05f1c..e1b97bc6eee75 100644 --- a/tests/prompts/example.txt +++ b/tests/prompts/example.txt @@ -1 +1,8 @@ vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b72282072..288072ad96299 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -50,7 +50,7 @@ def _do_sample( sampling_params: SamplingParams, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -60,11 +60,11 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + seq_lens, + subquery_lens=seq_lens) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index fe321520114f7..08e3282f9d4b9 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -64,7 +64,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -75,11 +75,11 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + seq_lens, + subquery_lens=seq_lens) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c10e54e71091c..d31f83613c15c 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -68,7 +68,7 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. + # prompt_lens_tensor stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, subquery_len, and seqlen. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 696484669d1e9..6628bfd188ddf 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -70,7 +70,7 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. + # prompt_lens_tensor stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, subquery_len, and seqlen. @@ -203,21 +203,18 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - print(f"SANG-TODO original query: {query.size()}") + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) decode_query = query[num_prefill_tokens:] query = query[:num_prefill_tokens] key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] - print(f"SANG-TODO {num_prefill_tokens=} {num_decode_tokens=}") - print(f"SANG-TODO {query.size()=} {decode_query.size()=}") assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens if num_prefill_tokens > 0: - print("SANG-TODO run prefill") - prefill_output = output[:num_prefill_tokens] prefill_meta = attn_metadata.prefill_metadata assert prefill_meta is not None # Prompt run. @@ -264,16 +261,19 @@ def forward( # with input tensor's size and stride (at least one # dimension spans across two contiguous subspaces). # Use reshape instead. - return prefill_output.reshape(num_tokens, hidden_size) + return output[:num_prefill_tokens].reshape( + num_tokens, hidden_size) - prefill_output = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta).squeeze(0) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out.squeeze(0) else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - prefill_output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, @@ -286,9 +286,10 @@ def forward( prefill_meta.max_subquery_len, self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out if num_decode_tokens > 0: - print("SANG-TODO run decode") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None output[num_prefill_tokens:] = PagedAttention.forward_decode( diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70f09224f1cf6..6ff5f2b2177ee 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -632,7 +632,7 @@ def context_attention_fwd(q, alibi_slopes=None): cap = torch.cuda.get_device_capability() - BLOCK = 128 if cap[0] >= 8 else 64 + BLOCK = 64 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/vllm/config.py b/vllm/config.py index d31d336034f8d..35f00aaf559a3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -563,9 +563,13 @@ def __init__( if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # For chunked prefill, choose the well-tuned batch size. + self.max_num_batched_tokens = 768 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index aa8d91da7c05a..4a669d46daac5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,7 +140,12 @@ def _sort_by_lora_ids(self) -> bool: @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + result = {} + for g in self.scheduled_seq_groups: + lora_request = g.seq_group.lora_request + if lora_request is not None: + result.add(lora_request) + return result @dataclass @@ -826,13 +831,12 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - - return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), + groups = (prefills.seq_groups + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + out = SchedulerOutputs( + scheduled_seq_groups=groups, num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -847,6 +851,8 @@ def _schedule_chunked_prefill(self): swapped_in.num_lookahead_slots), ) + return out + def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: @@ -907,7 +913,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ccaf0ef625529..96affa5ddddd2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -611,8 +611,9 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # SANG-TODO DO NOT PROCESS IF IT IS CHUNKED PREFILL. - self._process_sequence_group_outputs(seq_group, outputs) + # If uncomputed tokens > 0, it means prefill is not done. + if seq_group.get_num_uncomputed_tokens() == 0: + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 6bea82b32868d..c0d624937a6a8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -182,7 +182,8 @@ def _prepare_prompt( computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -194,13 +195,8 @@ def _prepare_prompt( # it contains output tokens. prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() + prompt_len = prefill_end prompt_lens.append(prompt_len) # NOTE: This only works for oooooooxxx style attention. @@ -210,6 +206,15 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + assert self.scheduler_config.chunked_prefill_enabled is not None + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this @@ -561,8 +566,8 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadataPerStage, - SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -573,7 +578,6 @@ def prepare_input_tensors( decode_reqs.append(seq_group_meta) # Prepare input tensors. - ( input_tokens, input_positions, @@ -599,7 +603,7 @@ def prepare_input_tensors( num_prefills = len(prompt_lens) num_prefill_tokens = len(input_tokens) - num_decode_tokesn = len(decode_input_tokens) + num_decode_tokens = len(decode_input_tokens) # Coalesce tensors. Note that attn_metadata is currently not # coalesced for simplicity. @@ -629,6 +633,16 @@ def prepare_input_tensors( lora_mapping = None # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = "mixed" + elif prefill_attn_metadata is not None: + batch_type = "prefill" + else: + batch_type = "decode" + metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -638,39 +652,48 @@ def prepare_input_tensors( "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokesn": num_decode_tokesn + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, } - s = time.time() - # TODO(sang): It is dangerous if attn_metadata contains the same - # name key. if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - # NOTE(sang): Broadcast prefill/decode metadata separately for - # simplicity. Compared to one broadcast, its latency increases from - # 20us -> 35~60 us. We can potentially coalesce tensors to reduce the - # overhead. - metadata_dict = None - if decode_attn_metadata is not None: + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == "mixed": + assert decode_attn_metadata is not None metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) - print("SANG-TODO broadcast takes ", - (time.time() - s) * 1000 * 1000, "us") + broadcast_tensor_dict(metadata_dict, src=0) else: # Prefill metadata. metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokesn = metadata_dict.pop("num_decode_tokesn") - prefill_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == "prefill" or batch_type == "mixed": + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -680,15 +703,18 @@ def prepare_input_tensors( generators=None, perform_sampling=False, ) - # Decode metadata. - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + + # if it is a mixed batch, decode attn_metadata is also broadcasted. + if batch_type == "mixed": + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + attn_metadata = AttentionMetadata( num_prefills=num_prefills, slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokesn, + num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, kv_cache_dtype=self.kv_cache_dtype, @@ -712,8 +738,9 @@ def execute_model( self.set_active_loras(lora_requests, lora_mapping) # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata decode_meta = attn_metadata.decode_metadata - if decode_meta is not None and decode_meta.use_cuda_graph: + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -908,6 +935,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: num_prefills=0, num_prefill_tokens=0, num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, From 01e5b3b6396e758f153802b7f87e0d7bce6d1590 Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 6 Apr 2024 09:14:07 -0700 Subject: [PATCH 03/14] working e2e --- tests/basic_correctness/test_chunked_prefill.py | 9 ++++----- tests/worker/test_model_runner.py | 2 +- vllm/attention/backends/abstract.py | 2 +- vllm/attention/layer.py | 3 ++- vllm/config.py | 3 ++- vllm/worker/model_runner.py | 3 ++- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 7259bdd0941b0..0e244215126d6 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -31,11 +31,10 @@ def test_models( enforce_eager: bool, tensor_parallel_size: int, ) -> None: - if tensor_parallel_size == 2: - if chunked_prefill_token_size != 16 and not enforce_eager: - pytest.skip( - f"Skip {chunked_prefill_token_size=} and {enforce_eager=} for high TP to save testing time." - ) + if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 + and not enforce_eager): + pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " + "for high TP to save testing time.") # To pass the small model tests, we need full precision. # assert dtype == "float" enable_chunked_prefill = False diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5885254b29232..f0e0bd07ebc55 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -194,7 +194,7 @@ def test_prepare_decode_cuda_graph(batch_size): def test_empty_seq_group(): - """Verify prepare prompt and decode returns empty output properly when there's no seq groups.""" + """Verify prepare prompt and decode returns empty output.""" model_config = ModelConfig( "facebook/opt-125m", "facebook/opt-125m", diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index bcf5467a1af0e..1ed03c83c975b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Generic +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 070349bf29f45..fc65ae108dbb1 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadataPerStage, AttentionMetadata +from vllm.attention.backends.abstract import (AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend diff --git a/vllm/config.py b/vllm/config.py index 35f00aaf559a3..c27b2b485b65b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -590,7 +590,8 @@ def _verify_args(self) -> None: "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") - if self.max_num_batched_tokens < self.max_num_seqs and not self.chunked_prefill_enabled: + if (self.max_num_batched_tokens < self.max_num_seqs + and not self.chunked_prefill_enabled): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c0d624937a6a8..f386a0c0ab84a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,7 +6,8 @@ import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, AttentionMetadataPerStage, get_attn_backend +from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, + get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger From 502bd194df5f7baddea2b2c498c5776d7dc9428f Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 6 Apr 2024 09:31:32 -0700 Subject: [PATCH 04/14] made it work with flash attn --- tests/samplers/test_sampler.py | 8 +-- tests/test_logits_processor.py | 8 +-- tests/worker/test_model_runner.py | 2 + vllm/attention/backends/flash_attn.py | 71 +++++++++++++++++++-------- vllm/attention/backends/torch_sdpa.py | 58 +++++++++++++++------- vllm/attention/backends/xformers.py | 21 ++++++-- vllm/attention/ops/prefix_prefill.py | 2 +- vllm/core/scheduler.py | 14 +++--- vllm/worker/model_runner.py | 18 ++++--- 9 files changed, 137 insertions(+), 65 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 288072ad96299..1626b72282072 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -50,7 +50,7 @@ def _do_sample( sampling_params: SamplingParams, ): seq_group_metadata_list = [] - seq_lens = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -60,11 +60,11 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - seq_lens, - subquery_lens=seq_lens) + prompt_lens, + subquery_lens=prompt_lens) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 08e3282f9d4b9..fe321520114f7 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -64,7 +64,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - seq_lens = [] + prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -75,11 +75,11 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - seq_lens, - subquery_lens=seq_lens) + prompt_lens, + subquery_lens=prompt_lens) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index f0e0bd07ebc55..eb0cb8bf58f70 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -39,6 +39,7 @@ def test_prepare_prompt(batch_size): _, _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device @@ -146,6 +147,7 @@ def test_prepare_decode_cuda_graph(batch_size): input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d31f83613c15c..90f9196f8230d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -68,7 +68,7 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] - # prompt_lens_tensor stored as a tensor. + # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, subquery_len, and seqlen. @@ -117,6 +117,15 @@ class FlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -186,52 +195,74 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if num_prefill_tokens > 0: + prefill_meta = attn_metadata.prefill_metadata + assert prefill_meta is not None # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + if num_decode_tokens > 0: # Decoding run. - output = PagedAttention.forward_decode( - query, + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 895a117804da7..4fde5507341c9 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -140,36 +140,54 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if num_prefill_tokens > 0: + prefill_meta = attn_metadata.prefill_metadata + assert prefill_meta is not None + + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) - attn_metadata.attn_bias = att_masks + att_masks = [None] * len(prefill_meta.prompt_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): + out = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(prefill_meta.prompt_lens, + prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -179,22 +197,26 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out + out[start:end, :, :] = sub_out start = end + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if num_decode_tokens > 0: + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None # Decoding run. - output = PagedAttention.forward_decode( + output[num_prefill_tokens:] = PagedAttention.forward_decode( query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 6628bfd188ddf..d28faa9fe3178 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -70,7 +70,7 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): is_prompt: bool # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] - # prompt_lens_tensor stored as a tensor. + # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] # NOTE(sang): Definition of context_len, subquery_len, and seqlen. @@ -117,11 +117,11 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens --------------->| + |<--------------- num_prefill_tokens -------------->| |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_decode_tokens (M) ----------------->| + |<---------------------- num_decode_tokens --------------------->| |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. @@ -129,6 +129,15 @@ class XFormersImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -205,8 +214,11 @@ def forward( num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] + # QKV for prefill. query = query[:num_prefill_tokens] key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] @@ -217,6 +229,7 @@ def forward( if num_prefill_tokens > 0: prefill_meta = attn_metadata.prefill_metadata assert prefill_meta is not None + # Prompt run. if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. @@ -267,7 +280,7 @@ def forward( out = self._run_memory_efficient_xformers_forward( query, key, value, prefill_meta).squeeze(0) assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out.squeeze(0) + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 6ff5f2b2177ee..70f09224f1cf6 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -632,7 +632,7 @@ def context_attention_fwd(q, alibi_slopes=None): cap = torch.cuda.get_device_capability() - BLOCK = 64 if cap[0] >= 8 else 64 + BLOCK = 128 if cap[0] >= 8 else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4a669d46daac5..eeb1d83b86c20 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -831,12 +831,12 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - groups = (prefills.seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups) - out = SchedulerOutputs( - scheduled_seq_groups=groups, + return SchedulerOutputs( + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -851,8 +851,6 @@ def _schedule_chunked_prefill(self): swapped_in.num_lookahead_slots), ) - return out - def _schedule(self) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f386a0c0ab84a..b0b596afbe7a0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -591,10 +591,15 @@ def prepare_input_tensors( multi_modal_input, slot_mapping, ) = self._prepare_prompt(prefill_reqs) - (decode_input_tokens, decode_input_positions, decode_attn_metadata, - decode_lora_index_mapping, decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping) = self._prepare_decode(decode_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -672,7 +677,6 @@ def prepare_input_tensors( metadata_dict = decode_attn_metadata.asdict_zerocopy() broadcast_tensor_dict(metadata_dict, src=0) else: - # Prefill metadata. metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") @@ -687,6 +691,7 @@ def prepare_input_tensors( num_decode_tokens = metadata_dict.pop("num_decode_tokens") batch_type = metadata_dict.pop("batch_type") + # Create an attention metadata. prefill_attn_metadata = None decode_attn_metadata = None if batch_type == "prefill" or batch_type == "mixed": @@ -705,7 +710,8 @@ def prepare_input_tensors( perform_sampling=False, ) - # if it is a mixed batch, decode attn_metadata is also broadcasted. + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. if batch_type == "mixed": metadata_dict = broadcast_tensor_dict(src=0) decode_attn_metadata = self.attn_backend.make_metadata( From afa247ebf77e5da426e2ac42dda2a2e32e3bd465 Mon Sep 17 00:00:00 2001 From: sang Date: Sat, 6 Apr 2024 10:41:41 -0700 Subject: [PATCH 05/14] fix cpu tests --- benchmarks/benchmark_throughput.py | 62 ++++++++++++------- .../test_basic_distributed_correctness.py | 13 +--- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/torch_sdpa.py | 6 +- 4 files changed, 44 insertions(+), 39 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d6bf18c82e465..c0d92d6ddeca1 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -74,25 +74,31 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,15 +219,15 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -335,6 +341,14 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 7aa74e92540ca..7ebc108c43c5c 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -25,7 +25,6 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("chunked_prefill_token_size", [-1]) def test_models( hf_runner, vllm_runner, @@ -33,23 +32,13 @@ def test_models( model: str, dtype: str, max_tokens: int, - chunked_prefill_token_size: int, ) -> None: - enable_chunked_prefill = False - max_num_batched_tokens = None - if chunked_prefill_token_size != -1: - enable_chunked_prefill = True - max_num_batched_tokens = chunked_prefill_token_size hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill) + vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 1ed03c83c975b..b188de267a2dc 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -65,7 +65,7 @@ def asdict_zerocopy(self) -> Dict[str, Any]: @dataclass class AttentionMetadata(Generic[T]): - """Attention metadata for prefill and decode.""" + """Attention metadata for prefill and decode batched together.""" # Total number of prefill requests. num_prefills: int # Number of prefill tokens. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 4fde5507341c9..032014acaaadc 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -210,8 +210,8 @@ def forward( decode_meta = attn_metadata.decode_metadata assert decode_meta is not None # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - query, + out = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, decode_meta.block_tables, @@ -223,6 +223,8 @@ def forward( self.alibi_slopes, kv_scale, ) + assert out.shape == output[num_prefill_tokens:].shape + output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) From e735cc25ccd7f0ed066f283b25ddd112a33b8368 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 17:03:59 -0700 Subject: [PATCH 06/14] Addressed code review. --- .../basic_correctness/test_chunked_prefill.py | 3 + vllm/attention/backends/abstract.py | 7 ++ vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/torch_sdpa.py | 5 +- vllm/attention/backends/xformers.py | 5 +- vllm/config.py | 3 +- vllm/core/scheduler.py | 11 ++- vllm/worker/model_runner.py | 82 +++++++++++++++---- 8 files changed, 83 insertions(+), 37 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 0e244215126d6..3732ac27f4b51 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -35,6 +35,8 @@ def test_models( and not enforce_eager): pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " "for high TP to save testing time.") + max_num_seqs = min(chunked_prefill_token_size, 256) + # To pass the small model tests, we need full precision. # assert dtype == "float" enable_chunked_prefill = False @@ -54,6 +56,7 @@ def test_models( enable_chunked_prefill=enable_chunked_prefill, tensor_parallel_size=tensor_parallel_size, enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b188de267a2dc..5d2a8f6134460 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -82,6 +82,13 @@ class AttentionMetadata(Generic[T]): prefill_metadata: Optional[T] decode_metadata: Optional[T] + def __post_init__(self): + if self.num_prefill_tokens > 0: + assert self.num_prefills > 0 + assert self.prefill_metadata is not None + if self.num_decode_tokens > 0: + assert self.decode_metadata is not None + class AttentionImpl(ABC): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 90f9196f8230d..e4523b14c1632 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -211,9 +211,7 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - if num_prefill_tokens > 0: - prefill_meta = attn_metadata.prefill_metadata - assert prefill_meta is not None + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 032014acaaadc..b1390d18f1bf6 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -156,10 +156,7 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - if num_prefill_tokens > 0: - prefill_meta = attn_metadata.prefill_metadata - assert prefill_meta is not None - + if prefill_meta := attn_metadata.prefill_metadata: if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index bce10f58518d6..b9885488c4dfc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -219,10 +219,7 @@ def forward( assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - if num_prefill_tokens > 0: - prefill_meta = attn_metadata.prefill_metadata - assert prefill_meta is not None - + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. diff --git a/vllm/config.py b/vllm/config.py index eaf90466101e6..5616efbc32e10 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -590,8 +590,7 @@ def _verify_args(self) -> None: "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") - if (self.max_num_batched_tokens < self.max_num_seqs - and not self.chunked_prefill_enabled): + if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index eeb1d83b86c20..2942eab735a92 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,12 +140,11 @@ def _sort_by_lora_ids(self) -> bool: @property def lora_requests(self) -> Set[LoRARequest]: - result = {} - for g in self.scheduled_seq_groups: - lora_request = g.seq_group.lora_request - if lora_request is not None: - result.add(lora_request) - return result + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b0b596afbe7a0..279c8133c072c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,6 +1,7 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from enum import IntEnum +from typing import Dict, List, NamedTuple, Optional, Set, Tuple import numpy as np import torch @@ -40,6 +41,39 @@ ] +class PreparePromptMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadataPerStage] + prompt_lens: List[int] + subquery_lens: List[int] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: List[int] + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + slot_mapping: List[int] + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + class ModelRunner: def __init__( @@ -155,9 +189,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[List[int], List[int], Optional[AttentionMetadataPerStage], - List[int], List[int], List[int], List[int], Set[LoRARequest], - Optional[torch.Tensor], List[int]]: + ) -> PreparePromptMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -341,15 +373,23 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input, slot_mapping) + return PreparePromptMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + prompt_lens=prompt_lens, + subquery_lens=subquery_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping, + ) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[List[int], List[int], Optional[AttentionMetadata], List[int], - List[int], Set[LoRARequest], List[int]]: + ) -> PrepareDecodeMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -462,9 +502,15 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests, - slot_mapping) + return PrepareDecodeMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + ) def _prepare_sample( self, @@ -643,11 +689,11 @@ def prepare_input_tensors( # If it only contains 1 type, it triggers a single broadcast. if (prefill_attn_metadata is not None and decode_attn_metadata is not None): - batch_type = "mixed" + batch_type = BatchType.MIXED elif prefill_attn_metadata is not None: - batch_type = "prefill" + batch_type = BatchType.PREFILL else: - batch_type = "decode" + batch_type = BatchType.DECODE metadata_dict = { "input_tokens": input_tokens, @@ -672,7 +718,7 @@ def prepare_input_tensors( # Broadcast decode attn metadata for mixed batch type. # The additional broadcast costs 300us overhead on 4 A10 GPUs. # We can potentially reduce the overhead by coelescing tensors. - if batch_type == "mixed": + if batch_type == BatchType.MIXED: assert decode_attn_metadata is not None metadata_dict = decode_attn_metadata.asdict_zerocopy() broadcast_tensor_dict(metadata_dict, src=0) @@ -694,7 +740,7 @@ def prepare_input_tensors( # Create an attention metadata. prefill_attn_metadata = None decode_attn_metadata = None - if batch_type == "prefill" or batch_type == "mixed": + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: prefill_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: @@ -712,7 +758,7 @@ def prepare_input_tensors( # if it is a mixed batch, decode attn_metadata is broadcasted # separately. - if batch_type == "mixed": + if batch_type == BatchType.MIXED: metadata_dict = broadcast_tensor_dict(src=0) decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) From 62db33af9435551cb94b800cae7c0dae52194de2 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 17:09:40 -0700 Subject: [PATCH 07/14] addressed other comment --- vllm/attention/backends/abstract.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5d2a8f6134460..07b6c04967fcc 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -73,14 +73,17 @@ class AttentionMetadata(Generic[T]): # Number of decode tokens. Note that it is equivalent to the number of # decode requests. num_decode_tokens: int + # The attention metadata for prefill requests in a batch. + prefill_metadata: Optional[T] + # The attention metadata for decode requests in a batch. + decode_metadata: Optional[T] # (num_tokens,). The indices of the token slots that input tokens will be # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor + # The kv cache's data type. kv_cache_dtype: str - prefill_metadata: Optional[T] - decode_metadata: Optional[T] def __post_init__(self): if self.num_prefill_tokens > 0: From a18ae3a2c82e92681a1731cb11744e458c6da93e Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 17:21:10 -0700 Subject: [PATCH 08/14] fix a test --- tests/core/test_chunked_prefill_scheduler.py | 16 ++++++++-------- vllm/worker/model_runner.py | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898f..cce396bf4953c 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -104,10 +104,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +157,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 279c8133c072c..bc31e61d3553a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -239,7 +239,6 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) - assert self.scheduler_config.chunked_prefill_enabled is not None elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. From 4b84904a5340fc0740f6846acb92e1260ff3cf60 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 22:01:02 -0700 Subject: [PATCH 09/14] Fixed --- tests/worker/test_model_runner.py | 111 ++++++++++++++++++++- vllm/attention/backends/abstract.py | 2 + vllm/attention/backends/flash_attn.py | 16 ++- vllm/attention/backends/rocm_flash_attn.py | 93 +++++++++++------ vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 61 ++++++----- vllm/engine/llm_engine.py | 3 +- vllm/sequence.py | 3 +- vllm/worker/model_runner.py | 1 + 9 files changed, 216 insertions(+), 78 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index eb0cb8bf58f70..48bfc6f9e0cac 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SchedulerConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -229,4 +229,111 @@ def test_empty_seq_group(): assert len(return_prompt_lens) == 0 -# SANG-TODO Test chunked prefill case. +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): + + def get_world_size(group=None): + return 1 + + def mock_get_process_group_ranks(group=None): + return [0] + + monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) + monkeypatch.setattr(torch.distributed, "get_process_group_ranks", + mock_get_process_group_ranks) + + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=enforce_eager, + ) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=True) + model_runner = ModelRunner(model_config, + None, + scheduler_config, + None, + None, + is_driver_worker=True) + model_runner.set_block_size(16) + + # Add prefill requests. + prompt_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(prompt_len)) + seq_data = SequenceData(prompt_toks) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.num_prefills == prefill_batch_size + if enforce_eager: + assert attn_metadata.num_decode_tokens == decode_batch_size + else: + assert attn_metadata.num_decode_tokens == _get_graph_batch_size( + decode_batch_size) + assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + prefill_meta = model_runner._prepare_prompt( + prefill_metadata_list).attn_metadata + decode_meta = model_runner._prepare_decode( + decode_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(prefill_meta), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(decode_meta), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 07b6c04967fcc..7a4ccecf702f4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -74,8 +74,10 @@ class AttentionMetadata(Generic[T]): # decode requests. num_decode_tokens: int # The attention metadata for prefill requests in a batch. + # None if there's no prefill requests in a batch. prefill_metadata: Optional[T] # The attention metadata for decode requests in a batch. + # None if there's no decode requests in a batch. decode_metadata: Optional[T] # (num_tokens,). The indices of the token slots that input tokens will be # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e4523b14c1632..12e8c4404b94e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -105,12 +105,12 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. @@ -121,8 +121,8 @@ class FlashAttentionImpl(AttentionImpl): If chunked prefill is enabled, prefill tokens and decode tokens can be batched together in a flattened 1D query. - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| Currently, cuda graph is disabled for chunked prefill, meaning there's no padding between prefill and decode tokens. @@ -250,10 +250,8 @@ def forward( prefill_meta.max_subquery_len, self.alibi_slopes, ) - if num_decode_tokens > 0: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b4494..4706954b973c7 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -51,7 +52,8 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -117,6 +119,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -181,7 +192,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, + attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -218,9 +229,25 @@ def forward( kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -230,63 +257,69 @@ def forward( key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - output = self.attn_fuc( + out = self.attn_fuc( query, key, value, - attn_metadata.prompt_lens, + prefill_meta.prompt_lens, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output, _ = self.attn_func( + out, _ = self.attn_func( query, key, value, None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, True, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, ) - + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index b1390d18f1bf6..63904ea929870 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -203,9 +203,7 @@ def forward( raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - if num_decode_tokens > 0: - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None + if decode_meta := attn_metadata.decode_metadata: # Decoding run. out = PagedAttention.forward_decode( decode_query, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b9885488c4dfc..043fa8f6af805 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -115,12 +115,12 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<---------------------- num_decode_tokens --------------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. @@ -131,8 +131,8 @@ class XFormersImpl(AttentionImpl): If chunked prefill is enabled, prefill tokens and decode tokens can be batched together in a flattened 1D query. - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| Currently, cuda graph is disabled for chunked prefill, meaning there's no padding between prefill and decode tokens. @@ -225,26 +225,8 @@ def forward( # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta).squeeze(0) + query, key, value, prefill_meta) assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: @@ -268,9 +250,7 @@ def forward( assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out - if num_decode_tokens > 0: - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None + if decode_meta := attn_metadata.decode_metadata: output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, key_cache, @@ -298,6 +278,9 @@ def _run_memory_efficient_xformers_forward( """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: output: shape = [num_prefill_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size] @@ -305,6 +288,20 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -325,6 +322,7 @@ def _run_memory_efficient_xformers_forward( # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -335,14 +333,13 @@ def _run_memory_efficient_xformers_forward( attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len @@ -354,7 +351,7 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) + output[start:end].copy_(out.view_as(original_query)) start += prompt_len return output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a7fe50d1b2d70..ddfdda898a5c6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -633,7 +633,8 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # If uncomputed tokens > 0, it means prefill is not done. + # If uncomputed tokens > 0, it means prefill is chunked. + # We don't need to process outputs in that case. if seq_group.get_num_uncomputed_tokens() == 0: self._process_sequence_group_outputs(seq_group, outputs) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c4..77029908c2218 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -500,7 +500,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bc31e61d3553a..49819e7a1ad17 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -372,6 +372,7 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, ) + return PreparePromptMetadata( input_tokens=input_tokens, input_positions=input_positions, From 5ec4891fccbc4840fe71f7ff7d1b9914871ca4f9 Mon Sep 17 00:00:00 2001 From: sang Date: Tue, 9 Apr 2024 22:53:27 -0700 Subject: [PATCH 10/14] rocm test fix --- vllm/attention/backends/rocm_flash_attn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 4706954b973c7..e55435cd2c947 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -68,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| From b814fdbb6fe281dcf2992039324fd9fa79cbbdde Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 10 Apr 2024 00:29:58 -0700 Subject: [PATCH 11/14] Fixed all remaining tests --- .buildkite/test-pipeline.yaml | 2 + benchmarks/benchmark_latency.py | 3 +- .../basic_correctness/test_chunked_prefill.py | 7 +-- .../test_basic_distributed_correctness.py | 6 +- .../test_chunked_prefill_distributed.py | 59 +++++++++++++++++++ tests/entrypoints/test_openai_server.py | 2 +- tests/worker/test_model_runner.py | 20 ++++--- vllm/config.py | 3 + vllm/engine/arg_utils.py | 5 +- vllm/lora/layers.py | 5 +- vllm/worker/model_runner.py | 2 +- 11 files changed, 92 insertions(+), 22 deletions(-) create mode 100644 tests/distributed/test_chunked_prefill_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a6..695290ed74ab5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -29,6 +29,8 @@ steps: - pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e2d358ea6631e..af543a421c1d4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None): help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 3732ac27f4b51..9ff07b3c09020 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -19,7 +19,9 @@ @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("tensor_parallel_size", [2, 1]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) def test_models( hf_runner, vllm_runner, @@ -36,9 +38,6 @@ def test_models( pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " "for high TP to save testing time.") max_num_seqs = min(chunked_prefill_token_size, 256) - - # To pass the small model tests, we need full precision. - # assert dtype == "float" enable_chunked_prefill = False max_num_batched_tokens = None if chunked_prefill_token_size != -1: diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 7ebc108c43c5c..77aa90b12bf8f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -38,7 +38,11 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 0000000000000..91b3f2780346a --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,59 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2, max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens,) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4ba..6f2086c4dd269 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -141,7 +141,7 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 48bfc6f9e0cac..dcaae4af4a6f8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -8,7 +8,11 @@ @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner.set_block_size(16) prompt_lens = [] @@ -47,8 +51,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. @@ -85,7 +87,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" assert len(input_tokens) == sum(prompt_lens) assert len(input_positions) == sum(prompt_lens) @@ -124,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size): revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(model_config, None, scheduler_config, None, + None) model_runner.set_block_size(16) prompt_lens = [] @@ -154,8 +160,6 @@ def test_prepare_decode_cuda_graph(batch_size): device = model_runner.device assert attn_metadata.is_prompt is False assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs assert attn_metadata.max_prompt_len is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None @@ -173,7 +177,6 @@ def test_prepare_decode_cuda_graph(batch_size): model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs @@ -316,6 +319,7 @@ def mock_get_process_group_ranks(group=None): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) + assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/config.py b/vllm/config.py index 5616efbc32e10..fb5ab70b5ee20 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -570,6 +570,9 @@ def __init__( # If max_model_len is too short, use 2048 as the default value # for higher throughput. self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4b573992c06c..daefddc01b431 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -386,9 +386,8 @@ def add_cli_args( 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' + action='store_true', + help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0505014753951..27fc10fd7ff10 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -268,12 +268,13 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 49819e7a1ad17..a7efbc1513701 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -662,7 +662,7 @@ def prepare_input_tensors( input_tokens.extend(decode_input_tokens) input_positions.extend(decode_input_positions) slot_mapping.extend(decode_slot_mapping) - lora_prompt_mapping.extend(decode_lora_index_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) lora_prompt_mapping.extend(decode_lora_prompt_mapping) lora_requests.update(decode_lora_requests) From d01f893dfa3ff06e1b37352ecf74dc10f85194d7 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 10 Apr 2024 00:31:45 -0700 Subject: [PATCH 12/14] lint --- tests/distributed/test_chunked_prefill_distributed.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 91b3f2780346a..737b1f3169519 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -46,7 +46,14 @@ def test_models( hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2, max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens,) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model From 346e8625a40e3fa125e202214fd092d107af9ea4 Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 10 Apr 2024 03:25:26 -0700 Subject: [PATCH 13/14] addressed code review. --- vllm/worker/model_runner.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a7efbc1513701..1722798b095e3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -53,6 +53,21 @@ class PreparePromptMetadata(NamedTuple): multi_modal_input: Optional[torch.Tensor] slot_mapping: List[int] + @classmethod + def empty(cls): + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + prompt_lens=[], + subquery_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + ) + class PrepareDecodeMetadata(NamedTuple): input_tokens: List[int] @@ -63,6 +78,18 @@ class PrepareDecodeMetadata(NamedTuple): lora_requests: Set[LoRARequest] slot_mapping: List[int] + @classmethod + def empty(cls): + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + ) + # How batches are constructed. class BatchType(IntEnum): @@ -204,7 +231,7 @@ def _prepare_prompt( multi_modal_input_list: List[torch.Tensor] = [] if len(seq_group_metadata_list) == 0: - return [], [], None, [], [], [], [], set(), None, [] + return PreparePromptMetadata.empty() for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -400,7 +427,7 @@ def _prepare_decode( lora_requests: Set[LoRARequest] = set() if len(seq_group_metadata_list) == 0: - return [], [], None, [], [], set(), [] + return PrepareDecodeMetadata.empty() for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt From addf88ea4803a8ee94106c6fe5b78ae6c3fe02ad Mon Sep 17 00:00:00 2001 From: sang Date: Wed, 10 Apr 2024 03:48:20 -0700 Subject: [PATCH 14/14] Fixed a broken model test --- tests/models/test_models.py | 2 +- vllm/attention/backends/xformers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d4619646..cfe2539e3a052 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 043fa8f6af805..b745a04a143b4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -351,7 +351,7 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.view_as(original_query)) + output[start:end].copy_(out.view_as(original_query[start:end])) start += prompt_len return output