-
Notifications
You must be signed in to change notification settings - Fork 8
/
model.py
72 lines (64 loc) · 3.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
class RW_NN(nn.Module):
def __init__(self, input_dim, max_step, hidden_graphs, size_hidden_graphs, hidden_dim, penultimate_dim, normalize, n_classes, dropout, device):
super(RW_NN, self).__init__()
self.max_step = max_step
self.hidden_graphs = hidden_graphs
self.size_hidden_graphs = size_hidden_graphs
self.normalize = normalize
self.device = device
self.adj_hidden = Parameter(torch.FloatTensor(hidden_graphs, (size_hidden_graphs*(size_hidden_graphs-1))//2))
self.features_hidden = Parameter(torch.FloatTensor(hidden_graphs, size_hidden_graphs, hidden_dim))
self.fc = torch.nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_graphs*max_step)
self.fc1 = torch.nn.Linear(hidden_graphs*max_step, penultimate_dim)
self.fc2 = torch.nn.Linear(penultimate_dim, n_classes)
self.dropout = nn.Dropout(p=dropout)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.init_weights()
def init_weights(self):
self.adj_hidden.data.uniform_(-1, 1)
self.features_hidden.data.uniform_(0, 1)
def forward(self, adj, features, graph_indicator):
unique, counts = torch.unique(graph_indicator, return_counts=True)
n_graphs = unique.size(0)
n_nodes = features.size(0)
if self.normalize:
norm = counts.unsqueeze(1).repeat(1, self.hidden_graphs)
adj_hidden_norm = torch.zeros(self.hidden_graphs, self.size_hidden_graphs, self.size_hidden_graphs).to(self.device)
idx = torch.triu_indices(self.size_hidden_graphs, self.size_hidden_graphs, 1)
adj_hidden_norm[:,idx[0],idx[1]] = self.relu(self.adj_hidden)
adj_hidden_norm = adj_hidden_norm + torch.transpose(adj_hidden_norm, 1, 2)
x = self.sigmoid(self.fc(features))
z = self.features_hidden
zx = torch.einsum("abc,dc->abd", (z, x))
out = list()
for i in range(self.max_step):
if i == 0:
eye = torch.eye(self.size_hidden_graphs, device=self.device)
eye = eye.repeat(self.hidden_graphs, 1, 1)
o = torch.einsum("abc,acd->abd", (eye, z))
t = torch.einsum("abc,dc->abd", (o, x))
else:
x = torch.spmm(adj, x)
z = torch.einsum("abc,acd->abd", (adj_hidden_norm, z))
t = torch.einsum("abc,dc->abd", (z, x))
t = self.dropout(t)
t = torch.mul(zx, t)
t = torch.zeros(t.size(0), t.size(1), n_graphs, device=self.device).index_add_(2, graph_indicator, t)
t = torch.sum(t, dim=1)
t = torch.transpose(t, 0, 1)
if self.normalize:
t /= norm
out.append(t)
out = torch.cat(out, dim=1)
out = self.bn(out)
out = self.relu(self.fc1(out))
out = self.dropout(out)
out = self.fc2(out)
return F.log_softmax(out, dim=1)