Skip to content

Commit

Permalink
Fix LLaVA-NeXT input processor and cleanup code
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jun 5, 2024
1 parent a38b347 commit 9cfbcce
Showing 1 changed file with 145 additions and 72 deletions.
217 changes: 145 additions & 72 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, LlavaNextConfig,
PretrainedConfig)
PretrainedConfig, PreTrainedTokenizerBase)
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape)

Expand Down Expand Up @@ -67,25 +67,28 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)


def _get_llava_next_image_feature_size(hf_config: LlavaNextConfig) -> int:
def _get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
num_patches = _get_clip_num_patches(vision_config)
base_feature_size = num_patches * num_patches

# Results in the max possible feature size
dummy_height, dummy_width = 448, 448
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(dummy_height, dummy_width),
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)

(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(dummy_height, dummy_width,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches,
num_patch_height,
num_patch_width)
Expand All @@ -105,10 +108,8 @@ class DummyImageDataFactories:
"""

@classmethod
def _dummy_data_for_clip(
def _dummy_seq_data_for_clip(
cls,
model_config: ModelConfig,
multimodal_config: VisionLanguageConfig,
hf_config: CLIPVisionConfig,
seq_len: int,
*,
Expand All @@ -122,22 +123,40 @@ def _dummy_data_for_clip(

token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
seq_data = SequenceData(token_ids)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
width = height = hf_config.image_size
image = Image.new("RGB", (width, height), color=0)
multi_modal_data = ImagePixelData(image)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
depth = hf_config.hidden_size
values = torch.zeros((1, image_feature_size, depth),
dtype=torch.float16)
multi_modal_data = ImageFeatureData(values)

return seq_data, multi_modal_data
return SequenceData(token_ids)

@classmethod
def _dummy_pixel_data_for_clip(
cls,
hf_config: CLIPVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override

image = Image.new("RGB", (width, height), color=0)
return ImagePixelData(image)

@classmethod
def _dummy_feature_data_for_clip(
cls,
hf_config: CLIPVisionConfig,
*,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = _get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
dtype=torch.float16)
return ImageFeatureData(values)

@classmethod
def _dummy_data_for_llava(
Expand All @@ -150,14 +169,24 @@ def _dummy_data_for_llava(
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
return cls._dummy_data_for_clip(
model_config,
multimodal_config,
seq_data = cls._dummy_seq_data_for_clip(
vision_config,
seq_len=seq_len,
seq_len,
image_token_id=hf_config.image_token_index,
)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
multi_modal_data = cls._dummy_pixel_data_for_clip(
vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
multi_modal_data = cls._dummy_feature_data_for_clip(
vision_config)

return seq_data, multi_modal_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

Expand All @@ -170,18 +199,37 @@ def _dummy_data_for_llava_next(
seq_len: int,
):
vision_config = hf_config.vision_config
image_feature_size = _get_llava_next_image_feature_size(hf_config)

# Result in the max possible feature size
dummy_height = dummy_width = 448
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)

if isinstance(vision_config, CLIPVisionConfig):
return cls._dummy_data_for_clip(
model_config,
multimodal_config,
seq_data = cls._dummy_seq_data_for_clip(
vision_config,
seq_len=seq_len,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)

image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
multi_modal_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
multi_modal_data = cls._dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
multi_modal_data = cls._dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)

return seq_data, multi_modal_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

Expand Down Expand Up @@ -244,52 +292,45 @@ def _repeat_and_pad_token(
@classmethod
def _repeat_and_pad_image_tokens(
cls,
model_config: ModelConfig,
llm_inputs: LLMInputs,
tokenizer: PreTrainedTokenizerBase,
prompt: Optional[str],
prompt_token_ids: List[int],
*,
image_token_id: int,
repeat_count: int = 1,
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None:
return llm_inputs

tokenizer = _cached_get_tokenizer(model_config.tokenizer)
image_token_str = tokenizer.decode(image_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))

replacement_str = "".join(
cls._repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
replacement_ids = cls._repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)

) -> Tuple[Optional[str], List[int]]:
# To avoid invoking the tokenizer, we assume that the
# image token is called "<image>"
prompt = llm_inputs.get("prompt")
if prompt is None:
new_prompt = None
else:
image_token_str = tokenizer.decode(image_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
cls._repeat_and_pad_token(
image_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))

# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(image_token_str, replacement_str, 1)

prompt_token_ids = llm_inputs["prompt_token_ids"]
new_token_ids: List[int] = []
for i, token in enumerate(prompt_token_ids):
if token == image_token_id:
replacement_ids = cls._repeat_and_pad_token(
image_token_id,
repeat_count=repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)

# No need to further scan the list since we only replace once
Expand All @@ -298,10 +339,7 @@ def _repeat_and_pad_image_tokens(
else:
new_token_ids.append(token)

# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return new_prompt, new_token_ids

@classmethod
def _input_processor_for_clip(
Expand All @@ -314,18 +352,31 @@ def _input_processor_for_clip(
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or not isinstance(
multi_modal_data, (ImagePixelData, ImageFeatureData)):
return llm_inputs

tokenizer = _cached_get_tokenizer(model_config.tokenizer)

if image_feature_size_override is None:
image_feature_size = _get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override

return cls._repeat_and_pad_image_tokens(
model_config,
llm_inputs,
new_prompt, new_token_ids = cls._repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)

# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)

@classmethod
def _input_processor_for_llava(
cls,
Expand All @@ -334,6 +385,11 @@ def _input_processor_for_llava(
hf_config: LlavaConfig,
llm_inputs: LLMInputs,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or not isinstance(
multi_modal_data, (ImagePixelData, ImageFeatureData)):
return llm_inputs

vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
Expand All @@ -356,8 +412,25 @@ def _input_processor_for_llava_next(
hf_config: LlavaNextConfig,
llm_inputs: LLMInputs,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or not isinstance(
multi_modal_data, (ImagePixelData, ImageFeatureData)):
return llm_inputs

if isinstance(multi_modal_data, ImagePixelData):
image = multi_modal_data.image
if isinstance(image, torch.Tensor):
_, _, _, height, width = image.shape
else:
width, height = image.size

image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=height, input_width=width)
else:
image_features = multi_modal_data.image_features
image_feature_size = image_features.shape[-2]

vision_config = hf_config.vision_config
image_feature_size = _get_llava_next_image_feature_size(hf_config)

if isinstance(vision_config, CLIPVisionConfig):
return cls._input_processor_for_clip(
Expand Down

0 comments on commit 9cfbcce

Please sign in to comment.