-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconfuse_matrix.py
37 lines (25 loc) · 1.08 KB
/
confuse_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
import torch.nn.functional as F
NUM_CLASSES = 5
temperature = 2.0
def torch_tile(tensor, dim, n):
if dim == 0:
return tensor.unsqueeze(0).transpose(0,1).repeat(1,n,1).view(-1,tensor.shape[1])
else:
return tensor.unsqueeze(0).transpose(0,1).repeat(1,1,n).view(tensor.shape[0], -1)
def get_confuse_matrix(logits, labels):
source_prob = []
for i in range(NUM_CLASSES):
mask = torch_tile(torch.unsqueeze(labels[:, i], -1), 1, NUM_CLASSES)
logits_mask_out = logits * mask
logits_avg = torch.sum(logits_mask_out, dim=0) / ( torch.sum(labels[:, i]) + 1e-8 )
prob = F.softmax(logits_avg / temperature, dim=0)
source_prob.append(prob)
return torch.stack(source_prob)
def kd_loss(source_matrix, target_matrix):
loss_fn = torch.nn.MSELoss(reduction='none')
Q = source_matrix
P = target_matrix
loss = (F.kl_div(Q.log(), P, None, None, 'batchmean') + F.kl_div(P.log(), Q, None, None, 'batchmean'))/2.0
return loss