This repository has been archived by the owner on Jul 22, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 163
/
Copy pathtrainer.py
203 lines (160 loc) · 6.94 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
import utils as u
import logger
import time
import pandas as pd
import numpy as np
class Trainer():
def __init__(self,args, splitter, gcn, classifier, comp_loss, dataset, num_classes):
self.args = args
self.splitter = splitter
self.tasker = splitter.tasker
self.gcn = gcn
self.classifier = classifier
self.comp_loss = comp_loss
self.num_nodes = dataset.num_nodes
self.data = dataset
self.num_classes = num_classes
self.logger = logger.Logger(args, self.num_classes)
self.init_optimizers(args)
if self.tasker.is_static:
adj_matrix = u.sparse_prepare_tensor(self.tasker.adj_matrix, torch_size = [self.num_nodes], ignore_batch_dim = False)
self.hist_adj_list = [adj_matrix]
self.hist_ndFeats_list = [self.tasker.nodes_feats.float()]
def init_optimizers(self,args):
params = self.gcn.parameters()
self.gcn_opt = torch.optim.Adam(params, lr = args.learning_rate)
params = self.classifier.parameters()
self.classifier_opt = torch.optim.Adam(params, lr = args.learning_rate)
self.gcn_opt.zero_grad()
self.classifier_opt.zero_grad()
def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
def load_checkpoint(self, filename, model):
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
epoch = checkpoint['epoch']
self.gcn.load_state_dict(checkpoint['gcn_dict'])
self.classifier.load_state_dict(checkpoint['classifier_dict'])
self.gcn_opt.load_state_dict(checkpoint['gcn_optimizer'])
self.classifier_opt.load_state_dict(checkpoint['classifier_optimizer'])
self.logger.log_str("=> loaded checkpoint '{}' (epoch {})".format(filename, checkpoint['epoch']))
return epoch
else:
self.logger.log_str("=> no checkpoint found at '{}'".format(filename))
return 0
def train(self):
self.tr_step = 0
best_eval_valid = 0
eval_valid = 0
epochs_without_impr = 0
for e in range(self.args.num_epochs):
eval_train, nodes_embs = self.run_epoch(self.splitter.train, e, 'TRAIN', grad = True)
if len(self.splitter.dev)>0 and e>self.args.eval_after_epochs:
eval_valid, _ = self.run_epoch(self.splitter.dev, e, 'VALID', grad = False)
if eval_valid>best_eval_valid:
best_eval_valid = eval_valid
epochs_without_impr = 0
print ('### w'+str(self.args.rank)+') ep '+str(e)+' - Best valid measure:'+str(eval_valid))
else:
epochs_without_impr+=1
if epochs_without_impr>self.args.early_stop_patience:
print ('### w'+str(self.args.rank)+') ep '+str(e)+' - Early stop.')
break
if len(self.splitter.test)>0 and eval_valid==best_eval_valid and e>self.args.eval_after_epochs:
eval_test, _ = self.run_epoch(self.splitter.test, e, 'TEST', grad = False)
if self.args.save_node_embeddings:
self.save_node_embs_csv(nodes_embs, self.splitter.train_idx, log_file+'_train_nodeembs.csv.gz')
self.save_node_embs_csv(nodes_embs, self.splitter.dev_idx, log_file+'_valid_nodeembs.csv.gz')
self.save_node_embs_csv(nodes_embs, self.splitter.test_idx, log_file+'_test_nodeembs.csv.gz')
def run_epoch(self, split, epoch, set_name, grad):
t0 = time.time()
log_interval=999
if set_name=='TEST':
log_interval=1
self.logger.log_epoch_start(epoch, len(split), set_name, minibatch_log_interval=log_interval)
torch.set_grad_enabled(grad)
for s in split:
if self.tasker.is_static:
s = self.prepare_static_sample(s)
else:
s = self.prepare_sample(s)
predictions, nodes_embs = self.predict(s.hist_adj_list,
s.hist_ndFeats_list,
s.label_sp['idx'],
s.node_mask_list)
loss = self.comp_loss(predictions,s.label_sp['vals'])
# print(loss)
if set_name in ['TEST', 'VALID'] and self.args.task == 'link_pred':
self.logger.log_minibatch(predictions, s.label_sp['vals'], loss.detach(), adj = s.label_sp['idx'])
else:
self.logger.log_minibatch(predictions, s.label_sp['vals'], loss.detach())
if grad:
self.optim_step(loss)
torch.set_grad_enabled(True)
eval_measure = self.logger.log_epoch_done()
return eval_measure, nodes_embs
def predict(self,hist_adj_list,hist_ndFeats_list,node_indices,mask_list):
nodes_embs = self.gcn(hist_adj_list,
hist_ndFeats_list,
mask_list)
predict_batch_size = 100000
gather_predictions=[]
for i in range(1 +(node_indices.size(1)//predict_batch_size)):
cls_input = self.gather_node_embs(nodes_embs, node_indices[:, i*predict_batch_size:(i+1)*predict_batch_size])
predictions = self.classifier(cls_input)
gather_predictions.append(predictions)
gather_predictions=torch.cat(gather_predictions, dim=0)
return gather_predictions, nodes_embs
def gather_node_embs(self,nodes_embs,node_indices):
cls_input = []
for node_set in node_indices:
cls_input.append(nodes_embs[node_set])
return torch.cat(cls_input,dim = 1)
def optim_step(self,loss):
self.tr_step += 1
loss.backward()
if self.tr_step % self.args.steps_accum_gradients == 0:
self.gcn_opt.step()
self.classifier_opt.step()
self.gcn_opt.zero_grad()
self.classifier_opt.zero_grad()
def prepare_sample(self,sample):
sample = u.Namespace(sample)
for i,adj in enumerate(sample.hist_adj_list):
adj = u.sparse_prepare_tensor(adj,torch_size = [self.num_nodes])
sample.hist_adj_list[i] = adj.to(self.args.device)
nodes = self.tasker.prepare_node_feats(sample.hist_ndFeats_list[i])
sample.hist_ndFeats_list[i] = nodes.to(self.args.device)
node_mask = sample.node_mask_list[i]
sample.node_mask_list[i] = node_mask.to(self.args.device).t() #transposed to have same dimensions as scorer
label_sp = self.ignore_batch_dim(sample.label_sp)
if self.args.task in ["link_pred", "edge_cls"]:
label_sp['idx'] = label_sp['idx'].to(self.args.device).t() ####### ALDO TO CHECK why there was the .t() -----> because I concatenate embeddings when there are pairs of them, the embeddings are row vectors after the transpose
else:
label_sp['idx'] = label_sp['idx'].to(self.args.device)
label_sp['vals'] = label_sp['vals'].type(torch.long).to(self.args.device)
sample.label_sp = label_sp
return sample
def prepare_static_sample(self,sample):
sample = u.Namespace(sample)
sample.hist_adj_list = self.hist_adj_list
sample.hist_ndFeats_list = self.hist_ndFeats_list
label_sp = {}
label_sp['idx'] = [sample.idx]
label_sp['vals'] = sample.label
sample.label_sp = label_sp
return sample
def ignore_batch_dim(self,adj):
if self.args.task in ["link_pred", "edge_cls"]:
adj['idx'] = adj['idx'][0]
adj['vals'] = adj['vals'][0]
return adj
def save_node_embs_csv(self, nodes_embs, indexes, file_name):
csv_node_embs = []
for node_id in indexes:
orig_ID = torch.DoubleTensor([self.tasker.data.contID_to_origID[node_id]])
csv_node_embs.append(torch.cat((orig_ID,nodes_embs[node_id].double())).detach().numpy())
pd.DataFrame(np.array(csv_node_embs)).to_csv(file_name, header=None, index=None, compression='gzip')
#print ('Node embs saved in',file_name)