-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into assortativity
- Loading branch information
Showing
10 changed files
with
186 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from torch_sparse import SparseTensor | ||
|
||
from torch_geometric.nn import SSGConv | ||
from torch_geometric.testing import is_full_test | ||
|
||
|
||
def test_ssg_conv(): | ||
x = torch.randn(4, 16) | ||
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) | ||
row, col = edge_index | ||
value = torch.rand(row.size(0)) | ||
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) | ||
adj1 = adj2.set_value(None) | ||
|
||
conv = SSGConv(16, 32, alpha=0.1, K=10) | ||
assert conv.__repr__() == 'SSGConv(16, 32, K=10, alpha=0.1)' | ||
out1 = conv(x, edge_index) | ||
assert out1.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) | ||
out2 = conv(x, edge_index, value) | ||
assert out2.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) | ||
|
||
if is_full_test(): | ||
t = '(Tensor, Tensor, OptTensor) -> Tensor' | ||
jit = torch.jit.script(conv.jittable(t)) | ||
assert jit(x, edge_index).tolist() == out1.tolist() | ||
assert jit(x, edge_index, value).tolist() == out2.tolist() | ||
|
||
t = '(Tensor, SparseTensor, OptTensor) -> Tensor' | ||
jit = torch.jit.script(conv.jittable(t)) | ||
assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) | ||
assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) | ||
|
||
conv.cached = True | ||
conv(x, edge_index) | ||
assert conv(x, edge_index).tolist() == out1.tolist() | ||
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from typing import Optional | ||
|
||
from torch import Tensor | ||
from torch_sparse import SparseTensor, matmul | ||
|
||
from torch_geometric.nn.conv import MessagePassing | ||
from torch_geometric.nn.conv.gcn_conv import gcn_norm | ||
from torch_geometric.nn.dense.linear import Linear | ||
from torch_geometric.typing import Adj, OptTensor | ||
|
||
|
||
class SSGConv(MessagePassing): | ||
r"""The simple spectral graph convolutional operator from the | ||
`"Simple Spectral Graph Convolution" | ||
<https://openreview.net/forum?id=CYO5T-YjWZV>`_ paper | ||
.. math:: | ||
\mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K\left((1-\alpha) | ||
{\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} | ||
\mathbf{\hat{D}}^{-1/2} \right)}^k | ||
\mathbf{X}+\alpha \mathbf{X}\right) \mathbf{\Theta}, | ||
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the | ||
adjacency matrix with inserted self-loops and | ||
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. | ||
The adjacency matrix can include other values than :obj:`1` representing | ||
edge weights via the optional :obj:`edge_weight` tensor. | ||
:class:`~torch_geometric.nn.conv.SSGConv` is an improved operator of | ||
:class:`~torch_geometric.nn.conv.SGConv` by introducing the :obj:`alpha` | ||
parameter to address the oversmoothing issue. | ||
Args: | ||
in_channels (int): Size of each input sample, or :obj:`-1` to derive | ||
the size from the first input(s) to the forward method. | ||
out_channels (int): Size of each output sample. | ||
alpha (float): Teleport probability :math:`\alpha \in [0, 1]`. | ||
K (int, optional): Number of hops :math:`K`. (default: :obj:`1`) | ||
cached (bool, optional): If set to :obj:`True`, the layer will cache | ||
the computation of :math:`\frac{1}{K} \sum_{k=1}^K\left((1-\alpha) | ||
{\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} | ||
\mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+ | ||
\alpha \mathbf{X}\right)` on first execution, and will use the | ||
cached version for further executions. | ||
This parameter should only be set to :obj:`True` in transductive | ||
learning scenarios. (default: :obj:`False`) | ||
add_self_loops (bool, optional): If set to :obj:`False`, will not add | ||
self-loops to the input graph. (default: :obj:`True`) | ||
bias (bool, optional): If set to :obj:`False`, the layer will not learn | ||
an additive bias. (default: :obj:`True`) | ||
**kwargs (optional): Additional arguments of | ||
:class:`torch_geometric.nn.conv.MessagePassing`. | ||
Shapes: | ||
- **input:** | ||
node features :math:`(|\mathcal{V}|, F_{in})`, | ||
edge indices :math:`(2, |\mathcal{E}|)`, | ||
edge weights :math:`(|\mathcal{E}|)` *(optional)* | ||
- **output:** | ||
node features :math:`(|\mathcal{V}|, F_{out})` | ||
""" | ||
|
||
_cached_h: Optional[Tensor] | ||
|
||
def __init__(self, in_channels: int, out_channels: int, alpha: float, | ||
K: int = 1, cached: bool = False, add_self_loops: bool = True, | ||
bias: bool = True, **kwargs): | ||
kwargs.setdefault('aggr', 'add') | ||
super().__init__(**kwargs) | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.alpha = alpha | ||
self.K = K | ||
self.cached = cached | ||
self.add_self_loops = add_self_loops | ||
|
||
self._cached_h = None | ||
|
||
self.lin = Linear(in_channels, out_channels, bias=bias) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.lin.reset_parameters() | ||
self._cached_h = None | ||
|
||
def forward(self, x: Tensor, edge_index: Adj, | ||
edge_weight: OptTensor = None) -> Tensor: | ||
"""""" | ||
cache = self._cached_h | ||
if cache is None: | ||
if isinstance(edge_index, Tensor): | ||
edge_index, edge_weight = gcn_norm( # yapf: disable | ||
edge_index, edge_weight, x.size(self.node_dim), False, | ||
self.add_self_loops, self.flow, dtype=x.dtype) | ||
elif isinstance(edge_index, SparseTensor): | ||
edge_index = gcn_norm( # yapf: disable | ||
edge_index, edge_weight, x.size(self.node_dim), False, | ||
self.add_self_loops, self.flow, dtype=x.dtype) | ||
|
||
h = x * self.alpha | ||
for k in range(self.K): | ||
# propagate_type: (x: Tensor, edge_weight: OptTensor) | ||
x = self.propagate(edge_index, x=x, edge_weight=edge_weight, | ||
size=None) | ||
h = h + (1 - self.alpha) / self.K * x | ||
if self.cached: | ||
self._cached_h = h | ||
else: | ||
h = cache.detach() | ||
|
||
return self.lin(h) | ||
|
||
def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: | ||
return edge_weight.view(-1, 1) * x_j | ||
|
||
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: | ||
return matmul(adj_t, x, reduce=self.aggr) | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}({self.in_channels}, ' | ||
f'{self.out_channels}, K={self.K}, alpha={self.alpha})') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters