|
1 |
| -from typing import Iterable, List, Optional, Tuple |
| 1 | +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | from torch import nn
|
@@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
|
67 | 67 | return inputs_embeds
|
68 | 68 |
|
69 | 69 |
|
| 70 | +class LlavaImagePixelInputs(TypedDict): |
| 71 | + type: Literal["pixel_values"] |
| 72 | + data: torch.Tensor |
| 73 | + """Shape: (batch_size, num_channels, height, width)""" |
| 74 | + |
| 75 | + |
| 76 | +class LlavaImageFeatureInputs(TypedDict): |
| 77 | + type: Literal["image_features"] |
| 78 | + data: torch.Tensor |
| 79 | + """Shape: (batch_size, image_feature_size, hidden_size)""" |
| 80 | + |
| 81 | + |
| 82 | +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] |
| 83 | + |
| 84 | + |
70 | 85 | class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
71 | 86 |
|
72 | 87 | def __init__(self,
|
@@ -102,6 +117,90 @@ def __init__(self,
|
102 | 117 | config.vocab_size, logit_scale)
|
103 | 118 | self.sampler = Sampler()
|
104 | 119 |
|
| 120 | + def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: |
| 121 | + if list(data.shape[1:]) != list( |
| 122 | + self.vision_language_config.image_input_shape[1:]): |
| 123 | + raise ValueError( |
| 124 | + f"The expected image tensor shape is batch dimension plus " |
| 125 | + f"{self.vision_language_config.image_input_shape[1:]}. " |
| 126 | + f"You supplied {data.shape}. " |
| 127 | + f"If you are using vLLM's entrypoint, make sure your " |
| 128 | + f"supplied image input is consistent with " |
| 129 | + f"image_input_shape in engine args.") |
| 130 | + |
| 131 | + return data |
| 132 | + |
| 133 | + def _parse_and_validate_image_input( |
| 134 | + self, data: object) -> Optional[LlavaImageInputs]: |
| 135 | + expected_input_type = self.vision_language_config.image_input_type |
| 136 | + ImageInputType = VisionLanguageConfig.ImageInputType |
| 137 | + |
| 138 | + if data is None: |
| 139 | + return None |
| 140 | + |
| 141 | + if expected_input_type == ImageInputType.PIXEL_VALUES: |
| 142 | + if not isinstance(data, torch.Tensor): |
| 143 | + raise TypeError("Image pixel vector should be a tensor, " |
| 144 | + f"but received type: {type(data)}") |
| 145 | + |
| 146 | + return LlavaImagePixelInputs( |
| 147 | + type="pixel_values", |
| 148 | + data=self._validate_image_data(data), |
| 149 | + ) |
| 150 | + elif expected_input_type == ImageInputType.IMAGE_FEATURES: |
| 151 | + if not isinstance(data, torch.Tensor): |
| 152 | + raise TypeError("Image feature vector should be a tensor, " |
| 153 | + f"but received type: {type(data)}") |
| 154 | + |
| 155 | + return LlavaImageFeatureInputs( |
| 156 | + type="image_features", |
| 157 | + data=self._validate_image_data(data), |
| 158 | + ) |
| 159 | + |
| 160 | + return None |
| 161 | + |
| 162 | + def _select_image_features(self, image_features: torch.Tensor, *, |
| 163 | + strategy: str) -> torch.Tensor: |
| 164 | + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa |
| 165 | + if strategy == "default": |
| 166 | + return image_features[:, 1:] |
| 167 | + elif strategy == "full": |
| 168 | + return image_features |
| 169 | + |
| 170 | + raise ValueError(f"Unexpected select feature strategy: {strategy}") |
| 171 | + |
| 172 | + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, |
| 173 | + pixel_values: torch.Tensor) -> torch.Tensor: |
| 174 | + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. |
| 175 | + image_outputs = vision_tower(pixel_values.to(vision_tower.device), |
| 176 | + output_hidden_states=True) |
| 177 | + |
| 178 | + image_features = image_outputs.hidden_states[ |
| 179 | + self.config.vision_feature_layer] |
| 180 | + |
| 181 | + return self._select_image_features( |
| 182 | + image_features, |
| 183 | + strategy=self.config.vision_feature_select_strategy, |
| 184 | + ) |
| 185 | + |
| 186 | + def _process_image_pixels(self, |
| 187 | + inputs: LlavaImagePixelInputs) -> torch.Tensor: |
| 188 | + assert self.vision_tower is not None |
| 189 | + |
| 190 | + pixel_values = inputs["data"] |
| 191 | + |
| 192 | + return self._image_pixels_to_features(self.vision_tower, pixel_values) |
| 193 | + |
| 194 | + def _process_image_input(self, |
| 195 | + image_input: LlavaImageInputs) -> torch.Tensor: |
| 196 | + if image_input["type"] == "pixel_values": |
| 197 | + assert self.vision_tower is not None |
| 198 | + image_features = self._process_image_pixels(image_input) |
| 199 | + else: |
| 200 | + image_features = image_input["data"] |
| 201 | + |
| 202 | + return self.multi_modal_projector(image_features) |
| 203 | + |
105 | 204 | def forward(self,
|
106 | 205 | input_ids: torch.Tensor,
|
107 | 206 | positions: torch.Tensor,
|
@@ -144,42 +243,20 @@ def forward(self,
|
144 | 243 | For PIXEL_VALUES, expecting [1, 3, 336, 336].
|
145 | 244 | For IMAGE_FEATURES, expecting [1, 576, 1024].
|
146 | 245 | """
|
147 |
| - if image_input is not None: |
148 |
| - if list(image_input.shape[1:]) != list( |
149 |
| - self.vision_language_config.image_input_shape[1:]): |
150 |
| - raise ValueError( |
151 |
| - f"The expected image tensor shape is batch dimension " |
152 |
| - f"plus " |
153 |
| - f"{self.vision_language_config.image_input_shape[1:]}." |
154 |
| - f" You supplied {image_input.shape}. " |
155 |
| - f"If you are using vLLM's entrypoint, make sure your " |
156 |
| - f"supplied image input is consistent with " |
157 |
| - f"image_input_shape in engine args.") |
158 |
| - if self.vision_tower is not None: |
159 |
| - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. |
160 |
| - image_outputs = self.vision_tower(image_input, |
161 |
| - output_hidden_states=True) |
162 |
| - image_features = image_outputs.hidden_states[ |
163 |
| - self.config.vision_feature_layer] |
164 |
| - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa |
165 |
| - if self.config.vision_feature_select_strategy == "default": |
166 |
| - image_features = image_features[:, 1:] |
167 |
| - elif self.config.vision_feature_select_strategy == "full": |
168 |
| - image_features = image_features |
169 |
| - else: |
170 |
| - raise ValueError( |
171 |
| - f"Unexpected select feature strategy: " |
172 |
| - f"{self.config.vision_feature_select_strategy}") |
173 |
| - else: |
174 |
| - image_features = image_input |
175 |
| - vision_embeddings = self.multi_modal_projector(image_features) |
| 246 | + parsed_image_input = self._parse_and_validate_image_input(image_input) |
| 247 | + |
| 248 | + if parsed_image_input is not None: |
| 249 | + vision_embeddings = self._process_image_input(parsed_image_input) |
176 | 250 | inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
| 251 | + |
177 | 252 | inputs_embeds = _merge_vision_embeddings(
|
178 | 253 | input_ids, inputs_embeds, vision_embeddings,
|
179 | 254 | self.vision_language_config.image_token_id)
|
| 255 | + |
180 | 256 | input_ids = None
|
181 | 257 | else:
|
182 | 258 | inputs_embeds = None
|
| 259 | + |
183 | 260 | hidden_states = self.language_model(input_ids,
|
184 | 261 | positions,
|
185 | 262 | kv_caches,
|
|
0 commit comments