-
Notifications
You must be signed in to change notification settings - Fork 3.7k
/
Copy pathgen_conv.py
243 lines (204 loc) · 9.49 KB
/
gen_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from typing import List, Optional, Tuple, Union
from torch import Tensor
from torch.nn import (
BatchNorm1d,
Dropout,
InstanceNorm1d,
LayerNorm,
ReLU,
Sequential,
)
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.nn.norm import MessageNorm
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
class MLP(Sequential):
def __init__(self, channels: List[int], norm: Optional[str] = None,
bias: bool = True, dropout: float = 0.):
m = []
for i in range(1, len(channels)):
m.append(Linear(channels[i - 1], channels[i], bias=bias))
if i < len(channels) - 1:
if norm and norm == 'batch':
m.append(BatchNorm1d(channels[i], affine=True))
elif norm and norm == 'layer':
m.append(LayerNorm(channels[i], elementwise_affine=True))
elif norm and norm == 'instance':
m.append(InstanceNorm1d(channels[i], affine=False))
elif norm:
raise NotImplementedError(
f'Normalization layer "{norm}" not supported.')
m.append(ReLU())
m.append(Dropout(dropout))
super().__init__(*m)
class GENConv(MessagePassing):
r"""The GENeralized Graph Convolution (GENConv) from the `"DeeperGCN: All
You Need to Train Deeper GCNs" <https://arxiv.org/abs/2006.07739>`_ paper.
:class:`GENConv` supports both :math:`\textrm{softmax}` (see
:class:`~torch_geometric.nn.aggr.SoftmaxAggregation`) and
:math:`\textrm{powermean}` (see
:class:`~torch_geometric.nn.aggr.PowerMeanAggregation`) aggregation.
Its message construction is given by:
.. math::
\mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i +
\mathrm{AGG} \left( \left\{
\mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon
: j \in \mathcal{N}(i) \right\} \right)
\right)
.. note::
For an example of using :obj:`GENConv`, see
`examples/ogbn_proteins_deepgcn.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
ogbn_proteins_deepgcn.py>`_.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
aggr (str or Aggregation, optional): The aggregation scheme to use.
Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,
(:obj:`"softmax"`, :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`,
:obj:`max`). (default: :obj:`"softmax"`)
t (float, optional): Initial inverse temperature for softmax
aggregation. (default: :obj:`1.0`)
learn_t (bool, optional): If set to :obj:`True`, will learn the value
:obj:`t` for softmax aggregation dynamically.
(default: :obj:`False`)
p (float, optional): Initial power for power mean aggregation.
(default: :obj:`1.0`)
learn_p (bool, optional): If set to :obj:`True`, will learn the value
:obj:`p` for power mean aggregation dynamically.
(default: :obj:`False`)
msg_norm (bool, optional): If set to :obj:`True`, will use message
normalization. (default: :obj:`False`)
learn_msg_scale (bool, optional): If set to :obj:`True`, will learn the
scaling factor of message normalization. (default: :obj:`False`)
norm (str, optional): Norm layer of MLP layers (:obj:`"batch"`,
:obj:`"layer"`, :obj:`"instance"`) (default: :obj:`batch`)
num_layers (int, optional): The number of MLP layers.
(default: :obj:`2`)
expansion (int, optional): The expansion factor of hidden channels in
MLP layers. (default: :obj:`2`)
eps (float, optional): The epsilon value of the message construction
function. (default: :obj:`1e-7`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
edge_dim (int, optional): Edge feature dimensionality. If set to
:obj:`None`, Edge feature dimensionality is expected to match
the `out_channels`. Other-wise, edge features are linearly
transformed to match `out_channels` of node feature dimensionality.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GenMessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`,
edge attributes :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax',
t: float = 1.0,
learn_t: bool = False,
p: float = 1.0,
learn_p: bool = False,
msg_norm: bool = False,
learn_msg_scale: bool = False,
norm: str = 'batch',
num_layers: int = 2,
expansion: int = 2,
eps: float = 1e-7,
bias: bool = False,
edge_dim: Optional[int] = None,
**kwargs,
):
# Backward compatibility:
semi_grad = True if aggr == 'softmax_sg' else False
aggr = 'softmax' if aggr == 'softmax_sg' else aggr
aggr = 'powermean' if aggr == 'power' else aggr
# Override args of aggregator if `aggr_kwargs` is specified
if 'aggr_kwargs' not in kwargs:
if aggr == 'softmax':
kwargs['aggr_kwargs'] = dict(t=t, learn=learn_t,
semi_grad=semi_grad)
elif aggr == 'powermean':
kwargs['aggr_kwargs'] = dict(p=p, learn=learn_p)
super().__init__(aggr=aggr, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.eps = eps
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
if in_channels[0] != out_channels:
self.lin_src = Linear(in_channels[0], out_channels, bias=bias)
if edge_dim is not None and edge_dim != out_channels:
self.lin_edge = Linear(edge_dim, out_channels, bias=bias)
if isinstance(self.aggr_module, MultiAggregation):
aggr_out_channels = self.aggr_module.get_out_channels(out_channels)
else:
aggr_out_channels = out_channels
if aggr_out_channels != out_channels:
self.lin_aggr_out = Linear(aggr_out_channels, out_channels,
bias=bias)
if in_channels[1] != out_channels:
self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)
channels = [out_channels]
for i in range(num_layers - 1):
channels.append(out_channels * expansion)
channels.append(out_channels)
self.mlp = MLP(channels, norm=norm, bias=bias)
if msg_norm:
self.msg_norm = MessageNorm(learn_msg_scale)
def reset_parameters(self):
super().reset_parameters()
reset(self.mlp)
if hasattr(self, 'msg_norm'):
self.msg_norm.reset_parameters()
if hasattr(self, 'lin_src'):
self.lin_src.reset_parameters()
if hasattr(self, 'lin_edge'):
self.lin_edge.reset_parameters()
if hasattr(self, 'lin_aggr_out'):
self.lin_aggr_out.reset_parameters()
if hasattr(self, 'lin_dst'):
self.lin_dst.reset_parameters()
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
edge_attr: OptTensor = None, size: Size = None) -> Tensor:
if isinstance(x, Tensor):
x = (x, x)
if hasattr(self, 'lin_src'):
x = (self.lin_src(x[0]), x[1])
# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
if hasattr(self, 'lin_aggr_out'):
out = self.lin_aggr_out(out)
if hasattr(self, 'msg_norm'):
h = x[1] if x[1] is not None else x[0]
assert h is not None
out = self.msg_norm(h, out)
x_dst = x[1]
if x_dst is not None:
if hasattr(self, 'lin_dst'):
x_dst = self.lin_dst(x_dst)
out = out + x_dst
return self.mlp(out)
def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
if edge_attr is not None and hasattr(self, 'lin_edge'):
edge_attr = self.lin_edge(edge_attr)
if edge_attr is not None:
assert x_j.size(-1) == edge_attr.size(-1)
msg = x_j if edge_attr is None else x_j + edge_attr
return msg.relu() + self.eps
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={self.aggr})')