-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmodel.py
44 lines (41 loc) · 1.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch.nn as nn
import torch
from layer import *
class GAT(nn.Module):
def __init__(self,
nb_classes,
nb_nodes,
attn_drop,
ffd_drop,
bias_mat,
hid_units,
n_heads,
residual=False):
super(GAT, self).__init__()
self.nb_classes = nb_classes
self.nb_nodes = nb_nodes
self.attn_drop = attn_drop
self.ffd_drop = ffd_drop
self.bias_mat = bias_mat
self.hid_units = hid_units
self.n_heads = n_heads
self.residual = residual
self.attn1 = Attn_head(in_channel=1433, out_sz=self.hid_units[0],
bias_mat=self.bias_mat, in_drop=self.ffd_drop,
coef_drop=self.attn_drop, activation=nn.ELU(),
residual=self.residual)
self.attn2 = Attn_head(in_channel=64, out_sz=self.nb_classes,
bias_mat=self.bias_mat, in_drop=self.ffd_drop,
coef_drop=self.attn_drop, activation=nn.ELU(),
residual=self.residual)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
attns = []
for _ in range(self.n_heads[0]):
attns.append(self.attn1(x))
h_1 = torch.cat(attns, dim=1)
out = self.attn2(h_1)
logits = torch.transpose(out.view(self.nb_classes,-1), 1, 0)
logits = self.softmax(logits)
return logits