Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2-VL Doesn't Execute on TPUs #33289

Open
1 of 4 tasks
radna0 opened this issue Sep 4, 2024 · 3 comments
Open
1 of 4 tasks

Qwen2-VL Doesn't Execute on TPUs #33289

radna0 opened this issue Sep 4, 2024 · 3 comments
Labels
bug Feature request Request for a new feature TPU

Comments

@radna0
Copy link

radna0 commented Sep 4, 2024

System Info

  • transformers version: 4.45.0.dev0
  • Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0.dev20240830+cpu (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

#Following this Qwen2-VL guide => https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct#quickstart

  1. Script
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import numpy as np
import torch
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs

from torch.distributed._tensor import DeviceMesh, distribute_module
from torch_xla.distributed.spmd import auto_policy

from torch_xla import runtime as xr
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
    _prepare_spmd_partition_spec,
    SpmdFullyShardedDataParallel as FSDPv2,
)

import time

start = time.time()

device = xm.xla_device()

# default: Load the model on the available device(s)
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-2B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
).to(device)


print(model.device)

# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")


message = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "image1.jpg",
            },
            {"type": "text", "text": "Describe this image in detail."},
        ],
    }
]

all_messages = [[message] for _ in range(1)]
for messages in all_messages:

    # Preparation for inference
    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages
    ]

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )
    for i, text in enumerate(output_text):
        print(f"Output {i}: {text}")

print(f"Time taken: {time.time() - start}")
  1. Output Logs
kojoe@t1v-n-cb70f560-w-0:~/EasyAnimate/easyanimate/image_caption$ python caption.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.39it/s]
xla:0

Expected behavior

The model works fine when chaging device to "cpu", but stuck executing on TPUs. The model should run on TPUs

@radna0 radna0 added the bug label Sep 4, 2024
Copy link

github-actions bot commented Oct 4, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker ArthurZucker added Feature request Request for a new feature TPU labels Oct 5, 2024
@ArthurZucker
Copy link
Collaborator

Hey! Don't think we officially test nor support TPU for this model 🤗 I can't really reproduce 😢
@tengomucho might have an idea

@tengomucho
Copy link
Contributor

@radna0 transformers does not support officially TPU, but I think things might work if:

  • you use the raw model forward API, not generate
  • you use the static KV cache, that works best on TPUs with Pytorch XLA.
    I haven't tried this model myself, but you can try, using inspiration from this script. If it does not work, feel free to open a PR on optimum-tpu and we will try to help you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Feature request Request for a new feature TPU
Projects
None yet
Development

No branches or pull requests

3 participants