-
Notifications
You must be signed in to change notification settings - Fork 1
/
input.py
34 lines (25 loc) · 1.01 KB
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import numpy as np
def gen_input(graphs, bkd_gids, nodemax):
As = {}
Xs = {}
Adj = {}
for gid in bkd_gids:
Adj[gid] = graphs[gid].edge_mat
if gid not in As: As[gid] = Adj[gid].clone()
if gid not in Xs: Xs[gid] = graphs[gid].node_features.clone()
Ainputs = {}
Xinputs = {}
for gid in bkd_gids:
if gid not in Ainputs: Ainputs[gid] = As[gid].clone().detach()
if gid not in Xinputs: Xinputs[gid] = torch.mm(Ainputs[gid], Xs[gid])
# pad each input into maxi possible size (N, N) / (N, F)
for gid in Ainputs.keys():
a_input = Ainputs[gid]
x_input = Xinputs[gid]
add_dim = nodemax - a_input.shape[0]
Ainputs[gid] = np.pad(a_input, ((0, add_dim), (0, add_dim))).tolist()
Xinputs[gid] = np.pad(x_input, ((0, add_dim), (0, 0))).tolist()
Ainputs[gid] = torch.tensor(Ainputs[gid])
Xinputs[gid] = torch.tensor(Xinputs[gid])
return Ainputs, Xinputs