-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathcaption_dataset.py
63 lines (52 loc) · 2.78 KB
/
caption_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE
import glob
from torch.utils.data import Dataset
from dataset.utils import *
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class Caption(Dataset):
def __init__(self, config, train=True):
self.data_path = config['data_path']
self.label_path = config['label_path']
self.experts = config['experts']
self.prefix = config['prefix']
self.dataset = config['dataset']
self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=train)
self.train = train
if train:
self.data_list = []
if self.dataset in ['coco', 'nocaps']:
self.data_list += json.load(open(os.path.join(self.data_path, 'coco_karpathy_train.json'), 'r'))
else:
if self.dataset == 'coco':
self.data_list = json.load(open(os.path.join(self.data_path, 'coco_karpathy_test.json'), 'r'))
elif self.dataset == 'nocaps':
self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
elif self.dataset == 'demo':
data_folders = glob.glob(f'{self.data_path}/*/')
self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
data = self.data_list[index]
if self.dataset == 'coco':
image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'vqav2', self.experts)
elif self.dataset == 'nocaps':
image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, data['image'], 'nocaps', self.experts)
elif self.dataset == 'demo':
img_path_split = self.data_list[index]['image'].split('/')
img_name = img_path_split[-2] + '/' + img_path_split[-1]
image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
experts = self.transform(image, labels)
experts = post_label_process(experts, labels_info)
if self.train:
caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
return experts, caption
else:
return experts, index