Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaVA-OneVision mismatch between image features and image tokens #34625

Closed
2 of 4 tasks
agadetsky opened this issue Nov 6, 2024 · 17 comments · Fixed by #34779
Closed
2 of 4 tasks

LLaVA-OneVision mismatch between image features and image tokens #34625

agadetsky opened this issue Nov 6, 2024 · 17 comments · Fixed by #34779

Comments

@agadetsky
Copy link

agadetsky commented Nov 6, 2024

System Info

  • transformers version: 4.46.2
  • Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@amyeroberts @qubvel @ArthurZucker @ITaz

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig
import numpy as np
import torch
from PIL import Image


model_id = "llava-hf/llava-onevision-qwen2-72b-ov-hf"

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained(model_id)

conversation = [
    {
        "role": "user",
        "content": [{"type": "text", "text": "Describe the image"}, {"type": "image"}] ,
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image = Image.fromarray(np.random.randn(243, 387, 3).astype('uint8'), 'RGB')
inputs = processor(
    images=image,
    text=prompt,
    return_tensors="pt"
).to(model.device, torch.float16)

output_ids = model.generate(**inputs, max_new_tokens=32)

Error is the following

ValueError                                Traceback (most recent call last)
Cell In[235], line 16
      9 image = Image.fromarray(np.random.randn(243, 387, 3).astype('uint8'), 'RGB')
     10 inputs = processor(
     11     images=image,
     12     text=prompt,
     13     return_tensors="pt"
     14 ).to(model.device, torch.float16)
---> 16 output_ids = model.generate(**inputs, max_new_tokens=32)

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:2215, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2207     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2208         input_ids=input_ids,
   2209         expand_size=generation_config.num_return_sequences,
   2210         is_encoder_decoder=self.config.is_encoder_decoder,
   2211         **model_kwargs,
   2212     )
   2214     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2215     result = self._sample(
   2216         input_ids,
   2217         logits_processor=prepared_logits_processor,
   2218         stopping_criteria=prepared_stopping_criteria,
   2219         generation_config=generation_config,
   2220         synced_gpus=synced_gpus,
   2221         streamer=streamer,
   2222         **model_kwargs,
   2223     )
   2225 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2226     # 11. prepare beam search scorer
   2227     beam_scorer = BeamSearchScorer(
   2228         batch_size=batch_size,
   2229         num_beams=generation_config.num_beams,
   (...)
   2234         max_length=generation_config.max_length,
   2235     )

File ~/.local/lib/python3.10/site-packages/transformers/generation/utils.py:3206, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3203 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   3205 # forward pass to get next token
-> 3206 outputs = self(**model_inputs, return_dict=True)
   3208 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3209 model_kwargs = self._update_model_kwargs_for_generation(
   3210     outputs,
   3211     model_kwargs,
   3212     is_encoder_decoder=self.config.is_encoder_decoder,
   3213 )

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.local/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/.local/lib/python3.10/site-packages/transformers/models/llava_onevision/modeling_llava_onevision.py:684, in LlavaOnevisionForConditionalGeneration.forward(self, input_ids, pixel_values, image_sizes, pixel_values_videos, image_sizes_videos, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, vision_aspect_ratio, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
    681 n_image_features = image_features.shape[0]
    683 if n_image_tokens != n_image_features:
--> 684     raise ValueError(
    685         f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
    686     )
    687 special_image_mask = (
    688     (input_ids == self.config.image_token_index)
    689     .unsqueeze(-1)
    690     .expand_as(inputs_embeds)
    691     .to(inputs_embeds.device)
    692 )
    693 image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)

ValueError: Image features and image tokens do not match: tokens: 1890, features 1944

Expected behavior

Given that LLaVA-OneVision can work with any resolutions, the model is expected to successfully generate the output.

@zucchini-nlp
Copy link
Member

@agadetsky , it seems like there are differences in how we compute number of image tokens in the processing code and in modeling. Might be related to prev bugs with numerical issues when the image resolution is on the edge case of all possible grid resolutiions (like 337 here). I'll take a look and see where is the precision error coming

@chenweize1998
Copy link
Contributor

Hi @zucchini-nlp , have you managed to identify the issue? I'm encountering the same error while using llava-hf/llava-v1.6-mistral-7b-hf. I haven't pinpointed the specific data causing the error, as it occurs midway through training. Could you also take a look at the modeling file of llava next? Maybe some calculation on the anyres is mismatched?

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Nov 18, 2024

@chenweize1998 yes, that is most probably the anyres calculations. Unfortunately I didn't have time to look in more detail, will try to have a look today

EDIT: found the place where there was precision error and opened a PR to fix

@chenweize1998
Copy link
Contributor

chenweize1998 commented Nov 18, 2024

@zucchini-nlp Thanks for looking into this! I've pinpointed the batch of data causing the issue and uploaded it here. The problem specifically originates from the first data point in the batch. Hope it helps with debugging.

Additionally, here’s a minimal script to reproduce the error (assuming the data point is downloaded as ./tmp.bin):

from transformers import AutoModelForVision2Seq
import torch

# Load the model
model = AutoModelForVision2Seq.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", 
    torch_dtype=torch.bfloat16
).to("cuda:0")

# Load the problematic input
inputs = torch.load("tmp.bin")
# Note: inputs['input_ids'][0] triggers the error

for k, v in inputs.items():
    inputs[k] = v.to("cuda:0")

# Generate outputs
outputs = model(**inputs)

I'm using torch==2.4.0 and transformers==4.46.2. Let me know if you need more details.

@agadetsky
Copy link
Author

Thank you @zucchini-nlp!

@atanasmatev
Copy link

Hi,
I am facing a similar problem when I try to train Qwen2-VL-7B with images data. "Image features and image tokens do not match" after loading the model. Just for the sake of testing, I commented out the four lines in the ".py" file from the transformers library that do the value comparison and it worked. However it is not supposed to work like this.
Can you please check and confirm whether the same bug affects the Qwen2 model?
Thanks.

@zucchini-nlp
Copy link
Member

@atanasmatev hey, Qwen2 has a different processing and has no unpadding like in LLaVA-OV. Can you open a new issue for it pls and provide a small code snippet for reproduction?

@mano3-1
Copy link

mano3-1 commented Nov 22, 2024

Hi @agadetsky ,
I am facing same issue with Qwen2-VL-7B with images data. Did you figure out a fix?

@zhangboshen
Copy link

Same issue with Qwen2-VL-7B with images data

@zucchini-nlp
Copy link
Member

For Qwen we have an issue here #33399 (comment)

But the issue is about shape errors on mps as device as mps had some weird bugs in vision LLMs. In case you are not on mps and still experience the error, please open an issue with a short reproducer. Since qwen code has not been changed since release, it might be something related to specific input images/resolution

@gouqi666
Copy link
Contributor

gouqi666 commented Jan 7, 2025

@agadetsky @zhangboshen @LysandreJik @chenweize1998 @atanasmatev anyone has solutions for this problem? i use the llamafactory to train qwen2-vl.

@atanasmatev
Copy link

For qwen2-vl-7b with llama-factory I changed one of the parameters in the config file for llama-factory from 2 to 1 (I am not near that PC to check which one exactly but the default was 2)

@gouqi666
Copy link
Contributor

gouqi666 commented Jan 8, 2025 via email

@chchch0109
Copy link

Hi @zucchini-nlp, I faced this issue again with transformers==4.47.1 The data caused it is the number 2482 in huggingface dataset "lmms-lab/docvqa" test split.

@zucchini-nlp
Copy link
Member

@chchch0109 it would be help me a lot if you can provide a small runnable code without much external dependencies :)

The numerical error bug from padding/unpadding should have been fixed by v4.47.1, so I can look if there are any other reasons to error out

@chchch0109
Copy link

@chchch0109 it would be help me a lot if you can provide a small runnable code without much external dependencies :)

The numerical error bug from padding/unpadding should have been fixed by v4.47.1, so I can look if there are any other reasons to error out

@zucchini-nlp sure

from transformers import AutoProcessor
from datasets import load_dataset
import torch
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16) 
dataset = load_dataset("lmms-lab/docvqa", 'DocVQA')

d = dataset['test'][2482]
question = d['question']
image = d['image']
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": question},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
    outputs = model(**inputs)

@sheryc
Copy link
Contributor

sheryc commented Jan 18, 2025

@chchch0109 it would be help me a lot if you can provide a small runnable code without much external dependencies :)
The numerical error bug from padding/unpadding should have been fixed by v4.47.1, so I can look if there are any other reasons to error out

@zucchini-nlp sure

from transformers import AutoProcessor
from datasets import load_dataset
import torch
processor = AutoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf")
model = LlavaOnevisionForConditionalGeneration.from_pretrained("llava-hf/llava-onevision-qwen2-7b-ov-hf", torch_dtype=torch.float16)
dataset = load_dataset("lmms-lab/docvqa", 'DocVQA')

d = dataset['test'][2482]
question = d['question']
image = d['image']
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": question},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)

I'm facing the same problem when using llava-hf/llava-onevision-qwen2-0.5b-ov-hf and transformers version 4.48.0. Might be better to open another issue for this? @zucchini-nlp

Edit: new issue for this: #35775

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants