diff --git a/vllm/config.py b/vllm/config.py index 7e0b75eceae5b..b84d91d402370 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -32,6 +32,7 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ "AquilaModel", @@ -571,6 +572,10 @@ def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" return self.embedding_mode + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + class CacheConfig: """Configuration for the KV cache. @@ -947,25 +952,36 @@ def __init__(self, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, + embedding_mode: bool = False, + is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, send_delta_data: bool = False) -> None: - if max_num_batched_tokens is not None: - self.max_num_batched_tokens = max_num_batched_tokens - else: + if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. - self.max_num_batched_tokens = 512 - elif embedding_mode: - # For embedding, choose specific value for higher throughput - self.max_num_batched_tokens = max( - max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens = max(max_model_len, 2048) + + if embedding_mode: + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + self.max_num_batched_tokens = max_num_batched_tokens + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6e66198e203fc..d98f57bc2d353 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -921,6 +921,7 @@ def create_engine_config(self) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index aa33933c668ed..1eab83f3b9889 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2019,7 +2019,7 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") - if self.model_config.multimodal_config is not None: + if self.model_config.is_multimodal_model: max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: @@ -2030,3 +2030,7 @@ def _validate_model_inputs(self, inputs: Union[LLMInputs, "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index 79c48896469e8..d73023e8e1724 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.multimodal_config is not None: + if enc_dec_mr.model_config.is_multimodal_model: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])