-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
488 lines (189 loc) · 6.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import datetime
from lib2to3.pytree import BasePattern
import numpy as np
import os
import os.path as osp
import random
import logging
from pathlib import Path
import json
import time
import warnings
from contextlib import contextmanager
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
EPS = 1e-6
def get_date_postfix():
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def get_n_params(model):
pp = 0
for p in list(model.parameters()):
nn = 1
for s in list(p.size()):
nn = nn * s
pp += nn
return pp
def load_json(path):
with open(path, 'r') as f:
data = json.load(f)
return data
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def to_MB(byte):
return byte / 1024.0 / 1024.0
def combine_dicts(dicts):
result = {}
for d in dicts:
for key, value in d.items():
if key not in result:
result[key] = []
result[key].append(value)
for key, value in result.items():
result[key] = np.mean(value)
return result
def idx2mask(idx, num_nodes):
mask = torch.zeros(num_nodes, dtype=torch.bool)
mask[idx] = 1
return mask
# Generate random split from the dataset
def get_split(
num_samples: int,
train_ratio: float = 0.1,
test_ratio: float = 0.1,
ind_ratio: float = 0.0,
):
# The ind_ratio is used to split the test set
assert train_ratio + test_ratio < 1
train_size = int(num_samples * train_ratio)
val_size = int(num_samples * test_ratio)
test_size = num_samples - train_size - val_size
trans_size = int(test_size * (1 - ind_ratio))
indices = torch.randperm(num_samples)
if ind_ratio == 0:
return {
'train': indices[:train_size],
'valid': indices[train_size: val_size + train_size],
'test': indices[val_size + train_size:]
}
else:
return {
'train': indices[:train_size],
'valid': indices[train_size: val_size + train_size],
'test': indices[val_size + train_size: trans_size + val_size + train_size],
'ind_test': indices[trans_size + val_size + train_size:]
}
# Convert the split to the split with independent test set (inductive setting)
def to_ind_split(split, ind_ratio: float = 0.2):
test_split = split['test']
test_size = len(test_split)
trans_test_size = int(test_size * (1 - ind_ratio))
split['test'] = test_split[:trans_test_size]
split['ind_test'] = test_split[trans_test_size:]
return split
# Split the transductive set and inductive set
def split_data(data, split):
from torch_geometric.utils import subgraph
trans_nodes = torch.concat(
[split['train'],
split['valid'],
split['test']]).unique()
ind_nodes = split['ind_test'].unique()
x, y, edge_index, edge_attr = data.x, data.y, data.edge_index, data.edge_attr
trans_edge_index, trans_edge_attr = subgraph(
trans_nodes,
edge_index,
edge_attr,
relabel_nodes=True,
num_nodes=x.shape[0],
return_edge_mask=False
)
ind_edge_index, ind_edge_attr = subgraph(
ind_nodes,
edge_index,
edge_attr,
relabel_nodes=True,
num_nodes=x.shape[0],
return_edge_mask=False
)
trans_data = data.clone()
trans_data.x = x[trans_nodes]
trans_data.y = y[trans_nodes]
trans_data.edge_index = trans_edge_index
trans_data.edge_attr = trans_edge_attr
ind_data = data.clone()
ind_data.x = x[ind_nodes]
ind_data.y = y[ind_nodes]
ind_data.edge_index = ind_edge_index
ind_data.edge_attr = ind_edge_attr
return trans_data, ind_data
def accuracy(y_pred, y_true):
y_true = y_true.squeeze().long()
preds = y_pred.max(1)[1].type_as(y_true)
correct = preds.eq(y_true).double()
correct = correct.sum().item()
return correct / len(y_true)
def split2mask(split):
num_nodes = len(split['train']) + len(split['valid']) + len(split['test'])
return {
'train': idx2mask(split['train'], num_nodes),
'valid': idx2mask(split['valid'], num_nodes),
'test': idx2mask(split['test'], num_nodes)
}
def mask2idx(mask):
return torch.where(mask == 1)[0]
def get_split_from_mask(masks):
if masks.get('valid', None) is None:
split = [{
'train': mask2idx(masks['train']),
'test': mask2idx(masks['test'])
}]
num_train = len(split[0]['train'])
num_valid = int(num_train * 0.2)
perm = torch.randperm(num_train)
split[0]['valid'] = split[0]['train'][perm[:num_valid]]
split[0]['train'] = split[0]['train'][perm[num_valid:]]
return split
elif masks['train'].dim() == 1:
return [{
'train': mask2idx(masks['train']),
'valid': mask2idx(masks['valid']),
'test': mask2idx(masks['test'])
}]
elif masks['train'].dim() == 2:
num_splits = masks['train'].shape[1]
return [
{
'train': mask2idx(masks['train'][:, i]),
'valid': mask2idx(masks['valid'][:, i]),
'test': mask2idx(masks['test'][:, i])
} for i in range(num_splits)
]
def get_normalized_cut(data, y):
A = to_dense_adj(data.edge_index)[0]
D = torch.diag(torch.sum(A, dim=1))
normalized_cut = torch.trace(y.T @ A @ y) / torch.trace(y.T @ D @ y)
return normalized_cut.item()
def get_mad_value(h, mask, target_idx=None):
# h and mask are in cpu
from sklearn.metrics import pairwise_distances
h = h.cpu().detach().numpy()
mask = mask.cpu().detach().numpy()
dist_arr = pairwise_distances(h, h, metric='cosine')
mask_dist = np.multiply(dist_arr, mask)
divide_arr = (mask_dist != 0).sum(1) + 1e-8
node_dist = mask_dist.sum(1) / divide_arr
if target_idx is None:
mad = np.mean(node_dist)
else:
node_dist = node_dist * target_idx
mad = node_dist.sum() / ((node_dist != 0).sum() + 1e-8)
return mad