From 4abed65c5806d0514432d102f959a1c84d341171 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 30 Aug 2024 08:49:04 +0800 Subject: [PATCH] [VLM] Disallow overflowing `max_model_len` for multimodal models (#7998) --- tests/models/test_llava.py | 17 +++++++++++++++++ vllm/engine/llm_engine.py | 21 ++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 93634f245cee..9d7da5f803ea 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -179,3 +179,20 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +def test_context_length_too_short(vllm_runner, image_assets, model): + images = [asset.pil_image for asset in image_assets] + + with pytest.raises(ValueError, match="too long to fit into the model"): + vllm_model = vllm_runner( + model, + max_model_len=128, # LLaVA has a feature size of 576 + enforce_eager=True, + ) + + with vllm_model: + vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], + max_tokens=1, + images=[images[0]]) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 92c02072593e..59baf1ef40df 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2010,7 +2010,22 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - prompt_key = "encoder_prompt_token_ids" \ - if self.is_encoder_decoder_model() else "prompt_token_ids" - if not inputs.get(prompt_key): + if self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") + + if self.model_config.multimodal_config is not None: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.")