Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate nn.glob package #5039

Merged
merged 16 commits into from
Jul 26, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
15 changes: 0 additions & 15 deletions docs/source/modules/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,6 @@ Normalization Layers
:undoc-members:
:exclude-members: training

Global Pooling Layers
---------------------

.. currentmodule:: torch_geometric.nn.glob
.. autosummary::
:nosignatures:
{% for cls in torch_geometric.nn.glob.classes %}
{{ cls }}
{% endfor %}

.. automodule:: torch_geometric.nn.glob
:members:
:undoc-members:
:exclude-members: training

Pooling Layers
--------------

Expand Down
4 changes: 2 additions & 2 deletions test/nn/aggr/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch_geometric.nn.aggr import SortAggr


def test_global_sort_pool():
Padarn marked this conversation as resolved.
Show resolved Hide resolved
def test_sort_aggregation_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])
Expand All @@ -30,7 +30,7 @@ def test_global_sort_pool():
assert out[1, :, -1].argsort().tolist() == expected.tolist()


def test_global_sort_pool_smaller_than_k():
def test_sort_aggregation_pool_smaller_than_k():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])
Expand Down
1 change: 1 addition & 0 deletions test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_gnn_explainer_with_existing_self_loops(model, return_type):
[0, 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)

assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.size() == (edge_index.size(1), )
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions torch_geometric/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from .to_hetero_with_bases_transformer import to_hetero_with_bases
from .aggr import * # noqa
from .conv import * # noqa
from .norm import * # noqa
from .glob import * # noqa
from .pool import * # noqa
from .glob import * # noqa
from .norm import * # noqa
from .unpool import * # noqa
from .dense import * # noqa
from .models import * # noqa
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
from .glob import (
GlobalPooling,
from torch_geometric.deprecation import deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global_*_pool methods need to be added here as well.

from torch_geometric.nn import (
global_add_pool,
global_max_pool,
global_mean_pool,
)

__all__ = [
'global_add_pool',
'global_mean_pool',
'global_max_pool',
'GlobalPooling',
]

classes = __all__

from torch_geometric.deprecation import deprecated # noqa
from torch_geometric.nn.aggr import AttentionalAggregation # noqa
from torch_geometric.nn.aggr import GraphMultisetTransformer # noqa
from torch_geometric.nn.aggr import SortAggr # noqa
from torch_geometric.nn.aggr import Set2Set # noqa
from torch_geometric.nn.aggr import (
AttentionalAggregation,
GraphMultisetTransformer,
Set2Set,
SortAggr,
)

Set2Set = deprecated(
details="use 'nn.aggr.Set2Set' instead",
Expand All @@ -32,18 +23,34 @@


@deprecated(
details="use 'nn.aggr.GlobalSortAggr' instead",
details="use 'nn.aggr.AttentionalAggregation' instead",
func_name='nn.glob.GlobalAttention',
)
class GlobalAttention(AttentionalAggregation):
def __call__(self, x, batch=None, size=None):
return super().__call__(x, batch, dim_size=size)


@deprecated(
details="use 'nn.aggr.SortAggr' instead",
func_name='nn.glob.global_sort_pool',
)
def global_sort_pool(x, index, k):
module = SortAggr(k=k)
return module(x, index=index)


@deprecated(
details="use 'nn.aggr.GlobalAttention' instead",
func_name='nn.glob.GlobalAttention',
)
class GlobalAttention(AttentionalAggregation):
def __call__(self, x, batch=None, size=None):
return super().__call__(x, batch, dim_size=size)
deprecated(
details="use 'nn.pool.global_add_pool' instead",
func_name='nn.glob.global_add_pool',
)(global_add_pool)

deprecated(
details="use 'nn.pool.global_max_pool' instead",
func_name='nn.glob.global_max_pool',
)(global_max_pool)

deprecated(
details="use 'nn.pool.global_mean_pool' instead",
func_name='nn.glob.global_mean_pool',
)(global_mean_pool)
4 changes: 4 additions & 0 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .asap import ASAPooling
from .pan_pool import PANPooling
from .mem_pool import MemPooling
from .glob import global_max_pool, global_mean_pool, global_add_pool

try:
import torch_cluster
Expand Down Expand Up @@ -243,6 +244,9 @@ def nearest(x: Tensor, y: Tensor, batch_x: OptTensor = None,


__all__ = [
'global_max_pool',
'global_add_pool',
'global_mean_pool',
'TopKPooling',
'SAGPooling',
'EdgePooling',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional, Union
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import scatter

Expand All @@ -10,10 +9,8 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor],
r"""Returns batch-wise graph-level-outputs by adding node features
across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by

.. math::
\mathbf{r}_i = \sum_{n=1}^{N_i} \mathbf{x}_n

Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
Expand All @@ -34,10 +31,8 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor],
r"""Returns batch-wise graph-level-outputs by averaging node features
across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by

.. math::
\mathbf{r}_i = \frac{1}{N_i} \sum_{n=1}^{N_i} \mathbf{x}_n

Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
Expand All @@ -58,10 +53,8 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
r"""Returns batch-wise graph-level-outputs by taking the channel-wise
maximum across the node dimension, so that for a single graph
:math:`\mathcal{G}_i` its output is computed by

.. math::
\mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n

Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
Expand All @@ -75,43 +68,3 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
return x.max(dim=-2, keepdim=x.dim() == 2)[0]
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=-2, dim_size=size, reduce='max')


class GlobalPooling(torch.nn.Module):
r"""A global pooling module that wraps the usage of
:meth:`~torch_geometric.nn.glob.global_add_pool`,
:meth:`~torch_geometric.nn.glob.global_mean_pool` and
:meth:`~torch_geometric.nn.glob.global_max_pool` into a single module.

Args:
aggr (string or List[str]): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
If given as a list, will make use of multiple aggregations in which
different outputs will get concatenated in the last dimension.
"""
def __init__(self, aggr: Union[str, List[str]]):
super().__init__()

self.aggrs = [aggr] if isinstance(aggr, str) else aggr

assert len(self.aggrs) > 0
assert len(set(self.aggrs) | {'sum', 'add', 'mean', 'max'}) == 4

def forward(self, x: Tensor, batch: Optional[Tensor],
size: Optional[int] = None) -> Tensor:
""""""
xs: List[Tensor] = []

for aggr in self.aggrs:
if aggr == 'sum' or aggr == 'add':
xs.append(global_add_pool(x, batch, size))
elif aggr == 'mean':
xs.append(global_mean_pool(x, batch, size))
elif aggr == 'max':
xs.append(global_max_pool(x, batch, size))

return xs[0] if len(xs) == 1 else torch.cat(xs, dim=-1)

def __repr__(self) -> str:
aggr = self.aggrs[0] if len(self.aggrs) == 1 else self.aggrs
return f'{self.__class__.__name__}(aggr={aggr})'