-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgcn_with_edge_removal.py
107 lines (87 loc) · 3.17 KB
/
gcn_with_edge_removal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import time
import numpy as np
import torch
import torch.nn.functional as F
import dgl
from dataset import load_graph_dataset
from gcn import GCN
def evaluate(model, features, labels, mask):
model.eval()
with torch.no_grad():
logits = model(features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
class gcn_with_edge_removal:
def __init__(self, dataset, remove_edge_index=False, lr=1e-2, n_epochs=100, n_hidden=16, n_layers=1, weight_decay=5e-4, dropout=0.0, new_labels=None):
self.dataset = dataset
self.remove_edge_idx = remove_edge_index
# self.gpu = gpu
self.lr = lr
self.n_epochs = n_epochs
self.n_hidden = n_hidden
self.n_layers = n_layers
self.weight_decay = weight_decay
self.dropout = dropout
self.new_labels = new_labels
def train_evaluate(self):
g, features, labels, train_mask, val_mask, test_mask, n_classes = load_graph_dataset(self.dataset)
if self.new_labels is not None:
labels = self.new_labels
# if self.gpu < 0:
# cuda = False
# else:
# cuda = True
# if self.remove_edge_idx:
# print('todo or not')
g = dgl.remove_edges(g, self.remove_edge_idx)
in_feats = features.shape[1]
n_edges = g.number_of_edges()
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
# if cuda:
# norm = norm.cuda()
g.ndata['norm'] = norm.unsqueeze(1)
# create GCN model
model = GCN(g,
in_feats,
self.n_hidden,
n_classes,
self.n_layers,
F.relu,
self.dropout)
# if cuda:
# model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
# initialize graph
dur = []
for epoch in range(self.n_epochs):
model.train()
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
# acc = evaluate(model, features, labels, val_mask)
# print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
# "number of edges {:.2f}". format(epoch, np.mean(dur), loss.item(),
# acc, n_edges))
print(g.number_of_edges())
print()
acc = evaluate(model, features, labels, test_mask)
print("Test accuracy {:.2%}".format(acc))
return acc