-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
60 lines (52 loc) · 2.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_self_loops
def add_inverse_rels(edge_index, rel):
edge_index_all = torch.cat([edge_index, edge_index[[1,0]]], dim=1)
rel_all = torch.cat([rel, rel+rel.max()+1])
return edge_index_all, rel_all
def get_train_batch(x1, x2, train_set, k=5):
e1_neg1 = torch.cdist(x1[train_set[:, 0]], x1, p=1).topk(k+1, largest=False)[1].t()[1:]
e1_neg2 = torch.cdist(x1[train_set[:, 0]], x2, p=1).topk(k+1, largest=False)[1].t()[1:]
e2_neg1 = torch.cdist(x2[train_set[:, 1]], x2, p=1).topk(k+1, largest=False)[1].t()[1:]
e2_neg2 = torch.cdist(x2[train_set[:, 1]], x1, p=1).topk(k+1, largest=False)[1].t()[1:]
train_batch = torch.stack([e1_neg1, e1_neg2, e2_neg1, e2_neg2], dim=0)
return train_batch
def get_hits(x1, x2, pair, dist='L1', Hn_nums=(1, 10)):
pair_num = pair.size(0)
S = torch.cdist(x1[pair[:, 0]], x2[pair[:, 1]], p=1)
print('Left:\t',end='')
for k in Hn_nums:
pred_topk= S.topk(k, largest=False)[1]
Hk = (pred_topk == torch.arange(pair_num, device=S.device).view(-1, 1)).sum().item()/pair_num
print('Hits@%d: %.2f%% ' % (k, Hk*100),end='')
rank = torch.where(S.sort()[1] == torch.arange(pair_num, device=S.device).view(-1, 1))[1].float()
MRR = (1/(rank+1)).mean().item()
print('MRR: %.3f' % MRR)
print('Right:\t',end='')
for k in Hn_nums:
pred_topk= S.t().topk(k, largest=False)[1]
Hk = (pred_topk == torch.arange(pair_num, device=S.device).view(-1, 1)).sum().item()/pair_num
print('Hits@%d: %.2f%% ' % (k, Hk*100),end='')
rank = torch.where(S.t().sort()[1] == torch.arange(pair_num, device=S.device).view(-1, 1))[1].float()
MRR = (1/(rank+1)).mean().item()
print('MRR: %.3f' % MRR)
def get_hits_stable(x1, x2, pair):
pair_num = pair.size(0)
S = -torch.cdist(x1[pair[:, 0]], x2[pair[:, 1]], p=1).cpu()
#index = S.flatten().argsort(descending=True)
index = (S.softmax(1)+S.softmax(0)).flatten().argsort(descending=True)
index_e1 = index//pair_num
index_e2 = index%pair_num
aligned_e1 = torch.zeros(pair_num, dtype=torch.bool)
aligned_e2 = torch.zeros(pair_num, dtype=torch.bool)
true_aligned = 0
for _ in range(pair_num*100):
if aligned_e1[index_e1[_]] or aligned_e2[index_e2[_]]:
continue
if index_e1[_] == index_e2[_]:
true_aligned += 1
aligned_e1[index_e1[_]] = True
aligned_e2[index_e2[_]] = True
print('Both:\tHits@Stable: %.2f%% ' % (true_aligned/pair_num*100))