Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EquilibriumAggregation global aggregation layer #4522

Merged
merged 27 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fbafde5
wip equilibrium aggregation
Padarn Apr 24, 2022
3c1d0ba
update with docs and batch
Padarn Apr 29, 2022
9b66232
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2022
fcef12c
example of median
Padarn Apr 29, 2022
b21760c
update init
Padarn May 4, 2022
58e1678
update changelog
Padarn May 4, 2022
3b1bc45
add details to median example
Padarn May 6, 2022
a02eea8
match parameters from paper
Padarn May 6, 2022
3f0a23b
match parameters from paper
Padarn May 6, 2022
8d79424
wip
Padarn May 9, 2022
3f1216a
add nesterov loss
Padarn May 14, 2022
894a50f
change loss to be a mean across a single batch
Padarn May 15, 2022
5091ffe
move to aggreation package
Padarn May 28, 2022
ed94b33
move example import
Padarn May 28, 2022
3b3aaca
add assertions for unsupported
Padarn May 28, 2022
99f020f
update changelog
Padarn May 28, 2022
8561c8b
update changelog
Padarn May 28, 2022
8460438
add suggestions and test update
Padarn Jun 4, 2022
6a38ff9
Update test/nn/aggr/test_equilibrium.py
Padarn Jun 3, 2022
b34b3a5
Update test/nn/aggr/test_equilibrium.py
Padarn Jun 3, 2022
4171372
fix error message
Padarn Jun 4, 2022
cc0eac1
fix indent error
Padarn Jun 6, 2022
fa47205
rebase to fix tests
Padarn Jun 18, 2022
7359edc
minor update for agg class
Padarn Jul 9, 2022
2236dc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2022
d504053
Merge branch 'master' into padarn/optim-embedding
Padarn Jul 26, 2022
a3a3a96
Merge branch 'master' into padarn/optim-embedding
Padarn Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand All @@ -58,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `nn.glob.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
Expand Down
41 changes: 41 additions & 0 deletions examples/equilibrium_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
r"""
Replicates the experiment from `"Deep Graph Infomax"
<https://arxiv.org/abs/1809.10341>`_ to try and teach
`EquilibriumAggregation` to learn to take the median of
a set of numbers

This example converges slowly to being able to predict the
median similar to what is observed in the paper.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add a few comments about the convergence of EquilibriumAggregation.

import numpy as np
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation

input_size = 100
steps = 10000000
embedding_size = 10
eval_each = 1000

model = EquilibriumAggregation(1, 10, [256, 256], 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

norm = torch.distributions.normal.Normal(0.5, 0.4)
gamma = torch.distributions.gamma.Gamma(0.2, 0.5)
uniform = torch.distributions.uniform.Uniform(0, 1)
total_loss = 0
n_loss = 0

for i in range(steps):
optimizer.zero_grad()
dist = np.random.choice([norm, gamma, uniform])
x = dist.sample((input_size, 1))
y = model(x)
loss = (y - x.median()).norm(2) / input_size
loss.backward()
optimizer.step()
total_loss += loss
n_loss += 1
if i % eval_each == (eval_each - 1):
print(f"Average loss at epoc {i} is {total_loss / n_loss}")
53 changes: 53 additions & 0 deletions test/nn/aggr/test_equilibrium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium(iter, alpha):

batch_size = 10
feature_channels = 3
output_channels = 2
x = torch.randn(batch_size, feature_channels)
model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x)
assert out.size() == (1, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium_batch(iter, alpha):

batch_1, batch_2 = 4, 6
feature_channels = 3
output_channels = 2
x = torch.randn(batch_1 + batch_2, feature_channels)
batch = torch.tensor([0 for _ in range(batch_1)] +
[1 for _ in range(batch_2)])

model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x, batch)
assert out.size() == (2, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .lstm import LSTMAggregation
from .set2set import Set2Set
from .scaler import DegreeScalerAggregation
from .equilibrium import EquilibriumAggregation
from .sort import SortAggr
from .gmt import GraphMultisetTransformer
from .attention import AttentionalAggregation
Expand All @@ -36,4 +37,5 @@
'SortAggr',
'GraphMultisetTransformer',
'AttentionalAggregation',
'EquilibriumAggregation',
]
189 changes: 189 additions & 0 deletions torch_geometric/nn/aggr/equilibrium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
from torch_scatter import scatter

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.inits import reset


class ResNetPotential(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int,
num_layers: List[int]):

super().__init__()
sizes = [in_channels] + num_layers + [out_channels]
self.layers = torch.nn.ModuleList([
torch.nn.Sequential(torch.nn.Linear(in_size, out_size),
torch.nn.LayerNorm(out_size), torch.nn.Tanh())
for in_size, out_size in zip(sizes[:-2], sizes[1:-1])
])
self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1]))

self.res_trans = torch.nn.ModuleList([
torch.nn.Linear(in_channels, layer_size)
for layer_size in num_layers + [out_channels]
])

def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor:
if index is None:
inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1)
else:
inp = torch.cat([x, y[index]], dim=1)

h = inp
for layer, res in zip(self.layers, self.res_trans):
h = layer(h)
h = res(inp) + h

if index is None:
return h.mean()

size = int(index.max().item() + 1)
return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum()


class MomentumOptimizer(torch.nn.Module):
r"""
Provides an inner loop optimizer for the implicitly defined output
layer. It is based on an unrolled Nesterov momentum algorithm.

Args:
learning_rate (flaot): learning rate for optimizer.
momentum (float): momentum for optimizer.
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and
:obj:`momentum` will be learnable parameters. If False they
are fixed. (default: :obj:`True`)
"""
def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9,
learnable: bool = True):
super().__init__()

self._initial_lr = learning_rate
self._initial_mom = momentum
self._lr = torch.nn.Parameter(Tensor([learning_rate]),
requires_grad=learnable)
self._mom = torch.nn.Parameter(Tensor([momentum]),
requires_grad=learnable)
self.softplus = torch.nn.Softplus()
self.sigmoid = torch.nn.Sigmoid()

def reset_parameters(self):
self._lr.data.fill_(self._initial_lr)
self._mom.data.fill_(self._initial_mom)

@property
def learning_rate(self):
return self.softplus(self._lr)

@property
def momentum(self):
return self.sigmoid(self._mom)

def forward(
self,
x: Tensor,
y: Tensor,
index: Optional[Tensor],
func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],
iterations: int = 5,
) -> Tuple[Tensor, float]:

momentum_buffer = torch.zeros_like(y)
for _ in range(iterations):
val = func(x, y, index)
grad = torch.autograd.grad(val, y, create_graph=True,
retain_graph=True)[0]
delta = self.learning_rate * grad
momentum_buffer = self.momentum * momentum_buffer - delta
y = y + momentum_buffer
return y


class EquilibriumAggregation(Aggregation):
r"""
The graph global pooling layer from the
`"Equilibrium Aggregation: Encoding Sets via Optimization"
<https://arxiv.org/abs/2202.12795>`_ paper.
This output of this layer :math:`\mathbf{y}` is defined implicitly by
defining a potential function :math:`F(\mathbf{x}, \mathbf{y})`
and regulatization function :math:`R(\mathbf{y})` and the condition

.. math::
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) +
\sum_{i} F(\mathbf{x}_i, \mathbf{y})

This implementation use a ResNet Like model for the potential function
and a simple L2 norm for the regularizer with learnable weight
:math:`\lambda`.

Args:
in_channels (int): The number of channels in the input to the layer.
out_channels (float): The number of channels in the ouput.
num_layers (List[int): A list of the number of hidden units in the
potential function.
grad_iter (int): The number of steps to take in the internal gradient
descent. (default: :obj:`5`)
lamb (float): The initial regularization constant. Is learnable.
descent. (default: :obj:`0.1`)
"""
def __init__(self, in_channels: int, out_channels: int,
num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1):
super().__init__()

self.potential = ResNetPotential(in_channels + out_channels, 1,
num_layers)
self.optimizer = MomentumOptimizer()
self._initial_lambda = lamb
self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True)
self.softplus = torch.nn.Softplus()
self.grad_iter = grad_iter
self.output_dim = out_channels
self.reset_parameters()

def reset_parameters(self):
self.lamb.data.fill_(self._initial_lambda)
reset(self.optimizer)
reset(self.potential)

@property
def lamb(self):
return self.softplus(self._labmda)

def init_output(self, index: Optional[Tensor] = None) -> Tensor:
index_size = 1 if index is None else int(index.max().item() + 1)
return torch.zeros(index_size, self.output_dim,
requires_grad=True).float()

def reg(self, y: Tensor) -> float:
return self.lamb * y.square().mean(dim=-2).sum(dim=0)

def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]):
return self.potential(x, y, index) + self.reg(y)

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

self.assert_index_present(index)

index_size = 1 if index is None else index.max() + 1
dim_size = index_size if dim_size is None else dim_size

if dim_size < index_size:
raise ValueError("`dim_size` is less than `index` "
"implied size")

with torch.enable_grad():
y = self.optimizer(x, self.init_output(index), index, self.energy,
iterations=self.grad_iter)

if dim_size > index_size:
zero = torch.zeros(dim_size - index_size, *y.size()[1:])
y = torch.cat([y, zero])

return y

def __repr__(self) -> str:
return (f'{self.__class__.__name__}()')