Skip to content

Commit

Permalink
EquilibriumAggregation global aggregation layer (#4522)
Browse files Browse the repository at this point in the history
* wip equilibrium aggregation

* update with docs and batch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* example of median

* update init

* update changelog

* add details to median example

* match parameters from paper

* match parameters from paper

* wip

* add nesterov loss

* change loss to be a mean across a single batch

* move to aggreation package

* move example import

* add assertions for unsupported

* update changelog

* update changelog

* add suggestions and test update

* Update test/nn/aggr/test_equilibrium.py

Co-authored-by: Guohao Li <lightaime@gmail.com>

* Update test/nn/aggr/test_equilibrium.py

Co-authored-by: Guohao Li <lightaime@gmail.com>

* fix error message

* fix indent error

* rebase to fix tests

* minor update for agg class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Guohao Li <lightaime@gmail.com>
  • Loading branch information
3 people authored Jul 27, 2022
1 parent 0c24277 commit 333d3d3
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand All @@ -58,6 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `nn.glob.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
Expand Down
41 changes: 41 additions & 0 deletions examples/equilibrium_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
r"""
Replicates the experiment from `"Deep Graph Infomax"
<https://arxiv.org/abs/1809.10341>`_ to try and teach
`EquilibriumAggregation` to learn to take the median of
a set of numbers
This example converges slowly to being able to predict the
median similar to what is observed in the paper.
"""

import numpy as np
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation

input_size = 100
steps = 10000000
embedding_size = 10
eval_each = 1000

model = EquilibriumAggregation(1, 10, [256, 256], 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

norm = torch.distributions.normal.Normal(0.5, 0.4)
gamma = torch.distributions.gamma.Gamma(0.2, 0.5)
uniform = torch.distributions.uniform.Uniform(0, 1)
total_loss = 0
n_loss = 0

for i in range(steps):
optimizer.zero_grad()
dist = np.random.choice([norm, gamma, uniform])
x = dist.sample((input_size, 1))
y = model(x)
loss = (y - x.median()).norm(2) / input_size
loss.backward()
optimizer.step()
total_loss += loss
n_loss += 1
if i % eval_each == (eval_each - 1):
print(f"Average loss at epoc {i} is {total_loss / n_loss}")
53 changes: 53 additions & 0 deletions test/nn/aggr/test_equilibrium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium(iter, alpha):

batch_size = 10
feature_channels = 3
output_channels = 2
x = torch.randn(batch_size, feature_channels)
model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x)
assert out.size() == (1, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium_batch(iter, alpha):

batch_1, batch_2 = 4, 6
feature_channels = 3
output_channels = 2
x = torch.randn(batch_1 + batch_2, feature_channels)
batch = torch.tensor([0 for _ in range(batch_1)] +
[1 for _ in range(batch_2)])

model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x, batch)
assert out.size() == (2, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .lstm import LSTMAggregation
from .set2set import Set2Set
from .scaler import DegreeScalerAggregation
from .equilibrium import EquilibriumAggregation
from .sort import SortAggr
from .gmt import GraphMultisetTransformer
from .attention import AttentionalAggregation
Expand All @@ -36,4 +37,5 @@
'SortAggr',
'GraphMultisetTransformer',
'AttentionalAggregation',
'EquilibriumAggregation',
]
189 changes: 189 additions & 0 deletions torch_geometric/nn/aggr/equilibrium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
from torch_scatter import scatter

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.inits import reset


class ResNetPotential(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int,
num_layers: List[int]):

super().__init__()
sizes = [in_channels] + num_layers + [out_channels]
self.layers = torch.nn.ModuleList([
torch.nn.Sequential(torch.nn.Linear(in_size, out_size),
torch.nn.LayerNorm(out_size), torch.nn.Tanh())
for in_size, out_size in zip(sizes[:-2], sizes[1:-1])
])
self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1]))

self.res_trans = torch.nn.ModuleList([
torch.nn.Linear(in_channels, layer_size)
for layer_size in num_layers + [out_channels]
])

def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor:
if index is None:
inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1)
else:
inp = torch.cat([x, y[index]], dim=1)

h = inp
for layer, res in zip(self.layers, self.res_trans):
h = layer(h)
h = res(inp) + h

if index is None:
return h.mean()

size = int(index.max().item() + 1)
return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum()


class MomentumOptimizer(torch.nn.Module):
r"""
Provides an inner loop optimizer for the implicitly defined output
layer. It is based on an unrolled Nesterov momentum algorithm.
Args:
learning_rate (flaot): learning rate for optimizer.
momentum (float): momentum for optimizer.
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and
:obj:`momentum` will be learnable parameters. If False they
are fixed. (default: :obj:`True`)
"""
def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9,
learnable: bool = True):
super().__init__()

self._initial_lr = learning_rate
self._initial_mom = momentum
self._lr = torch.nn.Parameter(Tensor([learning_rate]),
requires_grad=learnable)
self._mom = torch.nn.Parameter(Tensor([momentum]),
requires_grad=learnable)
self.softplus = torch.nn.Softplus()
self.sigmoid = torch.nn.Sigmoid()

def reset_parameters(self):
self._lr.data.fill_(self._initial_lr)
self._mom.data.fill_(self._initial_mom)

@property
def learning_rate(self):
return self.softplus(self._lr)

@property
def momentum(self):
return self.sigmoid(self._mom)

def forward(
self,
x: Tensor,
y: Tensor,
index: Optional[Tensor],
func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],
iterations: int = 5,
) -> Tuple[Tensor, float]:

momentum_buffer = torch.zeros_like(y)
for _ in range(iterations):
val = func(x, y, index)
grad = torch.autograd.grad(val, y, create_graph=True,
retain_graph=True)[0]
delta = self.learning_rate * grad
momentum_buffer = self.momentum * momentum_buffer - delta
y = y + momentum_buffer
return y


class EquilibriumAggregation(Aggregation):
r"""
The graph global pooling layer from the
`"Equilibrium Aggregation: Encoding Sets via Optimization"
<https://arxiv.org/abs/2202.12795>`_ paper.
This output of this layer :math:`\mathbf{y}` is defined implicitly by
defining a potential function :math:`F(\mathbf{x}, \mathbf{y})`
and regulatization function :math:`R(\mathbf{y})` and the condition
.. math::
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) +
\sum_{i} F(\mathbf{x}_i, \mathbf{y})
This implementation use a ResNet Like model for the potential function
and a simple L2 norm for the regularizer with learnable weight
:math:`\lambda`.
Args:
in_channels (int): The number of channels in the input to the layer.
out_channels (float): The number of channels in the ouput.
num_layers (List[int): A list of the number of hidden units in the
potential function.
grad_iter (int): The number of steps to take in the internal gradient
descent. (default: :obj:`5`)
lamb (float): The initial regularization constant. Is learnable.
descent. (default: :obj:`0.1`)
"""
def __init__(self, in_channels: int, out_channels: int,
num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1):
super().__init__()

self.potential = ResNetPotential(in_channels + out_channels, 1,
num_layers)
self.optimizer = MomentumOptimizer()
self._initial_lambda = lamb
self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True)
self.softplus = torch.nn.Softplus()
self.grad_iter = grad_iter
self.output_dim = out_channels
self.reset_parameters()

def reset_parameters(self):
self.lamb.data.fill_(self._initial_lambda)
reset(self.optimizer)
reset(self.potential)

@property
def lamb(self):
return self.softplus(self._labmda)

def init_output(self, index: Optional[Tensor] = None) -> Tensor:
index_size = 1 if index is None else int(index.max().item() + 1)
return torch.zeros(index_size, self.output_dim,
requires_grad=True).float()

def reg(self, y: Tensor) -> float:
return self.lamb * y.square().mean(dim=-2).sum(dim=0)

def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]):
return self.potential(x, y, index) + self.reg(y)

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

self.assert_index_present(index)

index_size = 1 if index is None else index.max() + 1
dim_size = index_size if dim_size is None else dim_size

if dim_size < index_size:
raise ValueError("`dim_size` is less than `index` "
"implied size")

with torch.enable_grad():
y = self.optimizer(x, self.init_output(index), index, self.energy,
iterations=self.grad_iter)

if dim_size > index_size:
zero = torch.zeros(dim_size - index_size, *y.size()[1:])
y = torch.cat([y, zero])

return y

def __repr__(self) -> str:
return (f'{self.__class__.__name__}()')

0 comments on commit 333d3d3

Please sign in to comment.