Skip to content

Commit

Permalink
[feat] LlavaNext add feature size check to avoid CUDA Runtime Error (#…
Browse files Browse the repository at this point in the history
…33608)

* [feat] add feature size check to avoid CUDA Runtime Error

* [minor] add error handling to all llava models

* [minor] avoid nested if else

* [minor] add error message to Qwen2-vl and chameleon

* [fix] token dimension for check

* [minor] add feature dim check for videos too

* [fix] dimension check

* [fix] test reference values

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
  • Loading branch information
laurentd-lunit and zucchini-nlp authored Oct 15, 2024
1 parent d00f1ca commit 0f49dea
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,12 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -976,6 +982,12 @@ def forward(
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -491,6 +497,12 @@ def forward(
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,12 @@ def forward(
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down Expand Up @@ -647,7 +652,12 @@ def forward(
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)

n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_video_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,12 @@ def forward(
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
Expand All @@ -1722,6 +1728,12 @@ def forward(
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_outputs is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -626,8 +632,13 @@ def forward(
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

if video_outputs is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens

def get_config(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/vipllava/test_modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens

def get_config(self):
Expand Down

0 comments on commit 0f49dea

Please sign in to comment.