forked from vips4/I-Split
-
Notifications
You must be signed in to change notification settings - Fork 0
/
filters_extraction.py
59 lines (44 loc) · 1.75 KB
/
filters_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = "Federico Cunico, Luigi Capogrosso, Francesco Setti, \
Damiano Carra, Franco Fummi, Marco Cristani"
__version__ = "1.0.0"
__maintainer__ = "Federico Cunico, Luigi Capogrosso"
__email__ = "name.surname@univr.it"
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List
def get_filters(input_img: torch.Tensor,
fwd_layer_list: List[torch.nn.Module],
gradients: torch.Tensor,
normalize: bool = False,
show: bool = False) -> None:
# Pool the gradients across the channels.
pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
if len(input_img.shape) == 3:
input_img = input_img.unsqueeze(0)
# Get the activations of the last convolutional layer.
_input = input_img
for l in fwd_layer_list:
_input = l(_input)
# activations = fwd_layer_list(input_img).detach()
activations = _input.detach()
# Weight the channels by corresponding gradients.
chs = activations.shape[1]
for i in range(chs):
activations[:, i, :, :] *= pooled_gradients[i]
# Average the channels of the activations.
heatmap = torch.mean(activations, dim=1).squeeze().detach().cpu().numpy()
# ReLU on top of the heatmap (Eq. 2 of https://arxiv.org/pdf/1610.02391.pdf).
heatmap = np.maximum(heatmap, 0)
heatmap = torch.from_numpy(heatmap)
if normalize:
heatmap /= torch.max(heatmap) if torch.max(heatmap) != 0 else 1
# Draw the heatmap.
if show:
plt.figure()
plt.imshow(heatmap.squeeze())
plt.savefig("test.jpg")
plt.pause(0.5)
return heatmap