-
Notifications
You must be signed in to change notification settings - Fork 89
/
txt2image_dataset.py
95 lines (73 loc) · 3.06 KB
/
txt2image_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import io
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import pdb
from PIL import Image
import torch
from torch.autograd import Variable
import pdb
import torch.nn.functional as F
class Text2ImageDataset(Dataset):
def __init__(self, datasetFile, transform=None, split=0):
self.datasetFile = datasetFile
self.transform = transform
self.dataset = None
self.dataset_keys = None
self.split = 'train' if split == 0 else 'valid' if split == 1 else 'test'
self.h5py2int = lambda x: int(np.array(x))
def __len__(self):
f = h5py.File(self.datasetFile, 'r')
self.dataset_keys = [str(k) for k in f[self.split].keys()]
length = len(f[self.split])
f.close()
return length
def __getitem__(self, idx):
if self.dataset is None:
self.dataset = h5py.File(self.datasetFile, mode='r')
self.dataset_keys = [str(k) for k in self.dataset[self.split].keys()]
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
# pdb.set_trace()
right_image = bytes(np.array(example['img']))
right_embed = np.array(example['embeddings'], dtype=float)
wrong_image = bytes(np.array(self.find_wrong_image(example['class'])))
inter_embed = np.array(self.find_inter_embed())
right_image = Image.open(io.BytesIO(right_image)).resize((64, 64))
wrong_image = Image.open(io.BytesIO(wrong_image)).resize((64, 64))
right_image = self.validate_image(right_image)
wrong_image = self.validate_image(wrong_image)
txt = np.array(example['txt']).astype(str)
sample = {
'right_images': torch.FloatTensor(right_image),
'right_embed': torch.FloatTensor(right_embed),
'wrong_images': torch.FloatTensor(wrong_image),
'inter_embed': torch.FloatTensor(inter_embed),
'txt': str(txt)
}
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
sample['wrong_images'] =sample['wrong_images'].sub_(127.5).div_(127.5)
return sample
def find_wrong_image(self, category):
idx = np.random.randint(len(self.dataset_keys))
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
_category = example['class']
if _category != category:
return example['img']
return self.find_wrong_image(category)
def find_inter_embed(self):
idx = np.random.randint(len(self.dataset_keys))
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
return example['embeddings']
def validate_image(self, img):
img = np.array(img, dtype=float)
if len(img.shape) < 3:
rgb = np.empty((64, 64, 3), dtype=np.float32)
rgb[:, :, 0] = img
rgb[:, :, 1] = img
rgb[:, :, 2] = img
img = rgb
return img.transpose(2, 0, 1)