-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdiff_moduleS4p.py
63 lines (57 loc) · 1.98 KB
/
diff_moduleS4p.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import numpy as np
def GCN_diffusion(W,order,feature,device='cuda'):
"""
W: [batchsize,n,n]
feature: [batchsize,n,n]
"""
identity_matrices = torch.eye(W.size(1)).repeat(W.size(0), 1, 1)
I_n = identity_matrices.to(device)
A_gcn = W + I_n #[b,n,n]
###
degrees = torch.sum(A_gcn,2)
degrees = degrees.unsqueeze(dim=2) # [b,n,1]
D = degrees
##
D = torch.pow(D, -0.5)
gcn_diffusion_list = []
A_gcn_feature = feature
for i in range(order):
A_gcn_feature = D*A_gcn_feature
A_gcn_feature = torch.matmul(A_gcn,A_gcn_feature) # batched matrix x batched matrix https://pytorch.org/docs/stable/generated/torch.matmul.html
A_gcn_feature = torch.mul(A_gcn_feature,D)
gcn_diffusion_list += [A_gcn_feature,]
return gcn_diffusion_list
def scattering_diffusionS4(sptensor,feature):
'''
A_tilte,adj_p,shape(N,N)
feature:shape(N,3) :torch.FloatTensor
all on cuda
'''
h_sct1,h_sct2,h_sct3,h_sct4 = SCT1stv2(sptensor,4,feature)
return h_sct1,h_sct2,h_sct3,h_sct4
def SCT1stv2(W,order,feature):
'''
W = [b,n,n]
'''
degrees = torch.sum(W,2)
D = degrees
# D = D.to_dense() # transfer D from sparse tensor to normal torch tensor
D = torch.pow(D, -1)
D = D.unsqueeze(dim=2)
iteration = 2**order
scale_list = list(2**i - 1 for i in range(order+1))
# scale_list = [0,1,3,7]
feature_p = feature
sct_diffusion_list = []
for i in range(iteration):
D_inv_x = D*feature_p
W_D_inv_x = torch.matmul(W,D_inv_x)
feature_p = 0.5*feature_p + 0.5*W_D_inv_x
if i in scale_list:
sct_diffusion_list += [feature_p,]
sct_feature1 = sct_diffusion_list[0]-sct_diffusion_list[1]
sct_feature2 = sct_diffusion_list[1]-sct_diffusion_list[2]
sct_feature3 = sct_diffusion_list[2]-sct_diffusion_list[3]
sct_feature4 = sct_diffusion_list[3]-sct_diffusion_list[4]
return sct_feature1,sct_feature2,sct_feature3,sct_feature4