Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] LLaVA model refactor #4910

Merged
merged 2 commits into from
May 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 107 additions & 30 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from torch import nn
Expand Down Expand Up @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds


class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""


class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]


class LlavaForConditionalGeneration(VisionLanguageModelBase):

def __init__(self,
Expand Down Expand Up @@ -102,6 +117,90 @@ def __init__(self,
config.vocab_size, logit_scale)
self.sampler = Sampler()

def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")

return data

def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if data is None:
return None

if expected_input_type == ImageInputType.PIXEL_VALUES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image pixel vector should be a tensor, "
f"but received type: {type(data)}")

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
)
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image feature vector should be a tensor, "
f"but received type: {type(data)}")

return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
)

return None

def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features

raise ValueError(f"Unexpected select feature strategy: {strategy}")

def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)

image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]

return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)

def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)

def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]

return self.multi_modal_projector(image_features)

def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
Expand Down Expand Up @@ -144,42 +243,20 @@ def forward(self,
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
parsed_image_input = self._parse_and_validate_image_input(image_input)

if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)

input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
Expand Down
Loading