-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathplaces2_train.py
executable file
·62 lines (47 loc) · 1.93 KB
/
places2_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import random
import torch
import os
import glob
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision import utils
# mean and std channel values for places2 dataset
MEAN = [0.485, 0.456, 0.406]
STDDEV = [0.229, 0.224, 0.225]
# reverses the earlier normalization applied to the image to prepare output
def unnormalize(x):
x.transpose_(1, 3)
x = x * torch.Tensor(STDDEV) + torch.Tensor(MEAN)
x.transpose_(1, 3)
return x
class Places2Data (torch.utils.data.Dataset):
def __init__(self, path_to_data="/data_256", path_to_mask="/mask"):
super().__init__()
self.img_paths = glob.glob(os.path.dirname(os.path.abspath(__file__)) + path_to_data + "/**/*.jpg", recursive=True)
self.mask_paths = glob.glob(os.path.dirname(os.path.abspath(__file__)) + path_to_mask + "/*.png")
self.num_masks = len(self.mask_paths)
self.num_imgs = len(self.img_paths)
# normalizes the image: (img - MEAN) / STD and converts to tensor
self.img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STDDEV)])
self.mask_transform = transforms.ToTensor()
def __len__(self):
return self.num_imgs
def __getitem__(self, index):
gt_img = Image.open(self.img_paths[index])
gt_img = self.img_transform(gt_img.convert('RGB'))
mask = Image.open(self.mask_paths[random.randint(0, self.num_masks - 1)])
mask = self.mask_transform(mask.convert('RGB'))
return gt_img * mask, mask, gt_img
# Unit Test
if __name__ == '__main__':
places2 = Places2Data()
print(len(places2))
img, mask, gt = zip(*[places2[i] for i in range(1)]) # returns tuple of a single batch of 3x256x256 images
img = torch.stack(img) # --> i x 3 x 256 x 256
i = img == 0
print(i.sum())
mask = torch.stack(mask)
gt = torch.stack(gt)
grid = utils.make_grid(torch.cat((unnormalize(img), mask, unnormalize(gt)), dim=0))
utils.save_image(grid, "test.jpg")