Skip to content

Commit

Permalink
[Type Hints] nn.TopKPooling (#5731)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
camillepradel and rusty1s authored Oct 19, 2022
1 parent 396f183 commit ab131ab
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 42 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
81 changes: 51 additions & 30 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import torch

from torch_geometric.nn.pool.topk_pool import TopKPooling, filter_adj, topk
from torch_geometric.testing import is_full_test


def test_topk():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm = topk(x, 0.5, batch)
perm1 = topk(x, 0.5, batch)

assert perm.tolist() == [1, 5, 3]
assert x[perm].tolist() == [4, 9, 6]
assert batch[perm].tolist() == [0, 1, 1]
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm = topk(x, 3, batch)
perm2 = topk(x, 3, batch)

assert perm.tolist() == [1, 0, 5, 3, 2]
assert x[perm].tolist() == [4, 2, 9, 6, 5]
assert batch[perm].tolist() == [0, 0, 1, 1, 1]
assert perm2.tolist() == [1, 0, 5, 3, 2]
assert x[perm2].tolist() == [4, 2, 9, 6, 5]
assert batch[perm2].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 3, batch), perm2)


def test_filter_adj():
Expand All @@ -26,9 +32,16 @@ def test_filter_adj():
edge_attr = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8])
perm = torch.tensor([2, 3])

edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm)
assert edge_index.tolist() == [[0, 1], [1, 0]]
assert edge_attr.tolist() == [6, 8]
out = filter_adj(edge_index, edge_attr, perm)
assert out[0].tolist() == [[0, 1], [1, 0]]
assert out[1].tolist() == [6, 8]

if is_full_test():
jit = torch.jit.script(filter_adj)

out = jit(edge_index, edge_attr, perm)
assert out[0].tolist() == [[0, 1], [1, 0]]
assert out[1].tolist() == [6, 8]


def test_topk_pooling():
Expand All @@ -38,22 +51,30 @@ def test_topk_pooling():
num_nodes = edge_index.max().item() + 1
x = torch.randn((num_nodes, in_channels))

pool = TopKPooling(in_channels, ratio=0.5)
assert pool.__repr__() == 'TopKPooling(16, ratio=0.5, multiplier=1.0)'

x, edge_index, _, _, _, _ = pool(x, edge_index)
assert x.size() == (num_nodes // 2, in_channels)
assert edge_index.size() == (2, 2)

pool = TopKPooling(in_channels, ratio=None, min_score=0.1)
assert pool.__repr__() == 'TopKPooling(16, min_score=0.1, multiplier=1.0)'
out = pool(x, edge_index)
assert out[0].size(0) <= x.size(0) and out[0].size(1) == (16)
assert out[1].size(0) == 2 and out[1].size(1) <= edge_index.size(1)

pool = TopKPooling(in_channels, ratio=2)
assert pool.__repr__() == 'TopKPooling(16, ratio=2, multiplier=1.0)'

x, edge_index, _, _, _, _ = pool(x, edge_index)
assert x.size() == (2, in_channels)
assert edge_index.size() == (2, 2)
pool1 = TopKPooling(in_channels, ratio=0.5)
assert str(pool1) == 'TopKPooling(16, ratio=0.5, multiplier=1.0)'
out1 = pool1(x, edge_index)
assert out1[0].size() == (num_nodes // 2, in_channels)
assert out1[1].size() == (2, 2)

pool2 = TopKPooling(in_channels, ratio=None, min_score=0.1)
assert str(pool2) == 'TopKPooling(16, min_score=0.1, multiplier=1.0)'
out2 = pool2(x, edge_index)
assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16)
assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1)

pool3 = TopKPooling(in_channels, ratio=2)
assert str(pool3) == 'TopKPooling(16, ratio=2, multiplier=1.0)'
out3 = pool3(x, edge_index)
assert out3[0].size() == (2, in_channels)
assert out3[1].size() == (2, 2)

if is_full_test():
jit1 = torch.jit.script(pool1)
assert torch.allclose(jit1(x, edge_index)[0], out1[0])

jit2 = torch.jit.script(pool2)
assert torch.allclose(jit2(x, edge_index)[0], out2[0])

jit3 = torch.jit.script(pool3)
assert torch.allclose(jit3(x, edge_index)[0], out3[0])
44 changes: 33 additions & 11 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -13,9 +13,9 @@

def topk(
x: Tensor,
ratio: float,
ratio: Optional[Union[float, int]],
batch: Tensor,
min_score: Optional[int] = None,
min_score: Optional[float] = None,
tol: float = 1e-7,
) -> Tensor:
if min_score is not None:
Expand All @@ -24,7 +24,8 @@ def topk(
scores_min = scores_max.clamp(max=min_score)

perm = (x > scores_min).nonzero().view(-1)
else:

elif ratio is not None:
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

Expand All @@ -48,7 +49,7 @@ def topk(
k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
k = torch.min(k, num_nodes)
else:
k = (ratio * num_nodes.to(torch.float)).ceil().to(torch.long)
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
Expand All @@ -58,17 +59,26 @@ def topk(

perm = perm[mask]

else:
raise ValueError("At least one of 'min_score' and 'ratio' parameters "
"must be specified")

return perm


def filter_adj(edge_index, edge_attr, perm, num_nodes=None):
def filter_adj(
edge_index: Tensor,
edge_attr: Optional[Tensor],
perm: Tensor,
num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
num_nodes = maybe_num_nodes(edge_index, num_nodes)

mask = perm.new_full((num_nodes, ), -1)
i = torch.arange(perm.size(0), dtype=torch.long, device=perm.device)
mask[perm] = i

row, col = edge_index
row, col = edge_index[0], edge_index[1]
row, col = mask[row], mask[col]
mask = (row >= 0) & (col >= 0)
row, col = row[mask], col[mask]
Expand Down Expand Up @@ -131,9 +141,14 @@ class TopKPooling(torch.nn.Module):
nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
(default: :obj:`torch.tanh`)
"""
def __init__(self, in_channels: int, ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None, multiplier: float = 1.,
nonlinearity: Callable = torch.tanh):
def __init__(
self,
in_channels: int,
ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None,
multiplier: float = 1.,
nonlinearity: Callable = torch.tanh,
):
super().__init__()

self.in_channels = in_channels
Expand All @@ -150,7 +165,14 @@ def reset_parameters(self):
size = self.in_channels
uniform(size, self.weight)

def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_attr: Optional[Tensor] = None,
batch: Optional[Tensor] = None,
attn: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
""""""

if batch is None:
Expand Down

0 comments on commit ab131ab

Please sign in to comment.