From e254497b66dcd87038969b0ad34d34425edfc5fe Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 11 May 2024 11:30:37 -0700 Subject: [PATCH] [Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734) --- examples/offline_inference_embedding.py | 17 ++ examples/openai_embedding_client.py | 23 ++ requirements-dev.txt | 9 +- tests/conftest.py | 38 ++- .../output_processor/test_multi_step.py | 12 +- tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_openai_server.py | 96 ++++++- tests/models/test_embedding.py | 44 +++ tests/samplers/test_logits_processor.py | 6 +- tests/samplers/test_seeded_generate.py | 2 +- tests/spec_decode/utils.py | 6 +- tests/test_sequence.py | 12 +- vllm/__init__.py | 7 +- vllm/config.py | 15 + vllm/core/embedding_model_block_manager.py | 84 ++++++ vllm/core/interfaces.py | 5 + vllm/core/scheduler.py | 10 +- vllm/engine/arg_utils.py | 1 + vllm/engine/async_llm_engine.py | 158 +++++++++-- vllm/engine/llm_engine.py | 143 ++++++++-- vllm/entrypoints/llm.py | 150 ++++++++-- vllm/entrypoints/openai/api_server.py | 20 +- vllm/entrypoints/openai/protocol.py | 36 ++- vllm/entrypoints/openai/serving_embedding.py | 134 +++++++++ vllm/entrypoints/openai/serving_engine.py | 16 +- vllm/executor/gpu_executor.py | 10 +- vllm/model_executor/layers/pooler.py | 56 ++++ vllm/model_executor/layers/sampler.py | 7 +- vllm/model_executor/models/__init__.py | 12 +- vllm/model_executor/models/llama_embedding.py | 87 ++++++ vllm/model_executor/pooling_metadata.py | 69 +++++ vllm/outputs.py | 82 +++++- vllm/pooling_params.py | 20 ++ vllm/sequence.py | 89 +++++- vllm/spec_decode/util.py | 5 +- vllm/worker/embedding_model_runner.py | 266 ++++++++++++++++++ vllm/worker/model_runner.py | 25 +- vllm/worker/worker.py | 14 +- 38 files changed, 1627 insertions(+), 160 deletions(-) create mode 100644 examples/offline_inference_embedding.py create mode 100644 examples/openai_embedding_client.py create mode 100644 tests/models/test_embedding.py create mode 100644 vllm/core/embedding_model_block_manager.py create mode 100644 vllm/entrypoints/openai/serving_embedding.py create mode 100644 vllm/model_executor/layers/pooler.py create mode 100644 vllm/model_executor/models/llama_embedding.py create mode 100644 vllm/model_executor/pooling_metadata.py create mode 100644 vllm/pooling_params.py create mode 100644 vllm/worker/embedding_model_runner.py diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py new file mode 100644 index 000000000000..7d5ef128bc8e --- /dev/null +++ b/examples/offline_inference_embedding.py @@ -0,0 +1,17 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.encode(prompts) +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 4096 floats diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py new file mode 100644 index 000000000000..b73360fe15a2 --- /dev/null +++ b/examples/openai_embedding_client.py @@ -0,0 +1,23 @@ +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +responses = client.embeddings.create(input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" +], + model=model) + +for data in responses.data: + print(data.embedding) # list of float of len 4096 diff --git a/requirements-dev.txt b/requirements-dev.txt index e6d375cbafa3..796c9e37d023 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,12 +19,15 @@ pytest-forked pytest-asyncio pytest-rerunfailures pytest-shard -httpx + +# testing utils +awscli einops # required for MPT +httpx +peft requests ray -peft -awscli +sentence-transformers # required for embedding # Benchmarking aiohttp diff --git a/tests/conftest.py b/tests/conftest.py index 1f2ad1cbd729..b8117a19c75d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -133,6 +133,10 @@ def example_long_prompts() -> List[str]: "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, } +_EMBEDDING_MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + class HfRunner: @@ -145,14 +149,7 @@ def __init__( assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] self.model_name = model_name - if model_name not in _VISION_LANGUAGE_MODELS: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = None - else: + if model_name in _VISION_LANGUAGE_MODELS: self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( model_name, torch_dtype=torch_dtype, @@ -162,6 +159,20 @@ def __init__( model_name, torch_dtype=torch_dtype, ) + elif model_name in _EMBEDDING_MODELS: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype).cuda() + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() + self.processor = None if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -334,6 +345,9 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + def __del__(self): del self.model cleanup() @@ -459,6 +473,14 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + def __del__(self): del self.model cleanup() diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 6da3da091db7..2bf4bf69da20 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, - SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 13e2e372cef3..74b49726734b 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -14,6 +14,7 @@ class MockModelConfig: tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None + embedding_mode = False @dataclass diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index e53e64a0c1ff..c22ac4507658 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -23,6 +23,7 @@ MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" @@ -121,7 +122,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -150,6 +151,25 @@ def server(zephyr_lora_files): ray.shutdown() +@pytest.fixture(scope="module") +def embedding_server(zephyr_lora_files): + ray.shutdown() + ray.init() + server_runner = ServerRunner.remote([ + "--model", + EMBEDDING_MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + @pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( @@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + input = [ + "The chef prepared a delicious meal.", + ] + + # test single embedding + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 9 + assert embeddings.usage.total_tokens == 9 + + # test using token IDs + input = [1, 1, 1, 1, 1] + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 5 + assert embeddings.usage.total_tokens == 5 + + +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + # test List[str] + inputs = [ + "The cat sat on the mat.", "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky." + ] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) == 4096 + + # test List[List[int]] + inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 17 + assert embeddings.usage.total_tokens == 17 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py new file mode 100644 index 000000000000..59bf054913f7 --- /dev/null +++ b/tests/models/test_embedding.py @@ -0,0 +1,44 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_llama_embedding.py`. +""" +import pytest +import torch +import torch.nn.functional as F + +MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.encode(example_prompts) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.encode(example_prompts) + del vllm_model + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3788e9e9752f..be4c2ea1b781 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -36,14 +36,14 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[0], - sampling_params=params_with_logprobs, + params=params_with_logprobs, prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[1], - sampling_params=SamplingParams( + params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), @@ -53,7 +53,7 @@ def pick_vllm(token_ids, logits): # test grouped requests vllm_model.model._add_request( prompt=example_prompts[2], - sampling_params=SamplingParams(max_tokens=max_tokens), + params=SamplingParams(max_tokens=max_tokens), prompt_token_ids=None, ) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58d..ce4501bbf71e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -60,7 +60,7 @@ def test_random_sample_with_seed( llm._add_request( prompt=prompt, prompt_token_ids=None, - sampling_params=params, + params=params, ) results = llm._run_engine(use_tqdm=False) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f288652d5155..d52b22c30bd4 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -7,8 +7,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SamplerOutput, SequenceData, - SequenceGroupMetadata, SequenceGroupOutput, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine @@ -170,7 +170,7 @@ def create_sampler_output_list( return [ SamplerOutput(outputs=[ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( output_token=token_id, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 53061278d5be..b8ea1f6b7720 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,17 +1,17 @@ import pytest from tests.core.utils import create_dummy_prompt -from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, + SequenceData, SequenceOutput) @pytest.fixture def sample_outputs(): return [ - SequenceGroupOutput(samples=[ + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) ], - prompt_logprobs=None) for i in range(5) + prompt_logprobs=None) for i in range(5) ] @@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs): def test_sampler_output_setitem(sampler_output): - new_output = SequenceGroupOutput(samples=[ + new_output = CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) ], - prompt_logprobs=None) + prompt_logprobs=None) sampler_output[2] = new_output assert sampler_output[2] == new_output diff --git a/vllm/__init__.py b/vllm/__init__.py index 59810da3ca41..74674ca0d12a 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -6,7 +6,9 @@ from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.model_executor.models import ModelRegistry -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, RequestOutput) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams __version__ = "0.4.2" @@ -17,9 +19,12 @@ "SamplingParams", "RequestOutput", "CompletionOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", "initialize_ray_cluster", + "PoolingParams", ] diff --git a/vllm/config.py b/vllm/config.py index 275814d72e6c..fab9cfbf41a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) +from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron @@ -22,6 +23,7 @@ logger = init_logger(__name__) _GB = 1 << 30 +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 class ModelConfig: @@ -126,6 +128,7 @@ def __init__( served_model_name) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -137,6 +140,11 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_embedding_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.embedding_mode = any( + ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -591,6 +599,7 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. + embedding_mode: Whether the running model is for embedding. """ def __init__( @@ -602,6 +611,7 @@ def __init__( num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -610,6 +620,10 @@ def __init__( # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. self.max_num_batched_tokens = 512 + elif embedding_mode: + # For embedding, choose specific value for higher throughput + self.max_num_batched_tokens = max( + max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -623,6 +637,7 @@ def __init__( self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self.embedding_mode = embedding_mode self._verify_args() diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py new file mode 100644 index 000000000000..a09d79ec3c42 --- /dev/null +++ b/vllm/core/embedding_model_block_manager.py @@ -0,0 +1,84 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup + + +class EmbeddingModelBlockSpaceManager(BlockSpaceManager): + """An embedding version of BlockSpaceManager for use in environments + with embedding models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b2a5e41990f3..689cbc2179ee 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,6 +35,11 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 + if version == "embedding": + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 35e3db18f1c4..fb6e985b2f31 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -270,9 +270,14 @@ def __init__( self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") + version) # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( @@ -968,6 +973,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, + pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5c2acbef1312..163723b4be36 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -574,6 +574,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37a2dc77a3b5..a31f10b7748d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,7 +14,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -47,15 +48,16 @@ def _raise_exception_on_finish( class AsyncStream: - """A stream of RequestOutputs for a request that can be - iterated over asynchronously.""" + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: self.request_id = request_id self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -71,7 +73,7 @@ def finished(self) -> bool: def __aiter__(self): return self - async def __anext__(self) -> RequestOutput: + async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: result = await self._queue.get() if isinstance(result, Exception): raise result @@ -108,7 +110,8 @@ def propagate_exception(self, self.abort_request(rid) def process_request_output(self, - request_output: RequestOutput, + request_output: Union[RequestOutput, + EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -196,7 +199,8 @@ def has_new_requests(self): class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" - async def step_async(self) -> List[RequestOutput]: + async def step_async( + self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -251,7 +255,7 @@ async def add_request_async( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -270,8 +274,8 @@ async def add_request_async( return self.add_request(request_id, prompt=prompt, + params=params, prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, arrival_time=arrival_time, lora_request=lora_request, multi_modal_data=multi_modal_data) @@ -511,7 +515,7 @@ async def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -528,9 +532,9 @@ async def add_request( max_log_len] logger.info( "Received request %s: prompt: %r, " - "sampling_params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, - sampling_params, shortened_token_ids, lora_request) + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) if not self.is_running: if self.start_engine_loop: @@ -562,7 +566,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, prompt=prompt, - sampling_params=sampling_params, + params=params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, @@ -597,8 +601,8 @@ async def generate( multi_modal_data: Multi modal data per request. Yields: - The output `RequestOutput` objects from the LLMEngine for the - request. + The output `RequestOutput` objects from the LLMEngine + for the request. Details: - If the engine is not running, start the background loop, @@ -643,25 +647,123 @@ async def generate( >>> # Process and return the final output >>> ... """ - # Preprocess the request. - arrival_time = time.time() - - try: - stream = await self.add_request( + async for output in self.process_request( request_id, prompt, sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data, - ) + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def encode( + self, + prompt: Optional[str], + pooling_params: PoolingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data per request. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in self.process_request( + request_id, + prompt, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def process_request( + self, + request_id: str, + prompt: Optional[str], + params: Union[SamplingParams, PoolingParams], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + """Common logic to process requests with SamplingParams or + PoolingParams.""" + arrival_time = time.time() + + stream = await self.add_request( + request_id, + prompt, + params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) + try: async for request_output in stream: yield request_output except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. self._abort(request_id) raise e diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9938b045ba2..46fa41030b4a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,9 +20,12 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + MultiModalData, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -169,7 +172,8 @@ def __init__( load_config=load_config, ) - self._initialize_kv_caches() + if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -354,7 +358,7 @@ def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -370,7 +374,8 @@ def add_request( request_id: The unique ID of the request. prompt: The prompt string. Can be None if prompt_token_ids is provided. - sampling_params: The sampling parameters for text generation. + params: Parameters for sampling or pooling. SamplingParams + for text generation. PoolingParams for pooling. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use @@ -404,13 +409,6 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.time() prompt_token_ids = self.encode_request( @@ -432,6 +430,50 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() @@ -443,11 +485,35 @@ def add_request( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + multi_modal_data=multi_modal_data) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + pooling_params=pooling_params) + return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -484,13 +550,25 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() + def _process_sequence_group_outputs( + self, + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + def _process_model_outputs( self, - output: List[SamplerOutput], + output: List[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[RequestOutput]: + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -510,6 +588,9 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + if self.model_config.embedding_mode: + self._process_sequence_group_outputs(seq_group, outputs) + continue self.output_processor.process_prompt_logprob(seq_group, outputs) if seq_group_meta.do_sample: @@ -519,18 +600,19 @@ def _process_model_outputs( self.scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[RequestOutput] = [] + request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) for seq_group in ignored_seq_groups: - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) return request_outputs - def step(self) -> List[RequestOutput]: + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -570,7 +652,7 @@ def step(self) -> List[RequestOutput]: >>> while True: >>> if example_inputs: >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> engine.add_request(str(req_id),prompt,sampling_params) >>> >>> # continue the request processing >>> request_outputs = engine.step() @@ -637,12 +719,15 @@ def _get_stats( # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 0. + if num_total_gpu is not None: + num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( + ) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. - if num_total_cpu > 0: + if num_total_cpu is not None and num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) @@ -716,8 +801,10 @@ def _get_stats( seq.get_output_len() for seq in seq_group.get_finished_seqs() ]) - best_of_requests.append(seq_group.sampling_params.best_of) - n_requests.append(seq_group.sampling_params.n) + if seq_group.sampling_params is not None: + best_of_requests.append( + seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) finished_reason_requests.extend([ SequenceStatus.get_finished_reason(seq.status) for seq in seq_group.get_finished_seqs() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 71620139fba3..25f4428100b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,13 +6,17 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter +logger = init_logger(__name__) + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -164,8 +168,89 @@ def generate( multi_modal_data: Multi modal data. Returns: - A list of `RequestOutput` objects containing the generated - completions in the same order as the input prompts. + A list of `RequestOutput` objects containing the + generated completions in the same order as the input prompts. + """ + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + sampling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def encode( + self, + prompts: Optional[Union[str, List[str]]] = None, + pooling_params: Optional[Union[PoolingParams, + List[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + generated embeddings in the same order as the input prompts. + """ + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def _validate_and_prepare_requests( + self, + prompts: Optional[Union[str, List[str]]], + params: Union[Union[SamplingParams, PoolingParams], + List[Union[SamplingParams, + PoolingParams]]], # Unified parameter + prompt_token_ids: Optional[List[List[int]]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[dict]: + """Validates and prepares request data for adding to the engine. + + Ensures prompts and token IDs are consistent, and returns a list of + dictionaries with request data for further processing. """ if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " @@ -188,40 +273,43 @@ def generate( assert prompt_token_ids is not None num_requests = len(prompt_token_ids) - if sampling_params is None: - # Use default sampling params. - sampling_params = SamplingParams() - - elif isinstance(sampling_params, - list) and len(sampling_params) != num_requests: - raise ValueError("The lengths of prompts and sampling_params " + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " "must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. + requests_data = [] for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request( + + multi_modal_item = MultiModalData( + type=multi_modal_data.type, + data=multi_modal_data.data[i].unsqueeze(0), + ) if multi_modal_data else None + + requests_data.append({ + "prompt": prompt, - sampling_params[i] - if isinstance(sampling_params, list) else sampling_params, + "params": + params[i] if isinstance(params, list) else params, + "prompt_token_ids": token_ids, - lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, - ) - return self._run_engine(use_tqdm) + "lora_request": + lora_request, + "multi_modal_data": + multi_modal_item, + }) + + return requests_data def _add_request( self, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, @@ -229,12 +317,14 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, - sampling_params, + params, prompt_token_ids, lora_request=lora_request, multi_modal_data=multi_modal_data) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine( + self, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -245,7 +335,7 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[RequestOutput] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() @@ -253,10 +343,12 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: if output.finished: outputs.append(output) if use_tqdm: - total_toks += (sum( - len(stp.token_ids) for stp in output.outputs)) - spd = total_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_toks += sum( + len(stp.token_ids) for stp in output.outputs) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 362f28d05c3b..7cd51b959a0e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -22,9 +22,11 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, ErrorResponse) + CompletionRequest, + EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -32,6 +34,8 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding + logger = init_logger(__name__) _running_tasks: Set[asyncio.Task] = set() @@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) +@app.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await openai_serving_embedding.create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + return JSONResponse(content=generator.model_dump()) + + if __name__ == "__main__": args = parse_args() @@ -190,7 +205,8 @@ async def authentication(request: Request, call_next): args.chat_template) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules) - + openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, + served_model_names) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3cd9ddad3b7b..139c5716c7ce 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,13 +1,14 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import torch from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -363,6 +364,24 @@ def check_guided_decoding_count(cls, data): return data +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: List[float] + + +class EmbeddingResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + class ChatMessage(OpenAIBaseModel): role: str content: str diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 000000000000..7a57be0c8891 --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,134 @@ +import time +from typing import AsyncIterator, List, Tuple + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_completion import parse_prompt_format +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] + + +def request_output_to_embedding_response( + final_res_batch: List[EmbeddingRequestOutput], + request_id: str, + created_time: int, + model_name: str, +) -> EmbeddingResponse: + data = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + + embedding_data = EmbeddingResponseData( + index=idx, embedding=final_res.outputs.embedding) + data.append(embedding_data) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str]): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + self._check_embedding_mode(model_config.embedding_mode) + + async def create_embedding(self, request: EmbeddingRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # Return error for unsupported features. + if request.encoding_format == "base64": + return self.create_error_response( + "base64 encoding is not currently supported") + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + prompt_is_tokens, prompts = parse_prompt_format(request.input) + pooling_params = request.to_pooling_params() + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + prompt_ids, prompt_text = prompt_formats + + generators.append( + self.engine.generate(prompt_text, + pooling_params, + f"{request_id}-{i}", + prompt_token_ids=prompt_ids)) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + + return response + + def _check_embedding_mode(self, embedding_mode: bool): + if not embedding_mode: + logger.warning( + "embedding_mode is False. Embedding API will not work.") + else: + logger.info("Activating the server engine with embedding enabled.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f10718c5f3d8..58a1c2f7e73f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,7 +9,8 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, ErrorResponse, + CompletionRequest, + EmbeddingRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger @@ -165,7 +166,8 @@ def _maybe_get_lora( def _validate_prompt_and_tokenize( self, - request: Union[ChatCompletionRequest, CompletionRequest], + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -191,6 +193,16 @@ def _validate_prompt_and_tokenize( prompt_ids) token_num = len(input_ids) + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, EmbeddingRequest): + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.", ) + return input_ids, input_text + if request.max_tokens is None: if token_num >= self.max_model_len: raise ValueError( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index fa3480fa6483..2b72b31b5f07 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -123,8 +123,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = self.driver_worker.execute_model(execute_model_req) return output @@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 000000000000..445b30b8c6e9 --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,56 @@ +from enum import IntEnum + +import torch +import torch.nn as nn + +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + if self.pooling_type == PoolingType.LAST: + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_flat_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e52e350d2726..c8bab46c83ec 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,8 +10,9 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceOutput) # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] @@ -1019,7 +1020,7 @@ def _build_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d5263b500fe0..6aec104be8da 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) # Architecture -> (module, class). -_MODELS = { +_GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b @@ -58,6 +58,12 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -114,6 +120,10 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def is_embedding_model(model_arch: str) -> bool: + return model_arch in _EMBEDDING_MODELS + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py new file mode 100644 index 000000000000..8f1c77da50d9 --- /dev/null +++ b/vllm/model_executor/models/llama_embedding.py @@ -0,0 +1,87 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class LlamaEmbeddingModel(nn.Module): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 000000000000..b86cafce85d1 --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd..f9bce9e683f2 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -57,8 +57,27 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +class EmbeddingOutput: + """The output data of one completion output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. The + length of vector depends on the model as listed in the embedding guide. + """ + + def __init__( + self, + embedding: List[float], + ) -> None: + self.embedding = embedding + + def __repr__(self) -> str: + return (f"EmbeddingOutput(" + f"embedding={len(self.embedding)}") + + class RequestOutput: - """The output data of a request to the LLM. + """The output data of a completion request to the LLM. Args: request_id: The unique ID of the request. @@ -93,6 +112,9 @@ def __init__( @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + if seq_group.sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -148,3 +170,61 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request})") + + +class EmbeddingRequestOutput: + """ + The output data of an embedding request to the LLM. + + Args: + request_id (str): A unique identifier for the embedding request. + outputs (EmbeddingOutput): The embedding results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the embedding is completed. + """ + + def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + if seq_group.embeddings is None: + raise ValueError( + "Embeddings are missing in seq_group for EmbeddingRequest.") + output = EmbeddingOutput(seq_group.embeddings) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 000000000000..3b95d73ddc2c --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + + +class PoolingParams: + """Pooling parameters for pooling. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + + def __init__(self, additional_data: Optional[Any] = None): + self.additional_data = additional_data + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data, ) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 3cebb85b49d2..46ac33b7ecab 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,11 +1,13 @@ """Sequence and its related classes.""" import copy import enum +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -375,12 +377,12 @@ class SequenceGroupState: class MultiModalData: """Multi modal request. - + Args: type: The data type. data: The actual data. The required shape and semantic meaning of it depends on the vision - language config of the hosted model. + language config of the hosted model. See `VisionLanguageConfig` in `config.py`. """ @@ -402,16 +404,22 @@ class SequenceGroup: arrival_time: The arrival time of the request. lora_request: LoRA request. multi_modal_data: Multi modal data associated with the request. + embeddings: The embeddings vectors of the prompt of the sequence group + for an embedding model. + pooling_params: The pooling parameters used to generate the pooling + for an embedding model. """ def __init__( self, request_id: str, seqs: List[Sequence], - sampling_params: SamplingParams, arrival_time: float, + sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + embeddings: Optional[List[float]] = None, + pooling_params: Optional[PoolingParams] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -425,6 +433,8 @@ def __init__( self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.multi_modal_data = multi_modal_data + self.embeddings = embeddings + self.pooling_params = pooling_params @property def prompt(self) -> str: @@ -479,12 +489,13 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params.use_beam_search: + if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. return self.sampling_params.best_of else: - if self.sampling_params.best_of > self.num_seqs(): + if (self.sampling_params + and self.sampling_params.best_of > self.num_seqs()): # At prompt stage, the sequence group is not yet filled up # and only have one sequence running. However, in the # generation stage, we will have `best_of` sequences running. @@ -555,7 +566,7 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def is_prefill(self) -> bool: - # Every sequences should be in the same stage. + # Every sequence should be in the same stage. return self.get_seqs()[0].is_prefill() def __repr__(self) -> str: @@ -594,6 +605,7 @@ def __init__( sampling_params: SamplingParams, block_tables: Dict[int, List[int]], do_sample: bool = True, + pooling_params: Optional[PoolingParams] = None, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, @@ -605,6 +617,7 @@ def __init__( self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.pooling_params = pooling_params self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data @@ -669,8 +682,20 @@ def __eq__(self, other: object) -> bool: return equal and log_probs_equal -class SequenceGroupOutput: - """The model output associated with a sequence group.""" +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with a completion sequence group.""" def __init__( self, @@ -682,26 +707,45 @@ def __init__( self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutput): + if not isinstance(other, CompletionSequenceGroupOutput): raise NotImplementedError() return (self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) +class EmbeddingSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with an embedding sequence group.""" + + def __init__( + self, + embeddings: List[float], + ) -> None: + self.embeddings = embeddings + + def __repr__(self) -> str: + return (f"EmbeddingSequenceGroupOutput(" + f"embeddings_shape={len(self.embeddings)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EmbeddingSequenceGroupOutput): + raise NotImplementedError() + return self.embeddings == other.embeddings + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. - This datastructure implements methods so it can be used like a list, but + This data structure implements methods, so it can be used like a list, but also has optional fields for device tensors. """ - outputs: List[SequenceGroupOutput] + outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None @@ -742,6 +786,27 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") +@dataclass +class PoolerOutput: + """The output from a pooling operation in the embedding model.""" + outputs: List[EmbeddingSequenceGroupOutput] + + spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + @dataclass class ExecuteModelRequest: """The model execution request.""" diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d6f80c82b80b..4dc6c49eb58d 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,7 +4,8 @@ import torch -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) SeqId = int @@ -94,7 +95,7 @@ def create_sequence_group_output( for topk_logprob_index, _ in enumerate(topk_token_ids) }) - return SequenceGroupOutput( + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py new file mode 100644 index 000000000000..2d3f160c60dc --- /dev/null +++ b/vllm/worker/embedding_model_runner.py @@ -0,0 +1,266 @@ +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import BatchType, ModelRunner + +logger = init_logger(__name__) + + +class EmbeddingModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> Optional[PoolerOutput]: + (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) + + return self.model.pooler(hidden_states=hidden_states, + pooling_metadata=pooling_metadata) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, + Set[LoRARequest], LoRAMapping, torch.Tensor]: + if self.is_driver_worker: + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) + + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + prompt_lens) + + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, + } + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + assert decode_attn_metadata is not None + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_input = metadata_dict.pop("multi_modal_input") + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + pooling_metadata = PoolingMetadata(seq_groups=None, + seq_data=None, + prompt_lens=None) + + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + + return (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fc76c614216..21d76fd531e4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,6 +1,6 @@ import time from enum import IntEnum -from typing import Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch @@ -287,18 +287,18 @@ def _prepare_prompt( lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend( - [lora_id] * - (seq_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + lora_prompt_mapping.extend([lora_id] * ( + seq_len - context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( seq_group_metadata.multi_modal_data.data) - if seq_group_metadata.block_tables is None: + if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. + # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue @@ -813,7 +813,6 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1139,3 +1138,15 @@ def _prepare_fake_inputs( prompt_tokens = [0] * seq_len fake_image_input = None return SequenceData(prompt_tokens), fake_image_input + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ca9c2b64cf3..e4fbc877b8c9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -16,8 +16,9 @@ init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -68,7 +69,9 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self.model_runner = ModelRunner( + ModelRunnerClass = (EmbeddingModelRunner if + self.model_config.embedding_mode else ModelRunner) + self.model_runner = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -83,7 +86,8 @@ def __init__( # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CacheEngine - self.gpu_cache: List[torch.Tensor] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[torch.tensor]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -209,7 +213,7 @@ def cache_swap( def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: if execute_model_req is None: seq_group_metadata_list = None