-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_set.py
80 lines (68 loc) · 2.7 KB
/
data_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch.utils.data import Dataset
import logging
import os
from PIL import Image
import json
logger = logging.getLogger(__name__)
WORKING_PATH = "data"
class MyDataset(Dataset):
def __init__(self, mode, text_name, limit=None):
self.text_name = text_name
self.data = self.load_data(mode, limit)
self.image_ids = list(self.data.keys())
for id in self.data.keys():
self.data[id]["image_path"] = os.path.join(WORKING_PATH, "dataset_image/dataset_image", str(id) + ".jpg")
def load_data(self, mode, limit):
cnt = 0
data_set = dict()
if mode in ["train"]:
path = open(os.path.join(WORKING_PATH, self.text_name, mode + ".json"), 'r', encoding='utf-8')
datas = json.load(path)
for data in datas:
if (limit != None) and cnt >= limit:
break
image = data['image_id']
sentence = data['text']
label = data['label']
if os.path.isfile(os.path.join(WORKING_PATH, "dataset_image/dataset_image", str(image) + ".jpg")):
data_set[int(image)] = {"text": sentence, 'label': label}
cnt += 1
if mode in ["test", "valid"]:
f1 = open(os.path.join(WORKING_PATH, self.text_name, mode + ".json"), 'r', encoding='utf-8')
datas = json.load(f1)
for data in datas:
image = data['image_id']
sentence = data['text']
label = data['label']
if os.path.isfile(os.path.join(WORKING_PATH, "dataset_image/dataset_image", str(image) + ".jpg")):
data_set[int(image)] = {"text": sentence, 'label': label}
cnt += 1
return data_set
def image_loader(self, id):
return Image.open(self.data[id]["image_path"])
def text_loader(self, id):
return self.data[id]["text"]
def __getitem__(self, index):
id = self.image_ids[index]
text = self.text_loader(id)
image_feature = self.image_loader(id)
label = self.data[id]["label"]
return text, image_feature, label, id
def __len__(self):
return len(self.image_ids)
@staticmethod
def collate_fn(batch_data):
batch_size = len(batch_data)
if batch_size == 0:
return {}
text_list = []
image_list = []
label_list = []
id_list = []
for instance in batch_data:
text_list.append(instance[0])
image_list.append(instance[1])
label_list.append(instance[2])
id_list.append(instance[3])
return text_list, image_list, label_list, id_list