-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathGCNII_layer.py
113 lines (87 loc) · 4.28 KB
/
GCNII_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Tuple, Optional
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.nn.inits import glorot
OptTensor = Optional[Tensor]
PairTensor = Tuple[Tensor, Tensor]
OptPairTensor = Tuple[Tensor, Optional[Tensor]]
PairOptTensor = Tuple[Optional[Tensor], Optional[Tensor]]
Size = Optional[Tuple[int, int]]
NoneType = Optional[Tensor]
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
fill_value = 2. if improved else 1.
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
device=edge_index.device)
if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight
row, col = edge_index[0], edge_index[1]
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
class GCNIIdenseConv(MessagePassing):
_cached_edge_index: Optional[Tuple[torch.Tensor, torch.Tensor]]
def __init__(self, in_channels, out_channels, improved=False, cached=True,
**kwargs):
super(GCNIIdenseConv, self).__init__(aggr='add', **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.weight1 = Parameter(torch.Tensor(in_channels, out_channels))
self.weight2 = Parameter(torch.Tensor(in_channels, out_channels))
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight1)
glorot(self.weight2)
self.cached_result = None
self.cached_num_edges = None
@staticmethod
def norm(edge_index, num_nodes, edge_weight=None, improved=False,
dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
device=edge_index.device)
fill_value = 1 if not improved else 2
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, alpha, h0, beta, edge_weight=None):
""""""
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}. Please '
'disable the caching behavior of this layer by removing '
'the `cached=True` argument in its constructor.'.format(
self.cached_num_edges, edge_index.size(1)))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
edge_index, norm = self.norm(edge_index, x.size(0), edge_weight,
self.improved, x.dtype)
self.cached_result = edge_index, norm
edge_index, norm = self.cached_result
support = (1 - beta) * (1 - alpha) * x + (1 - alpha) * beta * torch.matmul(x, self.weight1)
initial = (1 - beta) * (alpha) * h0 + beta * (alpha) * torch.matmul(h0, self.weight2)
out = self.propagate(edge_index, x=support, norm=norm) + initial
return out
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)