-
Notifications
You must be signed in to change notification settings - Fork 335
/
citation_gat.py
102 lines (87 loc) · 3.14 KB
/
citation_gat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
This example implements the experiments on citation networks from the paper:
Graph Attention Networks (https://arxiv.org/abs/1710.10903)
Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, Yoshua Bengio
"""
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dropout, Input
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed
from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GATConv
from spektral.transforms import LayerPreprocess
set_seed(0)
# Load data
dataset = Citation("cora", normalize_x=True, transforms=[LayerPreprocess(GATConv)])
def mask_to_weights(mask):
return mask.astype(np.float32) / np.count_nonzero(mask)
weights_tr, weights_va, weights_te = (
mask_to_weights(mask)
for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
# Parameters
channels = 8 # Number of channels in each head of the first GAT layer
n_attn_heads = 8 # Number of attention heads in first GAT layer
dropout = 0.6 # Dropout rate for the features and adjacency matrix
l2_reg = 2.5e-4 # L2 regularization rate
learning_rate = 5e-3 # Learning rate
epochs = 20000 # Number of training epochs
patience = 100 # Patience for early stopping
N = dataset.n_nodes # Number of nodes in the graph
F = dataset.n_node_features # Original size of node features
n_out = dataset.n_labels # Number of classes
# Model definition
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True)
do_1 = Dropout(dropout)(x_in)
gc_1 = GATConv(
channels,
attn_heads=n_attn_heads,
concat_heads=True,
dropout_rate=dropout,
activation="elu",
kernel_regularizer=l2(l2_reg),
attn_kernel_regularizer=l2(l2_reg),
bias_regularizer=l2(l2_reg),
)([do_1, a_in])
do_2 = Dropout(dropout)(gc_1)
gc_2 = GATConv(
n_out,
attn_heads=1,
concat_heads=False,
dropout_rate=dropout,
activation="softmax",
kernel_regularizer=l2(l2_reg),
attn_kernel_regularizer=l2(l2_reg),
bias_regularizer=l2(l2_reg),
)([do_2, a_in])
# Build model
model = Model(inputs=[x_in, a_in], outputs=gc_2)
optimizer = Adam(learning_rate=learning_rate)
model.compile(
optimizer=optimizer,
loss=CategoricalCrossentropy(reduction="sum"),
weighted_metrics=["acc"],
)
model.summary()
# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
loader_tr.load(),
steps_per_epoch=loader_tr.steps_per_epoch,
validation_data=loader_va.load(),
validation_steps=loader_va.steps_per_epoch,
epochs=epochs,
callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)
# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))