-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_map.py
139 lines (122 loc) · 4.05 KB
/
feature_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Library imports
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import argparse
# File imports
from loss_networks import FeatureExtractor, extractor_collector
from workspace_path import home_path
# Logging, checkpointing, and data directories
log_dir = home_path / 'logs'
checkpoint_dir = home_path / 'checkpoints'
image_dir = home_path / 'images'
def show_feature_map(loss_network_path, image_ids):
'''
Takes a path to a network and a list of strings identifying an image.
Finds the first image with those strings in its path and displays the
feauture maps of the network for that image.
Send terminal inputs to generate next feature map.
Args:
loss_network_path (str): Path to a network with feature extraction
image_ids ([str]): List of strings used to identify an image
'''
# Transforms to get image into a preferable format for Torchvision
image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])
# Load loss network
loss_network = torch.load(loss_network_path)
# Find the image and augs from ids
images = []
for dir in image_dir.iterdir():
if not dir.is_dir() or dir.stem == 'ref':
continue
for image in dir.iterdir():
if all([i in str(image) for i in image_ids]):
images.append(image)
if len(images) > 0:
break
# Load images
images = torch.cat([
image_transform(
Image.open(i).convert(mode='RGB')
).unsqueeze(dim=0) for i in images
], dim=0)
# Get feature maps
maps = loss_network(images)
cn_maps = [F.normalize(y) for y in maps]
for i, (y, cn_y) in enumerate(zip(maps, cn_maps)):
sz = y.size()
m = y.permute(1,2,0,3).reshape(sz[1],sz[2],sz[0]*sz[3])
cn_m = cn_y.permute(1,2,0,3).reshape(sz[1],sz[2],sz[0]*sz[3])
maps[i] = torch.cat([m, cn_m], dim=1)
# Show feature maps
to_pil = torchvision.transforms.ToPILImage()
for ys in maps:
for y in ys:
to_pil(y.detach()).resize((768,512),resample=Image.BOX).show()
input()
def run_feature_mapping():
'''
Visualizes the feature maps of the image identified by input strings
for the selected network.
Send inputs through the terminal to generate the next feature map.
'''
# Create parser and parse input
parser = argparse.ArgumentParser()
parser.add_argument(
'--image',
type=str,
default=['3ring','color'],
nargs='+',
help='Selects the first image '
)
parser.add_argument(
'--network',
type=str,
default='alexnet',
choices=['alexnet','squeezenet','vgg16'],
help='Network for which to generate feature maps'
)
args = parser.parse_args()
# Find or create the feature extraction networks
alexnet_path = extractor_collector(
FeatureExtractor,
architecture = 'alexnet',
layers = [1,4,7,9,11],
pretrained=True,
frozen = True,
flatten_layer = False,
normalize_in = False,
)
squeezenet_path = extractor_collector(
FeatureExtractor,
architecture = 'squeezenet1_1',
layers = [1,4,7,9,10,11,12],
pretrained=True,
frozen = True,
flatten_layer = False,
normalize_in = False,
)
vgg16_path = extractor_collector(
FeatureExtractor,
architecture = 'vgg16',
layers = [3,8,15,22,29],
pretrained=True,
frozen = True,
flatten_layer = False,
normalize_in = False,
)
path = vgg16_path
if args.network == 'alexnet':
path = alexnet_path
elif args.network == 'squeezenet':
path = squeezenet_path
# Visualize the feature maps
show_feature_map(path, args.image)
# When this file is executed independently, execute the main function
if __name__ == "__main__":
run_feature_mapping()