-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
EquilibriumAggregation
global aggregation layer (#4522)
* wip equilibrium aggregation * update with docs and batch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * example of median * update init * update changelog * add details to median example * match parameters from paper * match parameters from paper * wip * add nesterov loss * change loss to be a mean across a single batch * move to aggreation package * move example import * add assertions for unsupported * update changelog * update changelog * add suggestions and test update * Update test/nn/aggr/test_equilibrium.py Co-authored-by: Guohao Li <lightaime@gmail.com> * Update test/nn/aggr/test_equilibrium.py Co-authored-by: Guohao Li <lightaime@gmail.com> * fix error message * fix indent error * rebase to fix tests * minor update for agg class * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Guohao Li <lightaime@gmail.com>
- Loading branch information
1 parent
0c24277
commit 333d3d3
Showing
5 changed files
with
287 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
r""" | ||
Replicates the experiment from `"Deep Graph Infomax" | ||
<https://arxiv.org/abs/1809.10341>`_ to try and teach | ||
`EquilibriumAggregation` to learn to take the median of | ||
a set of numbers | ||
This example converges slowly to being able to predict the | ||
median similar to what is observed in the paper. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
input_size = 100 | ||
steps = 10000000 | ||
embedding_size = 10 | ||
eval_each = 1000 | ||
|
||
model = EquilibriumAggregation(1, 10, [256, 256], 1) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
|
||
norm = torch.distributions.normal.Normal(0.5, 0.4) | ||
gamma = torch.distributions.gamma.Gamma(0.2, 0.5) | ||
uniform = torch.distributions.uniform.Uniform(0, 1) | ||
total_loss = 0 | ||
n_loss = 0 | ||
|
||
for i in range(steps): | ||
optimizer.zero_grad() | ||
dist = np.random.choice([norm, gamma, uniform]) | ||
x = dist.sample((input_size, 1)) | ||
y = model(x) | ||
loss = (y - x.median()).norm(2) / input_size | ||
loss.backward() | ||
optimizer.step() | ||
total_loss += loss | ||
n_loss += 1 | ||
if i % eval_each == (eval_each - 1): | ||
print(f"Average loss at epoc {i} is {total_loss / n_loss}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium(iter, alpha): | ||
|
||
batch_size = 10 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_size, feature_channels) | ||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x) | ||
assert out.size() == (1, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium_batch(iter, alpha): | ||
|
||
batch_1, batch_2 = 4, 6 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_1 + batch_2, feature_channels) | ||
batch = torch.tensor([0 for _ in range(batch_1)] + | ||
[1 for _ in range(batch_2)]) | ||
|
||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x, batch) | ||
assert out.size() == (2, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from typing import Callable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_scatter import scatter | ||
|
||
from torch_geometric.nn.aggr import Aggregation | ||
from torch_geometric.nn.inits import reset | ||
|
||
|
||
class ResNetPotential(torch.nn.Module): | ||
def __init__(self, in_channels: int, out_channels: int, | ||
num_layers: List[int]): | ||
|
||
super().__init__() | ||
sizes = [in_channels] + num_layers + [out_channels] | ||
self.layers = torch.nn.ModuleList([ | ||
torch.nn.Sequential(torch.nn.Linear(in_size, out_size), | ||
torch.nn.LayerNorm(out_size), torch.nn.Tanh()) | ||
for in_size, out_size in zip(sizes[:-2], sizes[1:-1]) | ||
]) | ||
self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1])) | ||
|
||
self.res_trans = torch.nn.ModuleList([ | ||
torch.nn.Linear(in_channels, layer_size) | ||
for layer_size in num_layers + [out_channels] | ||
]) | ||
|
||
def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor: | ||
if index is None: | ||
inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1) | ||
else: | ||
inp = torch.cat([x, y[index]], dim=1) | ||
|
||
h = inp | ||
for layer, res in zip(self.layers, self.res_trans): | ||
h = layer(h) | ||
h = res(inp) + h | ||
|
||
if index is None: | ||
return h.mean() | ||
|
||
size = int(index.max().item() + 1) | ||
return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum() | ||
|
||
|
||
class MomentumOptimizer(torch.nn.Module): | ||
r""" | ||
Provides an inner loop optimizer for the implicitly defined output | ||
layer. It is based on an unrolled Nesterov momentum algorithm. | ||
Args: | ||
learning_rate (flaot): learning rate for optimizer. | ||
momentum (float): momentum for optimizer. | ||
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and | ||
:obj:`momentum` will be learnable parameters. If False they | ||
are fixed. (default: :obj:`True`) | ||
""" | ||
def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9, | ||
learnable: bool = True): | ||
super().__init__() | ||
|
||
self._initial_lr = learning_rate | ||
self._initial_mom = momentum | ||
self._lr = torch.nn.Parameter(Tensor([learning_rate]), | ||
requires_grad=learnable) | ||
self._mom = torch.nn.Parameter(Tensor([momentum]), | ||
requires_grad=learnable) | ||
self.softplus = torch.nn.Softplus() | ||
self.sigmoid = torch.nn.Sigmoid() | ||
|
||
def reset_parameters(self): | ||
self._lr.data.fill_(self._initial_lr) | ||
self._mom.data.fill_(self._initial_mom) | ||
|
||
@property | ||
def learning_rate(self): | ||
return self.softplus(self._lr) | ||
|
||
@property | ||
def momentum(self): | ||
return self.sigmoid(self._mom) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
y: Tensor, | ||
index: Optional[Tensor], | ||
func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], | ||
iterations: int = 5, | ||
) -> Tuple[Tensor, float]: | ||
|
||
momentum_buffer = torch.zeros_like(y) | ||
for _ in range(iterations): | ||
val = func(x, y, index) | ||
grad = torch.autograd.grad(val, y, create_graph=True, | ||
retain_graph=True)[0] | ||
delta = self.learning_rate * grad | ||
momentum_buffer = self.momentum * momentum_buffer - delta | ||
y = y + momentum_buffer | ||
return y | ||
|
||
|
||
class EquilibriumAggregation(Aggregation): | ||
r""" | ||
The graph global pooling layer from the | ||
`"Equilibrium Aggregation: Encoding Sets via Optimization" | ||
<https://arxiv.org/abs/2202.12795>`_ paper. | ||
This output of this layer :math:`\mathbf{y}` is defined implicitly by | ||
defining a potential function :math:`F(\mathbf{x}, \mathbf{y})` | ||
and regulatization function :math:`R(\mathbf{y})` and the condition | ||
.. math:: | ||
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + | ||
\sum_{i} F(\mathbf{x}_i, \mathbf{y}) | ||
This implementation use a ResNet Like model for the potential function | ||
and a simple L2 norm for the regularizer with learnable weight | ||
:math:`\lambda`. | ||
Args: | ||
in_channels (int): The number of channels in the input to the layer. | ||
out_channels (float): The number of channels in the ouput. | ||
num_layers (List[int): A list of the number of hidden units in the | ||
potential function. | ||
grad_iter (int): The number of steps to take in the internal gradient | ||
descent. (default: :obj:`5`) | ||
lamb (float): The initial regularization constant. Is learnable. | ||
descent. (default: :obj:`0.1`) | ||
""" | ||
def __init__(self, in_channels: int, out_channels: int, | ||
num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1): | ||
super().__init__() | ||
|
||
self.potential = ResNetPotential(in_channels + out_channels, 1, | ||
num_layers) | ||
self.optimizer = MomentumOptimizer() | ||
self._initial_lambda = lamb | ||
self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True) | ||
self.softplus = torch.nn.Softplus() | ||
self.grad_iter = grad_iter | ||
self.output_dim = out_channels | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.lamb.data.fill_(self._initial_lambda) | ||
reset(self.optimizer) | ||
reset(self.potential) | ||
|
||
@property | ||
def lamb(self): | ||
return self.softplus(self._labmda) | ||
|
||
def init_output(self, index: Optional[Tensor] = None) -> Tensor: | ||
index_size = 1 if index is None else int(index.max().item() + 1) | ||
return torch.zeros(index_size, self.output_dim, | ||
requires_grad=True).float() | ||
|
||
def reg(self, y: Tensor) -> float: | ||
return self.lamb * y.square().mean(dim=-2).sum(dim=0) | ||
|
||
def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): | ||
return self.potential(x, y, index) + self.reg(y) | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
self.assert_index_present(index) | ||
|
||
index_size = 1 if index is None else index.max() + 1 | ||
dim_size = index_size if dim_size is None else dim_size | ||
|
||
if dim_size < index_size: | ||
raise ValueError("`dim_size` is less than `index` " | ||
"implied size") | ||
|
||
with torch.enable_grad(): | ||
y = self.optimizer(x, self.init_output(index), index, self.energy, | ||
iterations=self.grad_iter) | ||
|
||
if dim_size > index_size: | ||
zero = torch.zeros(dim_size - index_size, *y.size()[1:]) | ||
y = torch.cat([y, zero]) | ||
|
||
return y | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}()') |