forked from Ema93sh/pytorch-saliency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
82 lines (58 loc) · 2.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import random
import argparse
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
from saliency import *
plt.ion()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--type',
help = 'Type of saliency to generate',
choices = ['guided', 'vanilla', 'deconv'],
default = 'vanilla',
required = False
)
choices_map = {
'guided': SaliencyMethod.GUIDED,
'vanilla': SaliencyMethod.VANILLA,
'deconv': SaliencyMethod.DECONV
}
args = parser.parse_args()
saliency_type = choices_map[args.type]
model = models.vgg16(pretrained=True)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
data_transform = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
cats_vs_dogs = datasets.ImageFolder(root = './images/', transform = data_transform)
dataset_loader = torch.utils.data.DataLoader(cats_vs_dogs,
batch_size=4,
shuffle=True,
num_workers=4)
random_image = random.choice(range(len(dataset_loader.dataset)))
input = dataset_loader.dataset[random_image][0].unsqueeze(0)
prediction = model(input)
_, prediction = prediction.max(1)
saliency = generate_saliency(model, input, prediction, type = saliency_type)
figure = plt.figure(figsize = (8, 8), facecolor='w')
plt.subplot(2, 2, 1)
plt.title("Original Image")
plt.imshow(input.squeeze(0).mean(0), cmap="gray")
plt.subplot(2, 2, 2)
plt.title("Positive Saliency")
plt.imshow(saliency[MapType.POSITIVE].mean(0), cmap='gray')
plt.subplot(2, 2, 3)
plt.title("Negative Saliency")
plt.imshow(saliency[MapType.NEGATIVE].mean(0), cmap='gray')
plt.subplot(2, 2, 4)
plt.title("Absolute Saliency")
plt.imshow(saliency[MapType.ABSOLUTE].mean(0), cmap='gray')
plt.show(block = True)
if __name__ == '__main__':
main()