Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualbert VQA model inference lower accuracy in validation around 40% by huggingface framework #45

Open
guanhdrmq opened this issue Nov 21, 2023 · 0 comments

Comments

@guanhdrmq
Copy link

class VQADataset(torch.utils.data.Dataset):
"""VQA (v2) dataset."""
def __init__(self, questions, annotations, tokenizer, image_preprocess, frcnn, frcnn_cfg):
self.questions = questions
self.annotations = annotations
self.tokenizer = tokenizer
self.image_preprocess = image_preprocess
self.frcnn = frcnn
self.frcnn_cfg = frcnn_cfg

def __len__(self):
return len(self.annotations)

def __getitem__(self, idx):
 # answer
annotation = self.annotations[idx]
#  question
questions = self.questions[idx]
image_path = id_to_filename[annotation["image_id"]]
image_path = image_path.replace("./multimodal_data/vqa2/val2014/.", "", 1)
text = questions['question']

inputs = self.tokenizer(
     text,
     padding="max_length",
     max_length=25,
     truncation=True,
     return_token_type_ids=True,
     return_attention_mask=True,
     add_special_tokens=True,
     return_tensors="pt")


images, sizes, scales_yx = self.image_preprocess(image_path)
output_dict = self.frcnn(
                     images,
                     sizes,
                     scales_yx=scales_yx,
                     padding="max_detections",
                     max_detections=self.frcnn_cfg.max_detections,
                     return_tensors="pt")

# Very important that the boxes are normalized
feature = output_dict.get("roi_features")
normalized_boxes = output_dict.get("normalized_boxes")

inputs.update(
    {
     "visual_embeds": feature,
     "visual_attention_mask": torch.ones(feature.shape[:-1], dtype=torch.float),
     # "visual_token_type_ids": torch.ones(feature.shape[:-1], dtype=torch.long),
     "output_attentions": False
     }
)

# remove batch dimension
for k, v in inputs.items():
     if isinstance(v, torch.Tensor):
        inputs[k] = v.squeeze()

# add labels
labels = annotation['labels']
# print("label candidate:", labels)
scores = annotation["scores"]

targets = torch.zeros(len(config.id2label), dtype=torch.float)
for label, score in zip(labels, scores):
    # print(f"Setting target at index {label} to {score}")
    targets[label] = score
inputs["labels"] = targets
inputs["text"] = text

print(text)
return inputs

from visualbert.processing_image import Preprocess
from visualbert.visualizing_image import SingleImageViz
from visualbert.modeling_frcnn import GeneralizedRCNN
from visualbert.utils import Config

frcnn_cfg = Config.from_pretrained("unc-nlp/frcnn-vg-finetuned")
frcnn = GeneralizedRCNN.from_pretrained("unc-nlp/frcnn-vg-finetuned", config=frcnn_cfg)
image_preprocess = Preprocess(frcnn_cfg)

from transformers import VisualBertForQuestionAnswering, AutoTokenizer, BertTokenizerFast
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa",
num_labels=len(config.id2label),
id2label=config.id2label,
label2id=config.label2id,
output_hidden_states=True)

model.to(device)
model.eval()

dataset = VQADataset(questions=questions[:100],
annotations=annotations[:100],
tokenizer=tokenizer,
image_preprocess=image_preprocess,
frcnn=frcnn,
frcnn_cfg=frcnn_cfg)

test_dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
correct = 0.0
total = 0

for batch in tqdm(test_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
logits = outputs.logits # [batch_size, 3129]
_, pre = torch.max(logits, 1)
_, target = torch.max(batch["labels"], 1)
print("prediction:", pre)
print("target:", target)
print("Predicted answer:", model.config.id2label[pre.item()])
print("Target answer:", model.config.id2label[target.item()])
correct += (pre == target).sum()
total = total + 1
print(total)

final_acc = correct / float(len(test_dataloader.dataset))
print('Accuracy of test: %f %%' % (100 * float(final_acc)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant