Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] LLaVA model refactor #4910

Merged
merged 2 commits into from
May 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 109 additions & 30 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from torch import nn
Expand Down Expand Up @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds


class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]


class LlavaForConditionalGeneration(VisionLanguageModelBase):

def __init__(self,
Expand Down Expand Up @@ -102,6 +117,92 @@ def __init__(self,
config.vocab_size, logit_scale)
self.sampler = Sampler()

def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")

return data

def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if expected_input_type == ImageInputType.PIXEL_VALUES:
if data is None:
return None

if not isinstance(data, torch.Tensor):
raise ValueError("Incorrect type of pixel values")

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
)

if expected_input_type == ImageInputType.IMAGE_FEATURES:
if data is None:
return None

if not isinstance(data, torch.Tensor):
raise ValueError("Incorrect type of image features")

return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
)

return None

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]

return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)

def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]

return self.multi_modal_projector(image_features)

def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
Expand Down Expand Up @@ -144,42 +245,20 @@ def forward(self,
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
parsed_image_input = self._parse_and_validate_image_input(image_input)

if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)

input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
Expand Down
Loading