forked from frigategnn/Bonsai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_saint_dataset.py
84 lines (81 loc) · 2.8 KB
/
load_saint_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
r"""Utility to load datasets provided by graph saint (flickr, ogbn-arxiv, reddit)
"""
import json
import numpy as np
import scipy.sparse as sp
import torch
from torch_geometric.data import Data
from torch_sparse import SparseTensor
def load_saint_dataset(name: str, *, root: str="datasets") -> Data:
r"""Loads Graph Saint provided dataset and converts it into proper
torch_geometric.data.Data object.
Params:
name: str - Name of the dataset. One of flickr, ogbn-arxiv, or reddit
root: str - Root directory of the dataset.
returns: torch_geometric.data.Data
"""
with open(f"{root}/{name}/role.json", "r") as jsonfile:
roles = json.load(jsonfile)
train_nodes = roles["tr"]
val_nodes = roles["va"]
test_nodes = roles["te"]
adj = sp.load_npz(f"{root}/{name}/adj_full.npz")
if name == "ogbn-arxiv":
adj = adj + adj.T
adj[adj > 1] = 1
adj = adj.tocoo()
rows = adj.row
cols = adj.col
edge_index = np.stack((rows, cols), axis=0)
feats = np.load(f"{root}/{name}/feats.npy")
feats = feats.astype(np.float32)
with open(f"{root}/{name}/class_map.json", "r") as jsonfile:
class_map = json.load(jsonfile)
num_nodes = feats.shape[0]
ys = np.zeros((num_nodes,))
for node, cls in class_map.items():
ys[int(node)] = cls
ys = ys.astype(np.int64) # Long.
nc = np.unique(ys).shape[0]
train_mask = np.zeros((num_nodes,))
train_mask[train_nodes] = 1
train_mask = train_mask.astype(bool)
val_mask = np.zeros((num_nodes,))
val_mask[val_nodes] = 1
val_mask = val_mask.astype(bool)
test_mask = np.zeros((num_nodes,))
test_mask[test_nodes] = 1
test_mask = test_mask.astype(bool)
#
adj = adj.tolil()
adj = adj + sp.eye(adj.shape[0])
rowsum = np.array(adj.sum(1))
r_inv = np.power(rowsum, -1 / 2).flatten()
r_inv[np.isinf(r_inv)] = 0.0
r_mat_inv = sp.diags(r_inv)
adj = r_mat_inv.dot(adj)
adj = adj.dot(r_mat_inv)
adj = adj.tocoo().astype(np.float32)
sparserow = torch.LongTensor(adj.row).unsqueeze(1)
sparsecol = torch.LongTensor(adj.col).unsqueeze(1)
sparseconcat = torch.cat((sparserow, sparsecol), 1)
sparsedata = torch.FloatTensor(adj.data)
adj = torch.sparse.FloatTensor(sparseconcat.t(), sparsedata, torch.Size(adj.shape))
adj = SparseTensor(
row=adj._indices()[0],
col=adj._indices()[1],
value=adj._values(),
sparse_sizes=adj.size(),
)
data = Data(
x=torch.tensor(feats),
edge_index=torch.tensor(edge_index).long(),
y=torch.tensor(ys),
train_mask=torch.tensor(train_mask),
val_mask=torch.tensor(val_mask),
test_mask=torch.tensor(test_mask),
num_nodes=torch.tensor(num_nodes),
num_classes=torch.tensor(nc),
adj=adj,
)
return data