Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate nn.glob package #5039

Merged
merged 16 commits into from
Jul 26, 2022
78 changes: 78 additions & 0 deletions test/nn/aggr/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch

from torch_geometric.nn.aggr import SortAggr


def test_sort_aggregation_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'

out = aggr(x, index)
assert out.size() == (2, 5 * 4)

out_dim = out = aggr(x, index, dim=0)
assert torch.allclose(out_dim, out)

out = out.view(2, 5, 4)

# First graph output has been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]

# Nodes are sorted.
expected = 3 - torch.arange(4)
assert out[0, :4, -1].argsort().tolist() == expected.tolist()

expected = 4 - torch.arange(5)
assert out[1, :, -1].argsort().tolist() == expected.tolist()


def test_sort_aggregation_pool_smaller_than_k():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

# Set k which is bigger than both N_1=4 and N_2=6.
aggr = SortAggr(k=10)
assert str(aggr) == 'SortAggr(k=10)'

out = aggr(x, index)
assert out.size() == (2, 10 * 4)

out_dim = out = aggr(x, index, dim=0)
assert torch.allclose(out_dim, out)

out = out.view(2, 10, 4)

# Both graph outputs have been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]
assert out[1, -1].tolist() == [0, 0, 0, 0]

# Nodes are sorted.
expected = 3 - torch.arange(4)
assert out[0, :4, -1].argsort().tolist() == expected.tolist()

expected = 5 - torch.arange(6)
assert out[1, :6, -1].argsort().tolist() == expected.tolist()


def test_global_sort_pool_dim_size():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

aggr = SortAggr(k=5)
assert str(aggr) == 'SortAggr(k=5)'

# expand batch output by 1
out = aggr(x, index, dim_size=3)
assert out.size() == (3, 5 * 4)

out = out.view(3, 5, 4)

# Both first and last graph outputs have been filled up with zeros.
assert out[0, -1].tolist() == [0, 0, 0, 0]
assert out[2, -1].tolist() == [0, 0, 0, 0]
72 changes: 72 additions & 0 deletions test/nn/pool/test_glob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch

from torch_geometric.nn import (
global_add_pool,
global_max_pool,
global_mean_pool,
)


def test_global_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

out = global_add_pool(x, batch)
assert out.size() == (2, 4)
assert out[0].tolist() == x[:4].sum(dim=0).tolist()
assert out[1].tolist() == x[4:].sum(dim=0).tolist()

out = global_add_pool(x, None)
assert out.size() == (1, 4)
assert out.tolist() == x.sum(dim=0, keepdim=True).tolist()

out = global_mean_pool(x, batch)
assert out.size() == (2, 4)
assert out[0].tolist() == x[:4].mean(dim=0).tolist()
assert out[1].tolist() == x[4:].mean(dim=0).tolist()

out = global_mean_pool(x, None)
assert out.size() == (1, 4)
assert out.tolist() == x.mean(dim=0, keepdim=True).tolist()

out = global_max_pool(x, batch)
assert out.size() == (2, 4)
assert out[0].tolist() == x[:4].max(dim=0)[0].tolist()
assert out[1].tolist() == x[4:].max(dim=0)[0].tolist()

out = global_max_pool(x, None)
assert out.size() == (1, 4)
assert out.tolist() == x.max(dim=0, keepdim=True)[0].tolist()


def test_permuted_global_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
perm = torch.randperm(N_1 + N_2)

px = x[perm]
pbatch = batch[perm]
px1 = px[pbatch == 0]
px2 = px[pbatch == 1]

out = global_add_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.sum(dim=0))
assert torch.allclose(out[1], px2.sum(dim=0))

out = global_mean_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.mean(dim=0))
assert torch.allclose(out[1], px2.mean(dim=0))

out = global_max_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.max(dim=0)[0])
assert torch.allclose(out[1], px2.max(dim=0)[0])


def test_dense_global_pool():
x = torch.randn(3, 16, 32)
assert torch.allclose(global_add_pool(x, None), x.sum(dim=1))
1 change: 1 addition & 0 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .to_hetero_transformer import to_hetero
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .aggr import * # noqa
from .glob import * # noqa
from .conv import * # noqa
from .norm import * # noqa
from .pool import * # noqa
Expand Down
35 changes: 35 additions & 0 deletions torch_geometric/nn/glob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from torch_geometric.deprecation import deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global_*_pool methods need to be added here as well.

from torch_geometric.nn.aggr import (
AttentionalAggregation,
GraphMultisetTransformer,
Set2Set,
SortAggr,
)

Set2Set = deprecated(
details="use 'nn.aggr.Set2Set' instead",
func_name='nn.pool.Set2Set',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deprecations are moved back to nn.glob.*. Should we change back from nn.pool.* to nn.glob.*? Sorry about that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is correct. We provide backward compatibility in nn/glob.py but deprecate its use.

)(Set2Set)

GraphMultisetTransformer = deprecated(
details="use 'nn.aggr.GraphMultisetTransformer' instead",
func_name='nn.pool.GraphMultisetTransformer',
)(GraphMultisetTransformer)


@deprecated(
details="use 'nn.aggr.SortAggr' instead",
func_name='nn.pool.global_sort_pool',
)
def global_sort_pool(x, index, k):
module = SortAggr(k=k)
return module(x, index=index)


@deprecated(
details="use 'nn.aggr.AttentionalAggregation' instead",
func_name='nn.pool.GlobalAttention',
)
class GlobalAttention(AttentionalAggregation):
def __call__(self, x, batch=None, size=None):
return super().__call__(x, batch, dim_size=size)
34 changes: 6 additions & 28 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from .asap import ASAPooling
from .pan_pool import PANPooling
from .mem_pool import MemPooling
from .glob import (global_max_pool, global_mean_pool, global_add_pool,
GraphMultisetTransformer, Set2Set, GlobalAttention)
from .glob import global_max_pool, global_mean_pool, global_add_pool

try:
import torch_cluster
Expand Down Expand Up @@ -245,32 +244,11 @@ def nearest(x: Tensor, y: Tensor, batch_x: OptTensor = None,


__all__ = [
'TopKPooling',
'SAGPooling',
'EdgePooling',
'ASAPooling',
'PANPooling',
'MemPooling',
'max_pool',
'avg_pool',
'max_pool_x',
'max_pool_neighbor_x',
'avg_pool_x',
'avg_pool_neighbor_x',
'graclus',
'voxel_grid',
'fps',
'knn',
'knn_graph',
'radius',
'radius_graph',
'nearest',
'global_max_pool',
'global_add_pool',
'global_mean_pool',
'GraphMultisetTransformer',
'Set2Set',
'GlobalAttention',
'TopKPooling', 'SAGPooling', 'EdgePooling', 'ASAPooling', 'PANPooling',
'MemPooling', 'max_pool', 'avg_pool', 'max_pool_x', 'max_pool_neighbor_x',
'avg_pool_x', 'avg_pool_neighbor_x', 'graclus', 'voxel_grid', 'fps', 'knn',
'knn_graph', 'radius', 'radius_graph', 'nearest', 'global_max_pool',
'global_add_pool', 'global_mean_pool'
Copy link
Contributor

@lightaime lightaime Jul 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add back a comma for auto-formatting?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. Let‘s also move the global pooling methods to the top.

]

classes = __all__
40 changes: 6 additions & 34 deletions torch_geometric/nn/pool/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,9 @@

from torch import Tensor

from torch_geometric.deprecation import deprecated
from torch_geometric.nn.aggr import (
AttentionalAggregation,
GraphMultisetTransformer,
MaxAggregation,
MeanAggregation,
Set2Set,
SortAggr,
SumAggregation,
)

Expand All @@ -32,6 +27,8 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
if batch is None:
return x.sum(dim=-2, keepdim=x.dim() == 2)
return sum_aggr(x, batch, dim_size=size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we leave the implementation as it is? I am not super happy with having global modules here.



Expand All @@ -54,6 +51,8 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
if batch is None:
return x.mean(dim=-2, keepdim=x.dim() == 2)
return mean_aggr(x, batch, dim_size=size)


Expand All @@ -76,33 +75,6 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
size (int, optional): Batch-size :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
"""
if batch is None:
return x.max(dim=-2, keepdim=x.dim() == 2)[0]
return max_aggr(x, batch, dim_size=size)


Set2Set = deprecated(
details="use 'nn.aggr.Set2Set' instead",
func_name='nn.pool.Set2Set',
)(Set2Set)

GraphMultisetTransformer = deprecated(
details="use 'nn.aggr.GraphMultisetTransformer' instead",
func_name='nn.pool.GraphMultisetTransformer',
)(GraphMultisetTransformer)


@deprecated(
details="use 'nn.aggr.SortAggr' instead",
func_name='nn.pool.global_sort_pool',
)
def global_sort_pool(x, index, k):
module = SortAggr(k=k)
return module(x, index=index)


@deprecated(
details="use 'nn.aggr.AttentionalAggregation' instead",
func_name='nn.pool.GlobalAttention',
)
class GlobalAttention(AttentionalAggregation):
def __call__(self, x, batch=None, size=None):
return super().__call__(x, batch, dim_size=size)