-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathsage_conv.py
156 lines (128 loc) · 5.68 KB
/
sage_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor
from torch_geometric.utils import spmm
class SAGEConv(MessagePassing):
r"""The GraphSAGE operator from the `"Inductive Representation Learning on
Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper.
.. math::
\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
\mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j
If :obj:`project = True`, then :math:`\mathbf{x}_j` will first get
projected via
.. math::
\mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j +
\mathbf{b})
as described in Eq. (3) of the paper.
Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities.
out_channels (int): Size of each output sample.
aggr (str or Aggregation, optional): The aggregation scheme to use.
Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,
*e.g.*, :obj:`"mean"`, :obj:`"max"`, or :obj:`"lstm"`.
(default: :obj:`"mean"`)
normalize (bool, optional): If set to :obj:`True`, output features
will be :math:`\ell_2`-normalized, *i.e.*,
:math:`\frac{\mathbf{x}^{\prime}_i}
{\| \mathbf{x}^{\prime}_i \|_2}`.
(default: :obj:`False`)
root_weight (bool, optional): If set to :obj:`False`, the layer will
not add transformed root node features to the output.
(default: :obj:`True`)
project (bool, optional): If set to :obj:`True`, the layer will apply a
linear transformation followed by an activation function before
aggregation (as described in Eq. (3) of the paper).
(default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **inputs:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
edge indices :math:`(2, |\mathcal{E}|)`
- **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V_t}|, F_{out})` if bipartite
"""
def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
aggr: Optional[Union[str, List[str], Aggregation]] = "mean",
normalize: bool = False,
root_weight: bool = True,
project: bool = False,
bias: bool = True,
**kwargs,
):
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.root_weight = root_weight
self.project = project
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
if aggr == 'lstm':
kwargs.setdefault('aggr_kwargs', {})
kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0])
kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0])
super().__init__(aggr, **kwargs)
if self.project:
if in_channels[0] <= 0:
raise ValueError(f"'{self.__class__.__name__}' does not "
f"support lazy initialization with "
f"`project=True`")
self.lin = Linear(in_channels[0], in_channels[0], bias=True)
if isinstance(self.aggr_module, MultiAggregation):
aggr_out_channels = self.aggr_module.get_out_channels(
in_channels[0])
else:
aggr_out_channels = in_channels[0]
self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)
if self.root_weight:
self.lin_r = Linear(in_channels[1], out_channels, bias=False)
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
if self.project:
self.lin.reset_parameters()
self.lin_l.reset_parameters()
if self.root_weight:
self.lin_r.reset_parameters()
def forward(
self,
x: Union[Tensor, OptPairTensor],
edge_index: Adj,
size: Size = None,
) -> Tensor:
if isinstance(x, Tensor):
x = (x, x)
if self.project and hasattr(self, 'lin'):
x = (self.lin(x[0]).relu(), x[1])
# propagate_type: (x: OptPairTensor)
out = self.propagate(edge_index, x=x, size=size)
out = self.lin_l(out)
x_r = x[1]
if self.root_weight and x_r is not None:
out = out + self.lin_r(x_r)
if self.normalize:
out = F.normalize(out, p=2., dim=-1)
return out
def message(self, x_j: Tensor) -> Tensor:
return x_j
def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
if isinstance(adj_t, SparseTensor):
adj_t = adj_t.set_value(None, layout=None)
return spmm(adj_t, x[0], reduce=self.aggr)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, aggr={self.aggr})')