From 753eaf2311f97210dda9f1d7f7840e251b7887bc Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:51:47 -0400 Subject: [PATCH] [Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942) Co-authored-by: Andrew Feldman Co-authored-by: Nick Hill --- .buildkite/test-pipeline.yaml | 4 +- examples/offline_inference_encoder_decoder.py | 99 ++ tests/conftest.py | 222 +++- tests/core/test_scheduler.py | 30 +- tests/core/test_scheduler_encoder_decoder.py | 99 ++ tests/core/utils.py | 99 +- ...t_basic_distributed_correctness_enc_dec.py | 101 ++ tests/kernels/test_attention_selector.py | 4 +- tests/kernels/test_encoder_decoder_attn.py | 366 ++++--- tests/kernels/test_flash_attn.py | 2 +- tests/kernels/utils.py | 20 +- tests/models/test_bart.py | 153 +++ tests/models/utils.py | 36 + .../test_encoder_decoder_model_runner.py | 480 +++++++++ vllm/attention/__init__.py | 4 +- vllm/attention/layer.py | 2 +- vllm/attention/selector.py | 123 ++- vllm/config.py | 36 +- vllm/core/block/utils.py | 12 +- vllm/core/scheduler.py | 28 + vllm/engine/arg_utils.py | 2 +- vllm/engine/llm_engine.py | 435 +++++++- vllm/entrypoints/llm.py | 21 +- vllm/inputs/__init__.py | 23 +- vllm/inputs/data.py | 124 ++- vllm/model_executor/models/__init__.py | 11 +- vllm/model_executor/models/bart.py | 996 ++++++++++++++++++ vllm/outputs.py | 20 +- vllm/sequence.py | 105 +- vllm/utils.py | 130 +++ vllm/worker/enc_dec_model_runner.py | 472 +++++++++ vllm/worker/utils.py | 56 + vllm/worker/worker.py | 13 +- 33 files changed, 3976 insertions(+), 352 deletions(-) create mode 100644 examples/offline_inference_encoder_decoder.py create mode 100644 tests/core/test_scheduler_encoder_decoder.py create mode 100644 tests/distributed/test_basic_distributed_correctness_enc_dec.py create mode 100644 tests/models/test_bart.py create mode 100644 tests/worker/test_encoder_decoder_model_runner.py create mode 100644 vllm/model_executor/models/bart.py create mode 100644 vllm/worker/enc_dec_model_runner.py create mode 100644 vllm/worker/utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6f38cd313f115..6e83c887f89b6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -148,8 +148,9 @@ steps: - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - - python3 llava_example.py + - python3 offline_inference_vision_language.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference_encoder_decoder.py - label: Models Test # 1hr10min source_file_dependencies: @@ -289,6 +290,7 @@ steps: commands: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py - pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py diff --git a/examples/offline_inference_encoder_decoder.py b/examples/offline_inference_encoder_decoder.py new file mode 100644 index 0000000000000..79b284554f172 --- /dev/null +++ b/examples/offline_inference_encoder_decoder.py @@ -0,0 +1,99 @@ +''' +Demonstrate prompting of text-to-text +encoder/decoder models, specifically BART +''' + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.utils import zip_enc_dec_prompt_lists + +dtype = "float" + +# Create a BART encoder/decoder model instance +llm = LLM( + model="facebook/bart-large-cnn", + dtype=dtype, +) + +# Get BART tokenizer +tokenizer = llm.llm_engine.get_tokenizer_group() + +# Test prompts +# +# This section shows all of the valid ways to prompt an +# encoder/decoder model. +# +# - Helpers for building prompts +text_prompt_raw = "Hello, my name is" +text_prompt = TextPrompt(prompt="The president of the United States is") +tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode( + prompt="The capital of France is")) +# - Pass a single prompt to encoder/decoder model +# (implicitly encoder input prompt); +# decoder input prompt is assumed to be None + +single_text_prompt_raw = text_prompt_raw # Pass a string directly +single_text_prompt = text_prompt # Pass a TextPrompt +single_tokens_prompt = tokens_prompt # Pass a TokensPrompt + +# - Pass explicit encoder and decoder input prompts within one data structure. +# Encoder and decoder prompts can both independently be text or tokens, with +# no requirement that they be the same prompt type. Some example prompt-type +# combinations are shown below, note that these are not exhaustive. + +enc_dec_prompt1 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt string directly, & + # pass decoder prompt tokens + encoder_prompt=single_text_prompt_raw, + decoder_prompt=single_tokens_prompt, +) +enc_dec_prompt2 = ExplicitEncoderDecoderPrompt( + # Pass TextPrompt to encoder, and + # pass decoder prompt string directly + encoder_prompt=single_text_prompt, + decoder_prompt=single_text_prompt_raw, +) +enc_dec_prompt3 = ExplicitEncoderDecoderPrompt( + # Pass encoder prompt tokens directly, and + # pass TextPrompt to decoder + encoder_prompt=single_tokens_prompt, + decoder_prompt=single_text_prompt, +) + +# - Finally, here's a useful helper function for zipping encoder and +# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt +# instances +zipped_prompt_list = zip_enc_dec_prompt_lists( + ['An encoder prompt', 'Another encoder prompt'], + ['A decoder prompt', 'Another decoder prompt']) + +# - Let's put all of the above example prompts together into one list +# which we will pass to the encoder/decoder LLM. +prompts = [ + single_text_prompt_raw, single_text_prompt, single_tokens_prompt, + enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3 +] + zipped_prompt_list + +print(prompts) + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, +) + +# Generate output tokens from the prompts. The output is a list of +# RequestOutput objects that contain the prompt, generated +# text, and other information. +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index c7a349f1e9e2a..c0bf9897c97f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,9 +10,11 @@ import torch.nn as nn import torch.nn.functional as F from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, - AutoTokenizer, BatchEncoding, BatchFeature) +from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, + BatchFeature) +from tests.models.utils import DecoderPromptType from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import TokenizerPoolConfig @@ -21,9 +23,11 @@ destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - is_cpu) + is_cpu, to_enc_dec_tuple_list, + zip_enc_dec_prompt_lists) logger = init_logger(__name__) @@ -120,6 +124,40 @@ def example_prompts() -> List[str]: return prompts +@pytest.fixture +def example_encoder_decoder_prompts() \ + -> Dict[DecoderPromptType, + Tuple[List[str], List[Optional[str]]]]: + ''' + Returns an encoder prompt list and a decoder prompt list, wherein each pair + of same-index entries in both lists corresponds to an (encoder prompt, + decoder prompt) tuple. + + Returns: + + * Encoder prompt list + * Decoder prompt list (reverse of encoder prompt list) + ''' + + encoder_prompts = [] + for filename in _TEST_PROMPTS: + encoder_prompts += _read_prompts(filename) + + custom_decoder_prompts = encoder_prompts[::-1] + empty_str_decoder_prompts = [""] * len(encoder_prompts) + none_decoder_prompts = [None] * len(encoder_prompts) + + # NONE decoder prompt type + return { + DecoderPromptType.NONE: + zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts), + DecoderPromptType.EMPTY_STR: + zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts), + DecoderPromptType.CUSTOM: + zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts), + } + + @pytest.fixture def example_long_prompts() -> List[str]: prompts = [] @@ -152,6 +190,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_vision_model: bool = False, + is_encoder_decoder_model: bool = False, ) -> None: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -168,6 +207,8 @@ def __init__( else: if is_vision_model: auto_cls = AutoModelForVision2Seq + elif is_encoder_decoder_model: + auto_cls = AutoModelForSeq2SeqLM else: auto_cls = AutoModelForCausalLM @@ -314,6 +355,44 @@ def generate_greedy_logprobs( all_logprobs.append(seq_logprobs) return all_logprobs + def _hidden_states_to_logprobs( + self, + hidden_states, + num_logprobs, + ) -> Tuple[List[Dict[int, float]], int]: + seq_logprobs: List[torch.Tensor] = [] + output_len = len(hidden_states) + for _, hidden_state in enumerate(hidden_states): + last_hidden_states = hidden_state[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst: List[Dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + return ( + seq_logprobs_lst, + output_len, + ) + def generate_greedy_logprobs_limit( self, prompts: List[str], @@ -346,37 +425,66 @@ def generate_greedy_logprobs_limit( **kwargs, ) - seq_logprobs: List[torch.Tensor] = [] - for _, hidden_states in enumerate(output.hidden_states): - last_hidden_states = hidden_states[-1][0] - logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), - ) - if getattr(self.model.get_output_embeddings(), "bias", - None) is not None: - logits += self.model.get_output_embeddings( - ).bias.unsqueeze(0) - logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - seq_logprobs.append(logprobs) + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(output.hidden_states, + num_logprobs) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = len(seq_logprobs_lst) + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) - # convert to dict - seq_logprobs_lst: List[Dict[int, float]] = [] - for tok_idx, tok_logprobs in enumerate(seq_logprobs): - # drop prompt logprobs - if tok_idx == 0: - tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) - topk = tok_logprobs.topk(num_logprobs) + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + def generate_encoder_decoder_greedy_logprobs_limit( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + max_tokens: int, + num_logprobs: int, + **kwargs: Any, + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' - tok_logprobs_dct = {} - for token_id, logprob in zip(topk.indices[0], topk.values[0]): - tok_logprobs_dct[token_id.item()] = logprob.item() + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] - seq_logprobs_lst.append(tok_logprobs_dct) + for (encoder_prompt, + decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): + encoder_input_ids = self.wrap_device( + self.tokenizer(encoder_prompt, return_tensors="pt").input_ids) + decoder_input_ids = ( + None if decoder_prompt is None else self.wrap_device( + self.tokenizer(decoder_prompt, + return_tensors="pt").input_ids)) + + output = self.model.generate( + encoder_input_ids, + decoder_input_ids=decoder_input_ids, + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(output.decoder_hidden_states, + num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] - output_len = len(seq_logprobs_lst) output_ids = seq_ids[-output_len:] all_output_ids.append(output_ids.tolist()) all_output_strs.append(self.tokenizer.decode(output_ids)) @@ -416,7 +524,7 @@ def __init__( block_size: int = 16, enable_chunked_prefill: bool = False, swap_space: int = 4, - enforce_eager: bool = False, + enforce_eager: Optional[bool] = False, **kwargs, ) -> None: self.model = LLM( @@ -465,6 +573,19 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + def _final_steps_generate_w_logprobs( + self, + req_outputs: List[RequestOutput], + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + def generate_w_logprobs( self, prompts: List[str], @@ -483,14 +604,21 @@ def generate_w_logprobs( req_outputs = self.model.generate(inputs, sampling_params=sampling_params) - outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] - for req_output in req_outputs: - for sample in req_output.outputs: - output_str = sample.text - output_ids = sample.token_ids - output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs)) - return outputs + return self._final_steps_generate_w_logprobs(req_outputs) + + def generate_encoder_decoder_w_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + sampling_params: SamplingParams, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + ''' + Logprobs generation for vLLM encoder/decoder models + ''' + + assert sampling_params.logprobs is not None + req_outputs = self.model.generate(encoder_decoder_prompts, + sampling_params=sampling_params) + return self._final_steps_generate_w_logprobs(req_outputs) def generate_greedy( self, @@ -523,6 +651,26 @@ def generate_greedy_logprobs( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def generate_encoder_decoder_greedy_logprobs( + self, + encoder_decoder_prompts: Tuple[List[str], List[str]], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, + use_beam_search=False, + max_tokens=max_tokens, + logprobs=num_logprobs) + ''' + Greedy logprobs generation for vLLM encoder/decoder models + ''' + + outputs = self.generate_encoder_decoder_w_logprobs( + encoder_decoder_prompts, greedy_logprobs_params) + + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def generate_beam_search( self, prompts: List[str], diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 447e8f8a586f6..11168d2423b0e 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -9,33 +9,11 @@ from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, SequenceGroup, SequenceStatus +from vllm.sequence import SequenceGroup, SequenceStatus -from .utils import create_dummy_prompt - - -def get_sequence_groups(scheduler_output): - return [s.seq_group for s in scheduler_output.scheduled_seq_groups] - - -def append_new_token(out, token_id: int): - seq_groups = get_sequence_groups(out) - for seq_group in seq_groups: - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) - - -def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) - return metas, out - - -def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): - seq_group.update_num_computed_tokens(token_chunk_size) - for seq in seq_group.get_seqs(): - seq.append_token_id(token_id, {token_id: Logprob(token_id)}) +from .utils import (append_new_token, append_new_token_seq_group, + create_dummy_prompt, get_sequence_groups, + schedule_and_update_computed_tokens) def test_scheduler_add_seq_group(): diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py new file mode 100644 index 0000000000000..50c047f30b80d --- /dev/null +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -0,0 +1,99 @@ +from typing import List + +import pytest # noqa + +from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.sequence import SequenceGroup + +from .utils import (append_new_token, create_dummy_prompt_encoder_decoder, + get_sequence_groups, schedule_and_update_computed_tokens) + + +def test_scheduler_schedule_simple_encoder_decoder(): + ''' + Test basic scheduler functionality in the context + of an encoder/decoder model. Focus on testing + enc/dec-specific functionality sense tests already + exist for decoder-only functionality + + Test behavior: + * Construct Scheduler + * Construct dummy encoder/decoder sequence groups + * Add dummy seq groups to scheduler backlog + * Schedule the next seq group & validate: + * Cross-attn block tables + * Updated states of seq groups + * Number of batched tokens + * Number of blocks to copy/swap-in/swap-out + * Number of scheduled seq groups + * Repeat for both prefill- and decode-phase + * Abort scheduled seq groups + * Assert that aborted seq groups no longer appear in + cross-attention block table + ''' + + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group + cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + req_id_list = [] + for i in range(num_seq_group): + req_id = str(i) + req_id_list.append(req_id) + _, _, seq_group = create_dummy_prompt_encoder_decoder( + req_id, block_size, block_size, block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prefill. + num_tokens = block_size * num_seq_group + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group cross-attention block tables are + # registered with the block manager + assert all([(req_id in scheduler.block_manager.cross_block_tables) + for req_id in req_id_list]) + # - Validate sequence-group status + assert set(get_sequence_groups(out)) == set(running) + # - Validate number of batched tokens + assert out.num_batched_tokens == num_tokens + # - Validate there are no remaining blocks to swap + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + # - Validate all seq groups were scheduled + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Schedule seq groups decode. + seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler) + # - Verify that sequence group metadata includes encoder attention + # and cross-attention metadata + assert all([ + not ((seq_group_meta.encoder_seq_data is None) or + (seq_group_meta.cross_block_table is None)) + for seq_group_meta in seq_group_meta_list + ]) + # - Validate sequence-group status + assert set(get_sequence_groups(out)) == set(running) + # - Validate there is one batched token per seq group + assert out.num_batched_tokens == num_seq_group + # - Validate there are no remaining blocks to swap + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + # - Validate that all seq groups were scheduled + assert len(seq_group_meta_list) == num_seq_group + append_new_token(out, 1) + + # Abort sequences + for req_id in req_id_list: + scheduler.abort_seq_group(req_id) + # - Verify that sequence group cross-attention block tables are + # NO LONGER registered with the block manager + assert req_id not in scheduler.block_manager.cross_block_tables diff --git a/tests/core/utils.py b/tests/core/utils.py index f249f4b59a2ee..45a8e74e85324 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -53,27 +53,30 @@ def create_dummy_prompt_encoder_decoder( block_size = decoder_prompt_length # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". + # and prompt "0 ... block_size". Note that the prompt string + # doesn't actually match the tokens decoder_prompt_tokens = list(range(decoder_prompt_length)) decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + + inputs = { + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "encoder_prompt": encoder_prompt_str, + "encoder_prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + } decoder_prompt = Sequence(int(request_id), - inputs={ - "prompt": decoder_prompt_str, - "prompt_token_ids": decoder_prompt_tokens, - "multi_modal_data": None, - }, - block_size=block_size) + inputs=inputs, + block_size=block_size, + from_decoder_prompt=True) - encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) - encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) encoder_prompt = Sequence(int(request_id), - inputs={ - "prompt": encoder_prompt_str, - "prompt_token_ids": encoder_prompt_tokens, - "multi_modal_data": None, - }, - block_size=block_size) + inputs=inputs, + block_size=block_size, + from_decoder_prompt=False) seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], sampling_params=SamplingParams( @@ -139,17 +142,21 @@ def create_seq_group_encoder_decoder( prompt_token_ids = [0] * seq_prompt_len + inputs = { + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "encoder_prompt": "", + "encoder_prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + } + seqs = [] for seq_id_offset, output_len in enumerate(seq_output_lens): - seq = Sequence( - seq_id=seq_id_start + seq_id_offset, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, - block_size=16, - ) + # Construct decoder input sequences + seq = Sequence(seq_id=seq_id_start + seq_id_offset, + inputs=inputs, + block_size=16, + from_decoder_prompt=True) for i in range(output_len): seq.append_token_id( @@ -158,16 +165,11 @@ def create_seq_group_encoder_decoder( ) seqs.append(seq) - # Encoder sequence - encoder_seq = Sequence( - seq_id=seq_id_start + len(seq_output_lens), - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, - block_size=16, - ) + # Encoder input sequence + encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens), + inputs=inputs, + block_size=16, + from_decoder_prompt=False) return SequenceGroup(request_id=request_id, seqs=seqs, @@ -177,4 +179,31 @@ def create_seq_group_encoder_decoder( def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size \ No newline at end of file + return (seq_len + block_size - 1) // block_size + + +# Helper functions for scheduler tests + + +def get_sequence_groups(scheduler_output): + return [s.seq_group for s in scheduler_output.scheduled_seq_groups] + + +def append_new_token(out, token_id: int): + seq_groups = get_sequence_groups(out) + for seq_group in seq_groups: + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +def schedule_and_update_computed_tokens(scheduler): + metas, out = scheduler.schedule() + for s, meta in zip(out.scheduled_seq_groups, metas): + s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + return metas, out + + +def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): + seq_group.update_num_computed_tokens(token_chunk_size) + for seq in seq_group.get_seqs(): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py new file mode 100644 index 0000000000000..69eae62ca7320 --- /dev/null +++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py @@ -0,0 +1,101 @@ +"""For encoder/decoder models only: +Compare the outputs of HF and distributed vLLM when using greedy sampling. + +Run: +```sh +cd $VLLM_PATH/tests + +pytest distributed/test_basic_distributed_correctness_enc_dec.py +``` +""" + +import pytest + +from tests.models.utils import DecoderPromptType +from vllm.utils import cuda_device_count_stateless + +from ..models.utils import check_logprobs_close +from ..utils import fork_new_process_for_each_test + + +@pytest.mark.skipif(cuda_device_count_stateless() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model, distributed_executor_backend", [ + ("facebook/bart-large-cnn", "ray"), + ("facebook/bart-large-cnn", "mp"), +]) +@fork_new_process_for_each_test +def test_models( + model: str, + distributed_executor_backend: str, + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, +) -> None: + ''' + Test vLLM BART inference on more than one GPU, comparing + outputs against HF as a baseline. + + Fork a new process for each test, to prevent CUDA from + being re-initialized by successive tests within the same + process. + + Arguments: + + * model: the HF ID of the specific BART variant under test + * distributed_executor_backend + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + ''' + + dtype = "float" + max_tokens = 64 + num_logprobs = 5 + + # Example inputs with non-trivial (i.e. not None/empty) encoder & + # decoder prompts. + test_prompts = example_encoder_decoder_prompts[DecoderPromptType.CUSTOM] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + ) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_prompts, max_tokens, num_logprobs) + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + is_encoder_decoder_model=True) as hf_model: + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index d9404e6442616..a20a741c27f74 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,9 +3,9 @@ import pytest import torch -from tests.kernels.utils import (STR_FLASH_ATTN_VAL, STR_INVALID_VAL, - override_backend_env_variable) +from tests.kernels.utils import override_backend_env_variable from vllm.attention.selector import which_attn_to_use +from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @pytest.mark.parametrize( diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index f25e7d480b6b3..b550a7fdd84f0 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -4,8 +4,6 @@ * E2E test of Encoder attention + Decoder self-attention + Encoder/decoder cross-attention (collectively "encoder/decoder attention") -* Confirm enc/dec models will fail for chunked prefill -* Confirm enc/dec models will fail for prefix caching """ @@ -15,19 +13,22 @@ import torch from tests.kernels.utils import * -from tests.kernels.utils import make_causal_mask, maybe_make_long_tensor -from vllm.attention import Attention, AttentionMetadata -from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, + AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP +from vllm.attention.selector import (_Backend, + global_force_attn_backend_context_manager) from vllm.utils import is_hip +# List of support backends for encoder/decoder models +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] BATCH_SIZES = [1, 16] BLOCK_SIZES = [16] -BACKEND_NAMES = [STR_XFORMERS_ATTN_VAL] CUDA_DEVICE = "cuda:0" MAX_DEC_SEQ_LENS = [128] @@ -724,57 +725,92 @@ def _run_encoder_decoder_cross_attention_test( @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) -def test_encoder_only(num_heads: int, head_size: int, backend_name: str, - batch_size: int, block_size: int, max_dec_seq_len: int, - max_enc_seq_len: int, monkeypatch): +def test_encoder_only( + num_heads: int, + head_size: int, + attn_backend: _Backend, + batch_size: int, + block_size: int, + max_dec_seq_len: int, + max_enc_seq_len: int, +): + ''' + End-to-end encoder-only attention test: + + * Construct fake test vectors for (1) encoder attention + * Construct (1) attention metadata structure with prefill-phase + encoder attention, and (2) an analogous attention metadata + structure but for decode-phase + * Test & validate encoder attention against ideal output + + No KV cache is required for encoder-only attention. + + Note on ROCm/HIP: currently encoder/decoder models are not supported on + AMD GPUs, therefore this test simply is skipped if is_hip(). + + This test globally forces an override of the usual backend + auto-selection process, forcing the specific backend-under-test + to be utilized. + + Arguments: + + * num_heads + * head_size, + * attn_backend: The attention backend to employ for testing + * batch_size + * block_size: KV cache block size + * max_dec_seq_len: max length of decoder input sequences + * max_enc_seq_len: max length of encoder input sequences + ''' # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) + with global_force_attn_backend_context_manager(attn_backend): - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096) - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) - # Construct encoder attention test params (only used - # during prefill) + # Construct encoder attention test params (only used + # during prefill) - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - # Shared prefill metadata structure + # Shared prefill metadata structure - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - None, - decoder_test_params=None, - encoder_test_params=enc_test_params, - cross_test_params=None, - device=CUDA_DEVICE) + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + None, + decoder_test_params=None, + encoder_test_params=enc_test_params, + cross_test_params=None, + device=CUDA_DEVICE) - # PREFILL: encoder attention + # PREFILL: encoder attention - enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( - test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) + enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test( + test_rsrcs.attn, enc_test_params, prephase_attn_metadata)) - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("backend_name", BACKEND_NAMES) +@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) @@ -782,12 +818,11 @@ def test_encoder_only(num_heads: int, head_size: int, backend_name: str, def test_e2e_enc_dec_attn( num_heads: int, head_size: int, - backend_name: str, + attn_backend: _Backend, batch_size: int, block_size: int, max_dec_seq_len: int, max_enc_seq_len: int, - monkeypatch, ) -> None: ''' End-to-end encoder/decoder test: @@ -820,8 +855,9 @@ def test_e2e_enc_dec_attn( cross-attention K/Vs are allowed to differ in seq len, as is often the case for cross-attention. - This test utilizes PyTest monkey patching to force the attention backend - via an environment variable. + This test globally forces an override of the usual backend + auto-selection process, forcing the specific backend-under-test + to be utilized. Note on ROCm/HIP: currently encoder/decoder models are not supported on AMD GPUs, therefore this test simply is skipped if is_hip(). @@ -830,124 +866,136 @@ def test_e2e_enc_dec_attn( all prefill-phase attention operations (encoder, decoder, enc/dec cross), and a single one shared by all decode-phase attention operations (decoder & enc/dec cross.) This is intended to reflect the behavior - of ModelRunner, which constructs a single attention metadata structure for - each prefill or decode run. A realistic scenario would rely on the - attention backend to utilize the appropriate attention metadata fields - according to the value of attn_metadata.attention_type. Thus, this test is - organized so as to confirm that the backend-under-test can handle a - shared prefill attention metadata structure & a shared decode attention - metadata structure. - ''' - - # Force Attention wrapper backend - override_backend_env_variable(monkeypatch, backend_name) - - # Note: KV cache size of 4096 is arbitrary & chosen intentionally - # to be more than necessary, since exceeding the kv cache size - # is not part of this test - test_pt = TestPoint(num_heads, head_size, backend_name, batch_size, - block_size, max_dec_seq_len, max_enc_seq_len, 4096) - - # Attention scale factor, attention backend instance, attention wrapper - # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + of EncoderDecoderModelRunner, which constructs a single attention metadata + structure for each prefill or decode run. A realistic scenario would rely + on the attention backend to utilize the appropriate attention metadata + fields according to the value of attn_metadata.attention_type. Thus, + this test is organized so as to confirm that the backend-under-test can + handle a shared prefill attention metadata structure & a shared decode\ + attention metadata structure. - # Construct encoder attention test params (only used - # during prefill) - - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) - - # Construct Decoder self-attention prefill-phase & decode-phase - # test params, including query/key/value tensors, decoder self-attention - # memory-mapping. cross_block_base_addr is the uppermost address in the - # decoder self-attention block-table, i.e. a base address which the - # encoder/decoder cross-attention block-table may build downward toward. - - ( - dec_qkv, - prephase_dec_test_params, - decphase_dec_test_params, - cross_block_base_addr, - ) = _decoder_attn_setup(test_pt, test_rsrcs) + Arguments: - # Construct encoder/decoder cross-attention prefill-phase & decode-phase - # test params, including key/value tensors, cross-attention memory-mapping + * num_heads + * head_size, + * attn_backend: The attention backend to employ for testing + * batch_size + * block_size: KV cache block size + * max_dec_seq_len: max length of decoder input sequences + * max_enc_seq_len: max length of encoder input sequences + ''' - ( - prephase_cross_test_params, - decphase_cross_test_params, - ) = _enc_dec_cross_attn_setup_reuses_query( - dec_qkv, - enc_test_params, - prephase_dec_test_params, - test_pt, - test_rsrcs, - block_base_addr=cross_block_base_addr) - - # Shared prefill metadata structure - assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None - prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - True, - prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, - decoder_test_params=prephase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=prephase_cross_test_params, - device=CUDA_DEVICE) - - # PREFILL: encoder attention - - enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, - enc_test_params, - prephase_attn_metadata) - - # - Is encoder attention result correct? - assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) - - # PREFILL: decoder self-attention test - - prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) - - # - Is prefill decoder self-attention correct? - assert_actual_matches_ideal(prephase_dec_test_params, - prephase_dec_pckd_act_out) - - # PREFILL: encoder/decoder cross-attention test - - prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, - prephase_attn_metadata) - - # - Is prefill encoder/decoder cross-attention correct? - assert_actual_matches_ideal(prephase_cross_test_params, - prephase_cross_pckd_act_out) - - # DECODE: build decode-phase attention metadata - - decphase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, - False, - dec_qkv.q_seq_lens, - decoder_test_params=decphase_dec_test_params, - encoder_test_params=enc_test_params, - cross_test_params=decphase_cross_test_params, - device=CUDA_DEVICE) - - # DECODE: decoder self-attention test - - decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) - - # - Is decode-phase decoder self-attention correct? - assert_actual_matches_ideal(decphase_dec_test_params, - decphase_dec_pckd_act_out) - - # DECODE: encoder/decoder cross-attention test - - decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) - - # - Is decode-phase encoder/decoder cross-attention correct? - assert_actual_matches_ideal(decphase_cross_test_params, - decphase_cross_pckd_act_out) + # Force Attention wrapper backend + with global_force_attn_backend_context_manager(attn_backend): + + # Note: KV cache size of 4096 is arbitrary & chosen intentionally + # to be more than necessary, since exceeding the kv cache size + # is not part of this test + test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096) + + # Attention scale factor, attention backend instance, attention wrapper + # instance, KV cache init + test_rsrcs = _make_test_resources(test_pt) + + # Construct encoder attention test params (only used + # during prefill) + + enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + + # Construct Decoder self-attention prefill-phase & decode-phase + # test params, including query/key/value tensors, decoder self-attention + # memory-mapping. cross_block_base_addr is the uppermost address in the + # decoder self-attention block-table, i.e. a base address which the + # encoder/decoder cross-attention block-table may build downward toward. + + ( + dec_qkv, + prephase_dec_test_params, + decphase_dec_test_params, + cross_block_base_addr, + ) = _decoder_attn_setup(test_pt, test_rsrcs) + + # Construct encoder/decoder cross-attention prefill-phase + # & decode-phase test params, including key/value tensors, + # cross-attention memory-mapping + + ( + prephase_cross_test_params, + decphase_cross_test_params, + ) = _enc_dec_cross_attn_setup_reuses_query( + dec_qkv, + enc_test_params, + prephase_dec_test_params, + test_pt, + test_rsrcs, + block_base_addr=cross_block_base_addr) + + # Shared prefill metadata structure + assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None + prephase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + True, + prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, + decoder_test_params=prephase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=prephase_cross_test_params, + device=CUDA_DEVICE) + + # PREFILL: encoder attention + + enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_test_params, + prephase_attn_metadata) + + # - Is encoder attention result correct? + assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) + + # PREFILL: decoder self-attention test + + prephase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_attn_metadata) + + # - Is prefill decoder self-attention correct? + assert_actual_matches_ideal(prephase_dec_test_params, + prephase_dec_pckd_act_out) + + # PREFILL: encoder/decoder cross-attention test + + prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, + prephase_attn_metadata) + + # - Is prefill encoder/decoder cross-attention correct? + assert_actual_matches_ideal(prephase_cross_test_params, + prephase_cross_pckd_act_out) + + # DECODE: build decode-phase attention metadata + + decphase_attn_metadata: AttentionMetadata = make_test_metadata( + test_rsrcs.attn_backend, + False, + dec_qkv.q_seq_lens, + decoder_test_params=decphase_dec_test_params, + encoder_test_params=enc_test_params, + cross_test_params=decphase_cross_test_params, + device=CUDA_DEVICE) + + # DECODE: decoder self-attention test + + decphase_dec_pckd_act_out = _run_decoder_self_attention_test( + test_rsrcs, decphase_dec_test_params, decphase_attn_metadata) + + # - Is decode-phase decoder self-attention correct? + assert_actual_matches_ideal(decphase_dec_test_params, + decphase_dec_pckd_act_out) + + # DECODE: encoder/decoder cross-attention test + + decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( + test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata) + + # - Is decode-phase encoder/decoder cross-attention correct? + assert_actual_matches_ideal(decphase_cross_test_params, + decphase_cross_pckd_act_out) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 6c5eff00de44c..0d3edc5d2aaf7 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -211,5 +211,5 @@ def test_varlen_with_paged_kv( sliding_window=sliding_window, soft_cap=soft_cap, ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 23d627820d247..e942336ff7fdc 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -8,24 +8,10 @@ import pytest import torch -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention.backends.xformers import XFormersBackend -from vllm.utils import make_tensor_with_pad - -# String name of register which may be set in order to -# force auto-selection of attention backend by Attention -# wrapper -STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" - -# Possible string values of STR_BACKEND_ENV_VAR -# register, corresponding to possible backends -STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" -STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" -STR_XFORMERS_ATTN_VAL: str = "XFORMERS" -STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_INVALID_VAL: str = "INVALID" +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad) class QKVInputs(NamedTuple): diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py new file mode 100644 index 0000000000000..9c26b7163ff62 --- /dev/null +++ b/tests/models/test_bart.py @@ -0,0 +1,153 @@ +"""Compare the outputs of HF and vLLM for BART models using greedy sampling. + +Run `pytest tests/models/test_bart.py`. +""" +from vllm.utils import is_cpu + +if not is_cpu(): + # CPU backend is not currently supported with encoder/decoder models + # skip test definitions entirely to avoid importing GPU kernel libs + # (xFormers, etc.) + + import pytest + + from tests.models.utils import DecoderPromptType + + from .utils import check_logprobs_close + + MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"] + + DECODER_PROMPT_TYPES = ([ + DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR, + DecoderPromptType.NONE + ]) + + @pytest.mark.parametrize("model", MODELS) + @pytest.mark.parametrize("dtype", ["float", "bfloat16"]) + @pytest.mark.parametrize("max_tokens", [64]) + @pytest.mark.parametrize("num_logprobs", [5]) + @pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES) + def test_models( + hf_runner, + vllm_runner, + example_encoder_decoder_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + decoder_prompt_type: DecoderPromptType, + ) -> None: + ''' + Test the vLLM BART model for a variety of encoder/decoder input prompts, + by validating it against HuggingFace (HF) BART. + + Arguments: + + * hf_runner: HuggingFace (HF) test model runner + * vllm_runner: vLLM test model runner + * example_encoder_decoder_prompts: test fixture which provides a + dictionary of dummy prompts + * model: the HF ID of the specific BART variant under test + * dtype: the tensor datatype to employ + * max_tokens + * num_logprobs + * decoder_prompt_type: key into the example_encoder_decoder_prompts + dictionary; selects specific encoder/decoder + prompt scenarios to test + + A note on using HF BART as a baseline for validating vLLM BART, + specifically when the decoder prompt is None. + + The HF GenerationMixin's default behavior is to force the first + decoded token to be if the prompt does not already contain + (this is accomplished using a logit + processor setting.) + + So when we use HF BART as our baseline for comparison, note that + when the user provides a request with a None decoder prompt + (i.e. a singleton encoder prompt, or else an explicit encoder/ + decoder prompt with the decoder sub-prompt set to None), HF and + vLLM handle this in different ways: + + * HF will (1) tokenize the None prompt as an empty token-list, + (2) append to the beginning, yielding + [], (3) pass this token list to the model, and + then (4) after computing logits during prefill, override the model + logits & force to be the first generated token. + + * vLLM will (1) tokenize the None prompt as [], (2) append decoder- + start-token to the beginning, yielding [], + (3) pass these tokens to the model & proceed with generation. + + The net effect is that compared to vLLM, the list of HF *decoded* tokens + will contain one more initial than the vLLM generated tokens, + because vLLM's token is injected into the prompt rather than into + the generated output. This is in spite of the fact that overall, the + complete sequences (prompt + decoded tokens) produced by vLLM will match + HF. + + So when we use HF decoded token output to validate vLLM's decoded token + output, the testing process must account for the difference in decoded + token sequences between vLLM and HF specifically in the + decoder-prompt-is-None case. + + One option is to disable the logit processor feature that forces the + token to be decoded (forced_bos_token_id = None), eliminating + the problem entirely. However this is not "normal" BART usage. + + The other option is - only in the decoder-prompt-is-None case - to + discard the first decoded token from the HF output before comparing it + to vLLM. + + To that end, when testing the scenario where the decoder prompt is None + (and only in that one scenario), this test skips the first HF decoded + token during the process of validating the vLLM decoded output. + ''' + + test_case_prompts = example_encoder_decoder_prompts[ + decoder_prompt_type] + + # Configuration settings for HF baseline + hf_kwargs = { + "top_k": None, + "num_beams": 1, + "repetition_penalty": 1.0, + "top_p": 1.0, + "length_penalty": 1.0, + "early_stopping": False, + "no_repeat_ngram_size": None, + "min_length": 0 + } + + with hf_runner(model, dtype=dtype, + is_encoder_decoder_model=True) as hf_model: + hf_outputs = ( + hf_model.generate_encoder_decoder_greedy_logprobs_limit( + test_case_prompts, + max_tokens, + num_logprobs, + **hf_kwargs, + )) + + # Note: currently encoder/decoder models are only compatible with + # enforce_eager=True. Normally this is not a problem because + # for encoder/decoder models vLLM will + # default to enforce_eager=True if enforce_eager + # is left unspecified. However, the + # VllmRunner test fixture (which wraps around the LLM class) defaults to + # enforce_eager=False (a behavior which a number of already-exisitng + # decoder-only unit tests expect), so when testing an encoder/decoder + # model we must explicitly specify enforce_eager=True in the VllmRunner + # constructor. + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + test_case_prompts, max_tokens, num_logprobs) + + hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE + else 0) + + check_logprobs_close(outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + num_outputs_0_skip_tokens=hf_skip_tokens) diff --git a/tests/models/utils.py b/tests/models/utils.py index 425f57ef9b966..d96301b853c85 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,4 +1,5 @@ import warnings +from enum import Enum from typing import Dict, List, Optional, Sequence, Tuple, Union from vllm.sequence import SampleLogprobs @@ -45,11 +46,27 @@ def check_logprobs_close( outputs_1_lst: Sequence[TokensTextLogprobs], name_0: str, name_1: str, + num_outputs_0_skip_tokens: int = 0, warn_on_mismatch: bool = True, ): """ Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. + + Arguments: + + * outputs_0_lst: First sequence to compare + * outputs_0_lst: Second sequence to compare + * name_0: sequence #0 name + * name_1: sequence #1 name + * num_outputs_0_skip_tokens: If > 0, specifies the number of initial + sequence #0 tokens & logprobs to discard + before comparison, i.e. all + of sequence #1 will be compared to + sequence #0 beginning at index + num_outputs_0_skip_tokens + * warn_on_mismatch: Issue a warning if there is token-wise or text-wise + mismatch between the two sequences """ assert len(outputs_0_lst) == len(outputs_1_lst) @@ -65,6 +82,15 @@ def check_logprobs_close( if logprobs_1 is None: logprobs_1 = [None] * len(output_ids_1) + # Skip specified number of initial sequence #0 tokens + # & logprobs, leaving output text as-is for simplicity + # (text mismatches may generate warnings but do not + # cause the test to fail.) + if num_outputs_0_skip_tokens < 0: + raise ValueError("num_outputs_0_skip_tokens must be non-negative") + output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:] + logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:] + # Loop through generated tokens. for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): @@ -110,3 +136,13 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +class DecoderPromptType(Enum): + ''' + For encoder/decoder models only - + + ''' + CUSTOM = 1 + NONE = 2 + EMPTY_STR = 3 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py new file mode 100644 index 0000000000000..8a2e9b81580fc --- /dev/null +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -0,0 +1,480 @@ +from typing import List + +import pytest +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import is_cpu +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner + +# CUDA graph scenarios to test +# +# Currently CUDA graph is not supported +ENFORCE_EAGER = [True] + +BATCH_SIZES = [1, 4, 16, 64, 256] + + +def _create_model_runner(model: str, *args, + **kwargs) -> EncoderDecoderModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = EncoderDecoderModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, + is_driver_worker=True, + ) + return model_runner + + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_empty_seq_group(enforce_eager, ): + """Verify prepare prompt and decode returns empty output + for empty seq group list""" + + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + model_input = model_runner._prepare_model_input_tensors( + seq_group_metadata_list) + ( + input_tokens, + input_positions, + encoder_input_tokens, + encoder_input_positions, + attn_metadata, + return_seq_lens, + ) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.encoder_input_tokens, + model_input.encoder_input_positions, + model_input.attn_metadata, + model_input.seq_lens, + ) + assert input_tokens is None + assert input_positions is None + assert encoder_input_tokens is None + assert encoder_input_positions is None + assert attn_metadata is None + assert return_seq_lens is None + + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_prompt( + batch_size, + enforce_eager, +): + ''' + Test the ability of the encoder/decoder model runner subclass to + produce prefill-phase model inputs & attention metadata. + + Test behavior: + + * Instantiate BART base model & enc/dec model runner + * Construct sequence-group metadata for dummy prompts + * Test that encoder attention, decoder self-attention, + and encoder/decoder cross-attention inputs are correct + + Arguments: + + * batch_size + * backend_name: The attention backend under test + * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) + ''' + + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + + # Build + # * Decoder model inputs + # * Decoder self-attention KV caching data structures + # * Encoder model inputs + # * Encoder/decoder cross-attention KV caching data structures + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + encoder_input_tokens = model_input.encoder_input_tokens + encoder_input_positions = model_input.encoder_input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for prompts. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 + assert torch.equal(attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.equal( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + start_loc.append(start_idx) + assert torch.equal( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) + + # Test decoder seq start locs & context lengths + + assert torch.equal( + attn_metadata.seq_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) + assert torch.equal( + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], + dtype=torch.int, + device=device), + ) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor( + [[] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device, + ) + assert torch.equal( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention + assert torch.equal( + attn_metadata.cross_block_tables, + expected, + ) + + # Cuda graph should not be used for prefill. + assert attn_metadata.use_cuda_graph is False + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) + # -- An indirect check that model_input.input_tokens + # and model_input.input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == sum(encoder_seq_lens) + # -- An indirect check that model_input.encoder_input_tokens + # and model_input.encoder_input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + encoder_input_tokens, + encoder_input_positions, + ) + + # Test that vLLM sampling infrastructure chooses the correct + # sequence positions at which to sample (i.e. the end of + # each sequence) in the prefill phase + + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for seq_len in seq_lens: + # Compute the index offset of the final token in each + # prompt (recall that the prompts are concatenated) + expected_selected_token_indices.append(selected_token_start_idx + + seq_len - 1) + selected_token_start_idx += seq_len + + sampling_metadata = model_input.sampling_metadata + actual = sampling_metadata.selected_token_indices + expected = torch.tensor( + expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype, + ) + assert torch.equal(actual, expected) + + +@pytest.mark.skipif(condition=is_cpu(), + reason="CPU backend is currently " + "unsupported for encoder/ " + "decoder models") +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER) +def test_prepare_decode( + batch_size, + enforce_eager, +): + ''' + Test the ability of the encoder/decoder model runner subclass to + produce decode-phase model inputs & attention metadata. + + Test behavior: + + * Instantiate BART base model & enc/dec model runner + * Construct sequence-group metadata for dummy prompts + * Test that encoder attention, decoder self-attention, + and encoder/decoder cross-attention inputs are correct + + Arguments: + + * batch_size + * backend_name: The attention backend under test + * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) + ''' + + model_runner = _create_model_runner( + "facebook/bart-base", + seed=0, + dtype="float16", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + enforce_eager=enforce_eager, + ) + + seq_lens: List[int] = [] + encoder_seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + block_tables = {0: [1]} + cross_block_table = [2] + for i in range(batch_size): + # make sure all tokens fit into one block + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 + encoder_seq_lens.append(encoder_seq_len) + encoder_seq_data = SequenceData(list(range(encoder_seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + + # Build + # * Decoder model inputs + # * Decoder self-attention KV caching data structures + # * Encoder model inputs + # * Encoder/decoder cross-attention KV caching data structures + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = attn_metadata.slot_mapping + encoder_input_tokens = model_input.encoder_input_tokens + encoder_input_positions = model_input.encoder_input_positions + cross_slot_mapping = attn_metadata.cross_slot_mapping + assert return_seq_lens == seq_lens + assert len(slot_mapping) == len(input_tokens) + assert len(cross_slot_mapping) == len(encoder_input_tokens) + + # Verify input metadata is correct for decode phase. + # - Decoder attention metadata + device = model_runner.device + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_decode_tokens > 0 + assert torch.equal(attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(seq_lens) + # - Encoder attention metadata + assert attn_metadata.encoder_seq_lens == encoder_seq_lens + assert torch.equal( + attn_metadata.encoder_seq_lens_tensor, + torch.tensor(encoder_seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens) + assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens) + + # Test decoder subquery start locs. + start_idx = 0 + start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += 1 + start_loc.append(start_idx) + assert torch.equal( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device), + ) + + # Test decoder seq start locs. Note that for normal prefill it is + # equivalent to query_start_loc. + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + + # Test seq_start_loc and context lengths + + assert torch.equal( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device), + ) + assert torch.equal( + attn_metadata.context_lens_tensor, + torch.tensor([seq_len - 1 for seq_len in seq_lens], + dtype=torch.int, + device=device)) + + # Verify block tables are correct for prompts + # - Decoder self-attention + expected = torch.tensor( + [block_tables[0] for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.equal( + attn_metadata.block_tables, + expected, + ) + # - Encoder/decoder cross-attention + expected = torch.tensor( + [cross_block_table for _ in range(len(seq_group_metadata_list))], + dtype=torch.int32, + device=model_runner.device) + assert torch.equal( + attn_metadata.cross_block_tables, + expected, + ) + + # Cuda graph should is currently not supported for encoder/decoer. + assert attn_metadata.use_cuda_graph is False + + # Verify the lengths of input tokens & positions + # - Decoder + assert len(input_tokens) == len(seq_lens) + assert len(input_positions) == len(seq_lens) + # -- An indirect check that model_input.input_tokens + # and model_input.input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + input_tokens, + input_positions, + ) + # - Encoder + assert len(encoder_input_tokens) == 0 + assert len(encoder_input_tokens) == 0 + # -- An indirect check that model_input.encoder_input_tokens + # and model_input.encoder_input_positions are correct - + # by design of the test, the input tokens are + # equal to the input position values, so if + # the model_input data structure has the correct + # values then these two should be equal + assert torch.equal( + encoder_input_tokens, + encoder_input_positions, + ) + + # Test that vLLM sampling infrastructure chooses the correct + # sequence positions at which to sample (i.e. the end of + # each sequence) in the decode phase + + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for seq_len in seq_lens: + # Compute the index offset of the final token in each + # sequence's decoded outputs; since a single token is + # decoded per iteration per sequence, then the length + # of the decoded tokens for a given sequence is 1 and + # the final index offset into a given sequence's + # generated tokens is 0 (i.e. the expected sampling index + # for a given sequence is just `selected_token_start_idx`) + expected_selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + + sampling_metadata = model_input.sampling_metadata + actual = sampling_metadata.selected_token_indices + expected = torch.tensor( + expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype, + ) + assert torch.equal(actual, expected) diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 44bfae44cfddd..4643d316d48b7 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, - AttentionMetadataBuilder) + AttentionMetadataBuilder, + AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,6 +9,7 @@ "Attention", "AttentionBackend", "AttentionMetadata", + "AttentionType", "AttentionMetadataBuilder", "Attention", "get_attn_backend", diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2c21502dcf407..ecf964fa49d9b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 8fcd85585a18f..d5c8d6a376961 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,8 @@ import enum +import os +from contextlib import contextmanager from functools import lru_cache -from typing import Optional, Type +from typing import Generator, Optional, Type import torch @@ -8,7 +10,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu +from vllm.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, + is_tpu, is_xpu) logger = init_logger(__name__) @@ -24,6 +27,66 @@ class _Backend(enum.Enum): IPEX = enum.auto() +def backend_name_to_enum(backend_name: str) -> _Backend: + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError(f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + @lru_cache(maxsize=None) def get_attn_backend( num_heads: int, @@ -101,16 +164,20 @@ def which_attn_to_use( # Default case. selected_backend = _Backend.FLASH_ATTN - # Check the environment variable and override if specified - backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - backend_members = _Backend.__members__ - if backend_by_env_var not in backend_members: - raise ValueError( - f"Invalid attention backend '{backend_by_env_var}'. " - f"Available backends: {', '.join(backend_members)} " - "(case-sensitive).") - selected_backend = _Backend[backend_by_env_var] + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) if is_cpu(): if selected_backend != _Backend.TORCH_SDPA: @@ -193,3 +260,35 @@ def which_attn_to_use( selected_backend = _Backend.XFORMERS return selected_backend + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/vllm/config.py b/vllm/config.py index 3cc197f3d655f..ec6d587e7925b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,7 +12,8 @@ from vllm.model_executor.models import ModelRegistry from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, + cuda_device_count_stateless, get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino, is_tpu, is_xpu, print_warning_once) @@ -87,6 +88,9 @@ class ModelConfig: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. + If None, the user did not specify, so default to False - + except for encoder/decoder models, which currently require + eager mode. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). @@ -121,7 +125,7 @@ def __init__( max_model_len: Optional[int] = None, quantization: Optional[str] = None, quantization_param_path: Optional[str] = None, - enforce_eager: bool = False, + enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, @@ -160,6 +164,34 @@ def __init__( self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + # Choose a default enforce_eager value if the user did not specify + # a value (enforce_eager is None) + if getattr(self.hf_config, 'is_encoder_decoder', False): + if self.enforce_eager is None: + # *Only for encoder/decoder models* and + # *only if enforce_eager is unset*, override + # to enforce_eager=True + # + # Add a logger message since it is *somewhat* non-intuitive that + # enforce_eager is True when the user has not specified its + # value. + logger.info("Forcing enforce_eager == True because " + "enforce_eager setting was unspecified and " + "CUDAGraph is not supported with encoder/ " + "decoder models.") + self.enforce_eager = True + + if not self.enforce_eager: + # Eager mode explicitly disabled by user for an encoder/ + # decoder model; however CUDAGRAPH + encoder/decoder is + # not currently supported + raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH) + elif self.enforce_eager is None: + # *Only for decoder-only models*, enforce_eager + # defaults to False if unset. This is intuitive + # so no logging message needed. + self.enforce_eager = False + if (not self.disable_sliding_window and self.hf_text_config.model_type == "gemma2" and self.hf_text_config.sliding_window is not None): diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 2c412a8f472e0..28839437c33c5 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -1,15 +1,7 @@ """Block manager utils.""" from vllm.sequence import SequenceGroup - -# Exception strings for non-implemented block manager enc/dec scenarios - -STR_NOT_IMPL_ENC_DEC_SWA = \ - "Sliding window attention for encoder/decoder models " + \ - "is not currently supported." - -STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ - "Prefix caching for encoder/decoder models " + \ - "is not currently supported." +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) def _get_block_mgr_sliding_window_attr(block_mgr): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 11d020be0c940..f60463107be44 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -392,6 +392,19 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: seq.status = SequenceStatus.FINISHED_ABORTED self.free_seq(seq) + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 @@ -963,6 +976,17 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # seq_id -> physical block numbers block_tables: Dict[int, List[int]] = {} + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq_data = seq_group.get_encoder_seq().data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data @@ -1001,6 +1025,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1032,6 +1058,8 @@ def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) # Add the finished requests to the finished requests list. # This list will be used to update the Mamba cache in the # next step. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 935a509cdb7ce..b6d2ea463940f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -69,7 +69,7 @@ class EngineArgs: rope_theta: Optional[float] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None - enforce_eager: bool = False + enforce_eager: Optional[bool] = None max_context_len_to_capture: Optional[int] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3747f93b16cd1..75c6d7e6c9b21 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,7 +3,7 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, TypeVar, Union +from typing import Set, Tuple, Type, TypeVar, Union import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -22,7 +22,8 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs +from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs, + get_prompt_type) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -42,7 +43,8 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import (Counter, is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -502,8 +504,19 @@ def _verify_args(self) -> None: self.prompt_adapter_config.verify_with_model_config( self.model_config) - def _get_eos_token_id( - self, lora_request: Optional[LoRARequest]) -> Optional[int]: + def _get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def _get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: if self.tokenizer is None: logger.warning("Using None for EOS token id because tokenizer " "is not initialized") @@ -511,6 +524,32 @@ def _get_eos_token_id( return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + def _get_decoder_start_token_id(self, ) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + logger.warning("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + logger.warning("Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + logger.warning("Falling back on for decoder start token id " + "because decoder start token id is not available.") + dec_start_token_id = self._get_bos_token_id() + + return dec_start_token_id + def _add_processed_request( self, request_id: str, @@ -529,6 +568,16 @@ def _add_processed_request( seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) + encoder_seq = None + if 'encoder_prompt_token_ids' in processed_inputs: + encoder_seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request, + from_decoder_prompt=False) + # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): seq_group = self._create_sequence_group_with_sampling( @@ -538,7 +587,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -546,7 +596,8 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -562,36 +613,362 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - def process_model_inputs( + _LLMInputComponentsType = Tuple[str, List[int], ] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[List[int]] = None, + ) -> List[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id: Optional[int] = ( + self._get_decoder_start_token_id()) + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt() + + if (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _tokenize_prompt( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[str] = None, + ) -> List[int]: + ''' + Wrapper around application of the model's + tokenizer. + + Arguments: + + * prompt + * request_id + * lora_request + + Returns: + + * prompt token ids + ''' + + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + return prompt_token_ids + + def _extract_single_prompt_for_enc_dec_input( + self, + inputs: Optional[PromptInputs], + request_id: Optional[str] = None, + ptype: Optional[str] = None, + is_encoder_prompt: bool = False, + ) -> Tuple[Optional[str], List[int]]: + ''' + Only for encoder/decoder models: + Extract prompt & prompt_token_ids from any single + encoder or decoder input prompt. For encoder input prompts + in particular, also extract multi-modal data. + + This function handles the following scenarios: + 1. The user supplied a singleton encoder prompt + & the prompt/prompt-token-ids must be extracted. + 2. The user supplied an explicit encoder/decoder + prompt & the prompt/prompt-token-ids must be + extracted from either the encoder and decoder prompts. + + For decoder prompts in particular (scenario 2), special + processing is applied to the returned decoder token ids. + + Arguments: + + * request_id + * ptype: str representation of the input prompt type. + If `ptype` is `None`, assume that the prompt + type is unknown and must be inferred. This is the + case for ExplicitEncoderDecoder sub-prompts. + * inputs: single encoder or decoder input prompt + * is_encoder_prompt: True if encoder input prompt. + If False, decoder prompt tokens + are preprocessed. + + Returns: + + * prompt + * prompt_token_ids + ''' + prompt_token_ids = None + ptype = (get_prompt_type(inputs) if ptype is None else ptype) + + if inputs is None: + prompt = None + elif ptype == 'str': + prompt = inputs + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + ) + elif ptype == 'TokensPrompt': + prompt = None + prompt_token_ids = inputs['prompt_token_ids'] + else: + prompt = inputs['prompt'] + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + ) + + if not is_encoder_prompt: + # Apply special pre-processing to + # decoder prompts + prompt_token_ids = (self._prepare_decoder_input_ids_for_generation( + prompt_token_ids, )) + + assert prompt_token_ids is not None + + return ( + prompt, + prompt_token_ids, + ) + + def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self._get_bos_token_id() + assert bos_token_id is not None + prompt_token_ids: List[int] = [bos_token_id] + return prompt_token_ids + + def _process_encoder_decoder_prompt( + self, + inputs: PromptInputs, + request_id: Optional[str] = None, + ) -> LLMInputs: + ''' + For encoder/decoder models only: + Process an input prompt + into an `LLMInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * inputs: an input prompt + * request_id + + Returns: + + * `LLMInputs` instance + ''' + + ptype = get_prompt_type(inputs) + + # Obtain encoder and decoder prompt tokens. Note + # that, no matter what, the decoder + # prompt type is unknown. + if ptype == "ExplicitEncoderDecoder": + # If input is explicit encoder/decoder prompt, + # then it remains to be determined what type + # of encoder prompt we have + extracted_encoder_prompt = inputs.get('encoder_prompt') + encoder_ptype = None + # Extract decoder prompt from explicit + # encoder/decoder prompt + extracted_decoder_prompt = inputs.get('decoder_prompt') + else: + # If input is singleton encoder prompt, then + # we know the encoder prompt type + extracted_encoder_prompt = inputs + encoder_ptype = ptype + # Decoder prompt is always unknown if + # encoder/decoder prompt is not explicit + extracted_decoder_prompt = None + + # Invoke helper function to obtain encoder + # prompt and prompt token ids, either from + # singleton encoder prompt or from the + # encoder sub-prompt of an explicit + # encoder/decode scenario 2), special + # processing is applied to the returned decoder token ids + ( + encoder_prompt, + encoder_prompt_token_ids, + ) = self._extract_single_prompt_for_enc_dec_input( + extracted_encoder_prompt, + request_id=request_id, + ptype=encoder_ptype, + is_encoder_prompt=True, + ) + + # Invoke helper method to obtain + # decoder prompt and prompt token ids. + # + # The helper method will detect the decoder + # prompt type. + # + # Helper method will also apply special + # preprocessing unique to decoder prompts. + ( + decoder_prompt, + decoder_prompt_token_ids, + ) = self._extract_single_prompt_for_enc_dec_input( + extracted_decoder_prompt, + request_id=request_id, + ptype=None, + is_encoder_prompt=False, + ) + + return LLMInputs( + prompt_token_ids=decoder_prompt_token_ids, + prompt=decoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids, + encoder_prompt=encoder_prompt, + ) + + def _process_decoder_only_prompt( self, - request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + request_id: Optional[str] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: + ''' + For decoder-only models: + Process an input prompt + into an `LLMInputs` instance. + + Arguments: + + * inputs: input prompt + * lora_request + * request_id + * prompt_adapter_request + + Returns: + + * `LLMInputs` instance + ''' + if isinstance(inputs, str): inputs = {"prompt": inputs} + prompt = inputs.get("prompt") if "prompt_token_ids" not in inputs: - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") - - prompt_token_ids = tokenizer.encode(request_id=request_id, - prompt=inputs["prompt"], - lora_request=lora_request) + prompt_token_ids = self._tokenize_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + ) else: prompt_token_ids = inputs["prompt_token_ids"] if prompt_adapter_request: - prompt_token_ids = \ - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - + prompt_token_ids + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=inputs.get("multi_modal_data")) + + def process_model_inputs( + self, + request_id: str, + inputs: PromptInputs, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: - llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, - prompt=inputs.get("prompt"), - multi_modal_data=inputs.get("multi_modal_data")) + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder - return self.input_processor(llm_inputs) + model_inputs = self._process_encoder_decoder_prompt( + inputs, + request_id=request_id, + ) + else: + # Decoder-only operation + model_inputs = self._process_decoder_only_prompt( + inputs, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + return self.input_processor(model_inputs) def add_request( self, @@ -676,6 +1053,7 @@ def _create_sequence_group_with_sampling( lora_request: Optional[LoRARequest], trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -701,7 +1079,8 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) return seq_group @@ -713,6 +1092,7 @@ def _create_sequence_group_with_pooling( arrival_time: float, lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -724,7 +1104,8 @@ def _create_sequence_group_with_pooling( arrival_time=arrival_time, lora_request=lora_request, pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request) + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1214,3 +1595,9 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: seq_span.set_attribute( SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + + def is_encoder_decoder_model(self): + return is_encoder_decoder_model_config(self.model_config) + + def is_embedding_model(self): + return is_embedding_model_config(self.model_config) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 262cba79e5712..eaa1572094936 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -121,12 +121,21 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, cpu_offload_gb: float = 0, - enforce_eager: bool = False, + enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False for decoder-only models and True + for encoder/decoder models, since encoder/decoder models + do not currently support CUDAGraph. + ''' + if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True removed_vision_keys = ("image_token_id", "image_feature_size", @@ -297,8 +306,8 @@ def generate( """ if self.llm_engine.model_config.embedding_mode: raise ValueError( - "LLM.generate() is only supported for generation models " - "(XForCausalLM).") + "LLM.generate() is only supported for (conditional) generation " + "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: inputs = self._convert_v1_inputs( @@ -631,3 +640,9 @@ def _run_engine( # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id)) + + def _is_encoder_decoder_model(self): + return self.llm_engine.is_encoder_decoder_model() + + def _is_embedding_model(self): + return self.llm_engine.is_embedding_model() diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index b13d9acf93d3b..e22b88f2fc38a 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,7 @@ -from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - TextPrompt, TokensPrompt, parse_and_batch_prompt) +from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText, + ParsedTokens, PromptInputs, SingletonPromptInputs, + TextPrompt, TokensPrompt, get_prompt_type, + is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -12,7 +14,18 @@ """ __all__ = [ - "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", - "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", - "InputContext", "InputRegistry" + "ParsedText", + "ParsedTokens", + "parse_and_batch_prompt", + "TextPrompt", + "TokensPrompt", + "PromptInputs", + "LLMInputs", + "INPUT_REGISTRY", + "InputContext", + "InputRegistry", + "get_prompt_type", + "is_valid_encoder_decoder_llm_inputs", + "ExplicitEncoderDecoderPrompt", + "SingletonPromptInputs", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4443e6c70fe5b..86c2901dc4c80 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -92,15 +92,114 @@ class TokensPrompt(TypedDict): """ -PromptInputs = Union[str, TextPrompt, TokensPrompt] +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] """ -The inputs to the LLM, which can take one of the following forms: +Set of possible schemas for a single LLM input: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) + +Note that "singleton" is as opposed to a data structure +which encapsulates multiple prompts, i.e. of the sort +which may be utilized for encoder/decoder models when +the user desires to express both the encoder & decoder +prompts explicitly, i.e. ExplicitEncoderDecoderPrompt + +A prompt of type SingletonPromptInputs may be employed +as (1) input to a decoder-only model, (2) input to +the encoder of an encoder/decoder model, in the scenario +where the decoder-prompt is not specified explicitly, or +(3) as a member of a larger data structure encapsulating +more than one prompt, i.e. ExplicitEncoderDecoderPrompt """ +class ExplicitEncoderDecoderPrompt(TypedDict): + """Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a + decoder prompt. + + The encoder and decoder prompts, respectively, + may formatted according to any of the + SingletonPromptInputs schemas, and are not + required to have the same schema. + + Only the encoder prompt may have multi-modal data. + + Note that an ExplicitEncoderDecoderPrompt may not + be used as an input to a decoder-only model, + and that the `encoder_prompt` and `decoder_prompt` + fields of this data structure may not themselves + must be SingletonPromptInputs instances. + """ + + encoder_prompt: SingletonPromptInputs + + decoder_prompt: SingletonPromptInputs + + +PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] +""" +Set of possible schemas for an LLM input, including +both decoder-only and encoder/decoder input types: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +- A single data structure containing both an encoder and a decoder prompt + (:class:`ExplicitEncoderDecoderPrompt`) +""" + + +def _has_required_keys( + d: dict, + required_keys: set, +) -> bool: + return required_keys.issubset(d.keys()) + + +def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]: + """ + Get the type-name of the prompt argument instance, given that + isinstance() cannot apply to TypedDict subclasses directly. + If the prompt is None, return 'None' as the type name. + + Arguments: + + * prompt: LLM input prompt or None + + Returns: + + * String representation of prompt type + """ + + if prompt is None: + return 'None' + + required_keys_dict = { + 'TextPrompt': {'prompt'}, + 'TokensPrompt': {'prompt_token_ids'}, + 'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, + } + + if isinstance(prompt, dict): + for (ptype, required_keys) in required_keys_dict.items(): + # Ignore type checking in the conditional below because type + # checker does not understand that is_dict(prompt) narrows + # down the possible types + if _has_required_keys( + prompt, # type: ignore + required_keys): + return ptype + + raise ValueError(f"Invalid prompt {prompt}, valid types are " + "required_keys_dict={required_keys_dict}") + + if isinstance(prompt, str): + return "str" + + raise ValueError(f"Invalid prompt {prompt}") + + class LLMInputs(TypedDict): """ The inputs in :class:`~vllm.LLMEngine` before they are @@ -114,8 +213,29 @@ class LLMInputs(TypedDict): The original prompt text corresponding to the token IDs, if available. """ + encoder_prompt_token_ids: NotRequired[List[int]] + """The token IDs of the encoder prompt.""" + + encoder_prompt: NotRequired[Optional[str]] + """ + The original encoder prompt text corresponding to the token IDs, if + available. + """ + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] """ Optional multi-modal data to pass to the model, if the model supports it. """ + + +def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool: + """ + Return True if the LLMInputs instance has the correct configuration + for encoder/decoder. + """ + + # True if encoder prompt token ids field exists & + # is not None + return ('encoder_prompt_token_ids' in inputs + and inputs['encoder_prompt_token_ids'] is not None) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ebb77a802d5cb..0f91b92665c28 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -83,7 +83,16 @@ "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), } -_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_CONDITIONAL_GENERATION_MODELS +} # Architecture -> type. # out of tree models diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000000000..5066e991f9003 --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,996 @@ +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BartConfig +from transformers.utils import logging + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors, SamplerOutput + +logger = logging.get_logger(__name__) + + +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] + + +class BartLearnedPositionalEmbedding(VocabParallelEmbedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + positions: torch.Tensor, + attn_type: AttentionType, + ) -> torch.Tensor: + """`input_ids' shape is expected to be [bsz x seqlen].""" + + assert attn_type != AttentionType.ENCODER_DECODER + + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(VocabParallelEmbedding): + """ + This module overrides VocabParallelEmbedding's + forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale + + +class BartParallelLMHead(ParallelLMHead): + """ + This module overrides ParallelLMHead's + forward by dividing by embeddings scale, + yielding effectively the inverse of + BartScaledWordEmbedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) / self.embed_scale + + +class BartEncoderAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.DECODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + # (afeldman-nm 2024/07/22) TODO: + # Need a more efficient solution for q/k/v + qkv_dec, _ = self.qkv_proj(decoder_hidden_states) + q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + if encoder_hidden_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartEncoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = get_act_fn(config.activation_function, + quant_config) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + kv_cache: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class BartDecoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartDecoderSelfAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.activation_fn = get_act_fn(config.activation_function, + quant_config) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' + self.encoder_attn = BartCrossAttention( + self.embed_dim, + config.decoder_attention_heads, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_hidden_states + torch.Tensor of *decoder* input embeddings. + kv_cache: + KV cache tensor + attn_metadata: + vLLM Attention metadata structure + encoder_hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Decoder layer output torch.Tensor + """ + residual = decoder_hidden_states + + # Self Attention + hidden_states = self.self_attn(hidden_states=decoder_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + + residual = hidden_states + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList( + [BartEncoderLayer(config,cache_config,quant_config) \ + for _ in range(config.encoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + + input_ids = input_ids.view(-1, input_ids.shape[-1]) + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions( + positions, + AttentionType.ENCODER, + ) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + + return hidden_states + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [BartDecoderLayer(config,cache_config,quant_config) \ + for _ in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + def forward(self, decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + + inputs_embeds = self.embed_tokens(decoder_input_ids) + + # embed positions + embed_pos = self.embed_positions( + decoder_positions, + AttentionType.DECODER, + ) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + return hidden_states + + +class BartModel(nn.Module): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__() + + self.config = config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config) + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + return decoder_outputs + + +class BartForConditionalGeneration(nn.Module): + base_model_prefix = "model" + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + + super().__init__() + self.config = config + self.model = BartModel(config, + cache_config, + quant_config, + lora_config=lora_config) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Output torch.Tensor + """ + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions, kv_caches, attn_metadata) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + stacked_params_mapping = { + "q_proj": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "k_proj": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "v_proj": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + if ('shared.weight' in name + or 'encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + # Assign shared weight values + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) + + assert shared_embedding_weight is not None + + if shared_embedding_shard_id: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, + shared_embedding_shard_id) + else: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) diff --git a/vllm/outputs.py b/vllm/outputs.py index b1cb1cd07fbb1..040f770814576 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -70,12 +70,20 @@ class RequestOutput: Args: request_id: The unique ID of the request. prompt: The prompt string of the request. + For encoder/decoder models, this is the + decoder input prompt. prompt_token_ids: The token IDs of the prompt. + For encoder/decoder models, this is the + decoder input prompt token ids. prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. metrics: Metrics associated with the request. lora_request: The LoRA request that was used to generate the output. + encoder_prompt: The encoder prompt string of the request; + None if decoder-only + encoder_prompt_token_ids: The token IDs of the encoder prompt; + None if decoder-only """ def __init__( @@ -88,6 +96,8 @@ def __init__( finished: bool, metrics: Optional[RequestMetrics] = None, lora_request: Optional[LoRARequest] = None, + encoder_prompt: Optional[str] = None, + encoder_prompt_token_ids: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -97,6 +107,8 @@ def __init__( self.finished = finished self.metrics = metrics self.lora_request = lora_request + self.encoder_prompt = encoder_prompt + self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -137,6 +149,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Every sequence in the sequence group should have the same prompt. prompt = seq_group.prompt prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() finished_time = time.time() if finished else None @@ -148,12 +162,16 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": outputs, finished, seq_group.metrics, - lora_request=seq_group.lora_request) + lora_request=seq_group.lora_request, + encoder_prompt=encoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " f"finished={self.finished}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 7ef9387c611f8..6347855333822 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -7,10 +7,11 @@ from collections import defaultdict from dataclasses import dataclass, field from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, - Union) + Union, cast) import torch +from vllm.inputs import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -244,24 +245,38 @@ def __repr__(self) -> str: class Sequence: """Stores the data, status, and block information of a sequence. + The sequence is constructed from the LLMInputs instance passed + in through the `inputs` constructor argument. + + For encoder/decoder models, LLMInputs encapsulates both a + decoder and encoder prompt, creating an ambiguity about which + prompt to construct the sequence from. The `from_decoder_prompt` + constructor argument signals whether to construct the Sequence + from the LLMInputs decoder prompt, or encoder prompt. + Args: seq_id: The ID of the sequence. inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. prompt_adapter_request: Prompt Adapter request. + from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt + (True) or encoder prompt (False.) Must be True + for decoder-only model. """ def __init__( - self, - seq_id: int, - inputs: "LLMInputs", - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + from_decoder_prompt: bool = True, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -269,6 +284,36 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request + self.from_decoder_prompt = from_decoder_prompt + self._prompt: Optional[str] = None + self._prompt_token_ids: Optional[List[int]] = None + + # For decoder-only models, a Sequence is constructed + # from an LLMInputs instance (the `inputs` arg.) + # + # For encoder/decoder models the same `inputs` + # instance could be utilized to construct either an + # encoder sequence or a decoder sequence, because + # `LLMInputs` has both decoder- and encoder-oriented + # member variables (i.e. it encapsulates both an encoder + # and a decoder prompt.) The decision of which type of sequence + # to generate is determined by the `from_decoder_prompt` argument. + # + # When constructing a encoder sequence + # (`from_decoder_prompt` False) it matters that + # the `LLMInputs` instance stored in `inputs` is valid + # in the sense that its encoder-related member variables are + # populated; below, an exception is raised if this is + # not the case. + # + # When constructing a decoder sequence (`from_decoder_prompt` True) + # it does not matter whether `inputs` has its encoder-related + # member variables populated. + if not (from_decoder_prompt + or is_valid_encoder_decoder_llm_inputs(inputs)): + raise ValueError("Cannot extract encoder input prompt from " + f"invalid input {inputs}; did you forget the " + "encoder input prompt fields?") self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -289,11 +334,35 @@ def n_blocks(self) -> int: @property def prompt(self) -> Optional[str]: - return self.inputs.get("prompt") + if self._prompt is not None: + # Reuse precomputed prompt string + return self._prompt + + # Select decoder or encoder input prompt str, + # as appropriate + prompt_key: str = ("prompt" + if self.from_decoder_prompt else "encoder_prompt") + + # Cache prompt + self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) + return self._prompt @property def prompt_token_ids(self) -> List[int]: - return self.inputs["prompt_token_ids"] + if self._prompt_token_ids is not None: + # Reuse precomputed prompt token ids + return self._prompt_token_ids + + # Select decoder or encoder input prompt + # token ids, as appropriate + prompt_token_ids_key: str = ("prompt_token_ids" + if self.from_decoder_prompt else + "encoder_prompt_token_ids") + + # Cache computed prompt token ids + self._prompt_token_ids = cast(List[int], + self.inputs.get(prompt_token_ids_key)) + return self._prompt_token_ids @property def multi_modal_data(self) -> "MultiModalDataDict": @@ -472,6 +541,22 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return self.seqs[0].prompt_token_ids + @property + def encoder_prompt(self) -> Optional[str]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt is distinct + # from the decoder's. + return (self.encoder_seq.prompt + if self.encoder_seq is not None else None) + + @property + def encoder_prompt_token_ids(self) -> Optional[List[int]]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt token ids are + # distinct from the decoder's. + return (self.encoder_seq.prompt_token_ids + if self.encoder_seq is not None else None) + @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. diff --git a/vllm/utils.py b/vllm/utils.py index 51bd72977a226..7070984eb728b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,10 +27,93 @@ import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs, + SingletonPromptInputs) from vllm.logger import enable_trace_function_call, init_logger logger = init_logger(__name__) +# Exception strings for non-implemented encoder/decoder scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( + "Models with logits_soft_cap " + "require FlashInfer backend, which is " + "currently not supported for encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " + "currently supported with " + "encoder/decoder models.") + +STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " + "currently supported with encoder/" + "decoder models.") + +# Efficiently import all enc/dec error strings +# rather than having to import all of the above +STR_NOT_IMPL_ENC_DEC_ERR_STRS = { + "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, + "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, + "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, + "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, + "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, + "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, + "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, + "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, + "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, +} + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.half, "bfloat16": torch.bfloat16, @@ -1029,3 +1112,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) + + +def is_encoder_decoder_model_config(model_config) -> bool: + ''' + Extract the HF encoder/decoder model flag from the ModelConfig instance. + Return False if model_config is None. + ''' + return model_config is not None and \ + getattr(model_config.hf_config, + "is_encoder_decoder", + False) + + +def is_embedding_model_config(model_config) -> bool: + ''' + Extract the embedding model flag from the ModelConfig instance. + Return False if model_config is None. + ''' + return model_config is not None and \ + model_config.embedding_mode + + +def build_explicit_enc_dec_prompt( + encoder_prompt: SingletonPromptInputs, + decoder_prompt: SingletonPromptInputs, +) -> ExplicitEncoderDecoderPrompt: + return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt) + + +def zip_enc_dec_prompt_lists( + enc_prompt_list: List[SingletonPromptInputs], + dec_prompt_list: List[SingletonPromptInputs], +) -> List[ExplicitEncoderDecoderPrompt]: + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompt_list, dec_prompt_list) + ] + + +def to_enc_dec_tuple_list( + enc_dec_prompts: List[ExplicitEncoderDecoderPrompt], +) -> List[Tuple[PromptInputs, PromptInputs]]: + return [(enc_dec_prompt['encoder_prompt'], + enc_dec_prompt['decoder_prompt']) + for enc_dec_prompt in enc_dec_prompts] diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py new file mode 100644 index 0000000000000..d9b323f2af09e --- /dev/null +++ b/vllm/worker/enc_dec_model_runner.py @@ -0,0 +1,472 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, cast + +import torch +import torch.distributed + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, + get_global_forced_attn_backend, + global_force_attn_backend) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, + SequenceGroupMetadata) +from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, + ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) +from vllm.worker.utils import assert_enc_dec_mr_supported_scenario + +logger = init_logger(__name__) + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInput": + return cast( + EncoderDecoderModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + ): + ''' + EncoderDecoderModelRunner constructor. + + `lora_config`, `multimodal_config`, and prompt_adapter_config are + unused (since these features are not yet supported for encoder/decoder + models) but these arguments are present here for compatibility with + the base-class constructor. + ''' + + self._maybe_force_supported_attention_backend() + + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=None, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + + # Crash for unsupported encoder/scenarios + assert_enc_dec_mr_supported_scenario(self) + + def _maybe_force_supported_attention_backend(self): + ''' + Force vLLM to use the XFormers attention backend, + which is currently the only supported option. + ''' + + def raise_backend_err(): + # The user has specified an attention backend override + # which is invalid for encoder/decoder models + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) + + maybe_env_var_forced_backend = get_env_variable_attn_backend() + maybe_global_forced_backend = get_global_forced_attn_backend() + is_forced_by_global = maybe_global_forced_backend is not None + is_forced_by_env_var = maybe_env_var_forced_backend is not None + + if not (is_forced_by_global or is_forced_by_env_var): + # The user has not already specified an attention backend + # override + logger.info("EncoderDecoderModelRunner requires " + "XFormers backend; overriding backend " + "auto-selection and forcing XFormers.") + global_force_attn_backend(_Backend.XFORMERS) + elif is_forced_by_global: + # Backend override enforced by global variable takes + # precedence over vLLM backend environment variable. + if maybe_global_forced_backend != _Backend.XFORMERS: + raise_backend_err() + elif is_forced_by_env_var: + # Backend override enforced by vLLM backend + # environment variable + if maybe_env_var_forced_backend != _Backend.XFORMERS: + raise_backend_err() + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + @torch.inference_mode() + def execute_model( + self, + model_input: EncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") + + model_executable = self.model + + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **seqlen_agnostic_kwargs) + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInput: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input)) + + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + + model_config = self.model_config + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + seq_data, _ = INPUT_REGISTRY \ + .dummy_data_for_profiling(model_config, seq_len) + + # Having more tokens is over-conservative but otherwise fine + assert len(seq_data.prompt_token_ids) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(seq_data.prompt_token_ids)}") + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + encoder_seq_data=seq_data, + cross_block_table=None, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInput, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + is_profile_run = (seq_group_metadata.block_tables is None) + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + else: + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + # Convert cross-attention block tables to encoder input tensor + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max( + len(block_table) for block_table in cross_block_tables), + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py new file mode 100644 index 0000000000000..8df3c8bc5408b --- /dev/null +++ b/vllm/worker/utils.py @@ -0,0 +1,56 @@ +''' +Worker-related helper functions. +''' + +from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS +from vllm.worker.model_runner import GPUModelRunnerBase + + +def assert_enc_dec_mr_supported_scenario( + enc_dec_mr: GPUModelRunnerBase) -> None: + ''' + Asserted that the provided encoder/decoder model runner instance reflects + a supported scenario. + ''' + + if enc_dec_mr.cache_config.enable_prefix_caching: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE']) + + if enc_dec_mr.sliding_window is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA']) + + if enc_dec_mr.scheduler_config.chunked_prefill_enabled: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL']) + + if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping', + None) is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] + ) + + if enc_dec_mr.lora_config is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) + + if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) + + if enc_dec_mr.multimodal_config is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) + + if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) + + if not enc_dec_mr.model_config.enforce_eager: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH']) + + if enc_dec_mr.prompt_adapter_config is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9e2cfff435cf6..ad6f6750ff980 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -19,8 +19,11 @@ from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest +from vllm.utils import (is_embedding_model_config, + is_encoder_decoder_model_config) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -85,8 +88,10 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_runner_cls is not None: ModelRunnerClass = model_runner_cls - elif self.model_config.embedding_mode: + elif self._is_embedding_model(): ModelRunnerClass = EmbeddingModelRunner + elif self._is_encoder_decoder_model(): + ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( model_config, parallel_config, @@ -107,6 +112,12 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + def _is_encoder_decoder_model(self): + return is_encoder_decoder_model_config(self.model_config) + + def _is_embedding_model(self): + return is_embedding_model_config(self.model_config) + def init_device(self) -> None: if self.device_config.device.type == "cuda": # torch.distributed.all_reduce does not free the input tensor until