diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index ca592c7ce878..0da0ae042519 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -5,7 +5,7 @@ import torch from PIL import Image from transformers import (CLIPVisionConfig, LlavaConfig, LlavaNextConfig, - PretrainedConfig) + PretrainedConfig, PreTrainedTokenizerBase) from transformers.models.llava_next.modeling_llava_next import ( get_anyres_image_grid_shape) @@ -67,17 +67,20 @@ def _get_llava_next_num_unpadded_features( return (unpadded_features, newline_features) -def _get_llava_next_image_feature_size(hf_config: LlavaNextConfig) -> int: +def _get_llava_next_image_feature_size( + hf_config: LlavaNextConfig, + *, + input_height: int, + input_width: int, +) -> int: vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): num_patches = _get_clip_num_patches(vision_config) base_feature_size = num_patches * num_patches - # Results in the max possible feature size - dummy_height, dummy_width = 448, 448 num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(dummy_height, dummy_width), + image_size=(input_height, input_width), grid_pinpoints=hf_config.image_grid_pinpoints, patch_size=vision_config.image_size, ) @@ -85,7 +88,7 @@ def _get_llava_next_image_feature_size(hf_config: LlavaNextConfig) -> int: ( unpadded_feature_size, newline_feature_size, - ) = _get_llava_next_num_unpadded_features(dummy_height, dummy_width, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, num_patches, num_patch_height, num_patch_width) @@ -105,10 +108,8 @@ class DummyImageDataFactories: """ @classmethod - def _dummy_data_for_clip( + def _dummy_seq_data_for_clip( cls, - model_config: ModelConfig, - multimodal_config: VisionLanguageConfig, hf_config: CLIPVisionConfig, seq_len: int, *, @@ -122,22 +123,40 @@ def _dummy_data_for_clip( token_ids = [image_token_id] * image_feature_size token_ids += [0] * (seq_len - image_feature_size) - seq_data = SequenceData(token_ids) - - image_input_type = multimodal_config.image_input_type - ImageInputType = VisionLanguageConfig.ImageInputType - multi_modal_data: MultiModalData - if image_input_type == ImageInputType.PIXEL_VALUES: - width = height = hf_config.image_size - image = Image.new("RGB", (width, height), color=0) - multi_modal_data = ImagePixelData(image) - elif image_input_type == ImageInputType.IMAGE_FEATURES: - depth = hf_config.hidden_size - values = torch.zeros((1, image_feature_size, depth), - dtype=torch.float16) - multi_modal_data = ImageFeatureData(values) - - return seq_data, multi_modal_data + return SequenceData(token_ids) + + @classmethod + def _dummy_pixel_data_for_clip( + cls, + hf_config: CLIPVisionConfig, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, + ): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return ImagePixelData(image) + + @classmethod + def _dummy_feature_data_for_clip( + cls, + hf_config: CLIPVisionConfig, + *, + image_feature_size_override: Optional[int] = None, + ): + if image_feature_size_override is None: + image_feature_size = _get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + values = torch.zeros((1, image_feature_size, hf_config.hidden_size), + dtype=torch.float16) + return ImageFeatureData(values) @classmethod def _dummy_data_for_llava( @@ -150,14 +169,24 @@ def _dummy_data_for_llava( vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): - return cls._dummy_data_for_clip( - model_config, - multimodal_config, + seq_data = cls._dummy_seq_data_for_clip( vision_config, - seq_len=seq_len, + seq_len, image_token_id=hf_config.image_token_index, ) + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + multi_modal_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + multi_modal_data = cls._dummy_pixel_data_for_clip( + vision_config) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + multi_modal_data = cls._dummy_feature_data_for_clip( + vision_config) + + return seq_data, multi_modal_data + msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -170,18 +199,37 @@ def _dummy_data_for_llava_next( seq_len: int, ): vision_config = hf_config.vision_config - image_feature_size = _get_llava_next_image_feature_size(hf_config) + + # Result in the max possible feature size + dummy_height = dummy_width = 448 + image_feature_size = _get_llava_next_image_feature_size( + hf_config, input_height=dummy_height, input_width=dummy_width) if isinstance(vision_config, CLIPVisionConfig): - return cls._dummy_data_for_clip( - model_config, - multimodal_config, + seq_data = cls._dummy_seq_data_for_clip( vision_config, - seq_len=seq_len, + seq_len, image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) + image_input_type = multimodal_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + multi_modal_data: MultiModalData + if image_input_type == ImageInputType.PIXEL_VALUES: + multi_modal_data = cls._dummy_pixel_data_for_clip( + vision_config, + image_width_override=dummy_width, + image_height_override=dummy_height, + ) + elif image_input_type == ImageInputType.IMAGE_FEATURES: + multi_modal_data = cls._dummy_feature_data_for_clip( + vision_config, + image_feature_size_override=image_feature_size, + ) + + return seq_data, multi_modal_data + msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -244,52 +292,45 @@ def _repeat_and_pad_token( @classmethod def _repeat_and_pad_image_tokens( cls, - model_config: ModelConfig, - llm_inputs: LLMInputs, + tokenizer: PreTrainedTokenizerBase, + prompt: Optional[str], + prompt_token_ids: List[int], *, image_token_id: int, repeat_count: int = 1, pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, - ) -> LLMInputs: - multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None: - return llm_inputs - - tokenizer = _cached_get_tokenizer(model_config.tokenizer) - image_token_str = tokenizer.decode(image_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) - - replacement_str = "".join( - cls._repeat_and_pad_token( - image_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - replacement_ids = cls._repeat_and_pad_token( - image_token_id, - repeat_count=repeat_count, - pad_token_left=pad_token_left, - pad_token_right=pad_token_right, - ) - + ) -> Tuple[Optional[str], List[int]]: # To avoid invoking the tokenizer, we assume that the # image token is called "" - prompt = llm_inputs.get("prompt") if prompt is None: new_prompt = None else: + image_token_str = tokenizer.decode(image_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + cls._repeat_and_pad_token( + image_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace new_prompt = prompt.replace(image_token_str, replacement_str, 1) - prompt_token_ids = llm_inputs["prompt_token_ids"] new_token_ids: List[int] = [] for i, token in enumerate(prompt_token_ids): if token == image_token_id: + replacement_ids = cls._repeat_and_pad_token( + image_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) new_token_ids.extend(replacement_ids) # No need to further scan the list since we only replace once @@ -298,10 +339,7 @@ def _repeat_and_pad_image_tokens( else: new_token_ids.append(token) - # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return new_prompt, new_token_ids @classmethod def _input_processor_for_clip( @@ -314,18 +352,31 @@ def _input_processor_for_clip( image_token_id: int, image_feature_size_override: Optional[int] = None, ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + + tokenizer = _cached_get_tokenizer(model_config.tokenizer) + if image_feature_size_override is None: image_feature_size = _get_clip_image_feature_size(hf_config) else: image_feature_size = image_feature_size_override - return cls._repeat_and_pad_image_tokens( - model_config, - llm_inputs, + new_prompt, new_token_ids = cls._repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], image_token_id=image_token_id, repeat_count=image_feature_size, ) + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + @classmethod def _input_processor_for_llava( cls, @@ -334,6 +385,11 @@ def _input_processor_for_llava( hf_config: LlavaConfig, llm_inputs: LLMInputs, ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): @@ -356,8 +412,25 @@ def _input_processor_for_llava_next( hf_config: LlavaNextConfig, llm_inputs: LLMInputs, ): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or not isinstance( + multi_modal_data, (ImagePixelData, ImageFeatureData)): + return llm_inputs + + if isinstance(multi_modal_data, ImagePixelData): + image = multi_modal_data.image + if isinstance(image, torch.Tensor): + _, _, _, height, width = image.shape + else: + width, height = image.size + + image_feature_size = _get_llava_next_image_feature_size( + hf_config, input_height=height, input_width=width) + else: + image_features = multi_modal_data.image_features + image_feature_size = image_features.shape[-2] + vision_config = hf_config.vision_config - image_feature_size = _get_llava_next_image_feature_size(hf_config) if isinstance(vision_config, CLIPVisionConfig): return cls._input_processor_for_clip(