-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EquilibriumAggregation
global aggregation layer
#4522
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
fbafde5
wip equilibrium aggregation
Padarn 3c1d0ba
update with docs and batch
Padarn 9b66232
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fcef12c
example of median
Padarn b21760c
update init
Padarn 58e1678
update changelog
Padarn 3b1bc45
add details to median example
Padarn a02eea8
match parameters from paper
Padarn 3f0a23b
match parameters from paper
Padarn 8d79424
wip
Padarn 3f1216a
add nesterov loss
Padarn 894a50f
change loss to be a mean across a single batch
Padarn 5091ffe
move to aggreation package
Padarn ed94b33
move example import
Padarn 3b3aaca
add assertions for unsupported
Padarn 99f020f
update changelog
Padarn 8561c8b
update changelog
Padarn 8460438
add suggestions and test update
Padarn 6a38ff9
Update test/nn/aggr/test_equilibrium.py
Padarn b34b3a5
Update test/nn/aggr/test_equilibrium.py
Padarn 4171372
fix error message
Padarn cc0eac1
fix indent error
Padarn fa47205
rebase to fix tests
Padarn 7359edc
minor update for agg class
Padarn 2236dc1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d504053
Merge branch 'master' into padarn/optim-embedding
Padarn a3a3a96
Merge branch 'master' into padarn/optim-embedding
Padarn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
r""" | ||
Replicates the experiment from `"Deep Graph Infomax" | ||
<https://arxiv.org/abs/1809.10341>`_ to try and teach | ||
`EquilibriumAggregation` to learn to take the median of | ||
a set of numbers | ||
|
||
This example converges slowly to being able to predict the | ||
median similar to what is observed in the paper. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
input_size = 100 | ||
steps = 10000000 | ||
embedding_size = 10 | ||
eval_each = 1000 | ||
|
||
model = EquilibriumAggregation(1, 10, [256, 256], 1) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
|
||
norm = torch.distributions.normal.Normal(0.5, 0.4) | ||
gamma = torch.distributions.gamma.Gamma(0.2, 0.5) | ||
uniform = torch.distributions.uniform.Uniform(0, 1) | ||
total_loss = 0 | ||
n_loss = 0 | ||
|
||
for i in range(steps): | ||
optimizer.zero_grad() | ||
dist = np.random.choice([norm, gamma, uniform]) | ||
x = dist.sample((input_size, 1)) | ||
y = model(x) | ||
loss = (y - x.median()).norm(2) / input_size | ||
loss.backward() | ||
optimizer.step() | ||
total_loss += loss | ||
n_loss += 1 | ||
if i % eval_each == (eval_each - 1): | ||
print(f"Average loss at epoc {i} is {total_loss / n_loss}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium(iter, alpha): | ||
|
||
batch_size = 10 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_size, feature_channels) | ||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x) | ||
assert out.size() == (1, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium_batch(iter, alpha): | ||
|
||
batch_1, batch_2 = 4, 6 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_1 + batch_2, feature_channels) | ||
batch = torch.tensor([0 for _ in range(batch_1)] + | ||
[1 for _ in range(batch_2)]) | ||
|
||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x, batch) | ||
assert out.size() == (2, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from typing import Callable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch_scatter import scatter | ||
|
||
from torch_geometric.nn.aggr import Aggregation | ||
from torch_geometric.nn.inits import reset | ||
|
||
|
||
class ResNetPotential(torch.nn.Module): | ||
def __init__(self, in_channels: int, out_channels: int, | ||
num_layers: List[int]): | ||
|
||
super().__init__() | ||
sizes = [in_channels] + num_layers + [out_channels] | ||
self.layers = torch.nn.ModuleList([ | ||
torch.nn.Sequential(torch.nn.Linear(in_size, out_size), | ||
torch.nn.LayerNorm(out_size), torch.nn.Tanh()) | ||
for in_size, out_size in zip(sizes[:-2], sizes[1:-1]) | ||
]) | ||
self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1])) | ||
|
||
self.res_trans = torch.nn.ModuleList([ | ||
torch.nn.Linear(in_channels, layer_size) | ||
for layer_size in num_layers + [out_channels] | ||
]) | ||
|
||
def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor: | ||
if index is None: | ||
inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1) | ||
else: | ||
inp = torch.cat([x, y[index]], dim=1) | ||
|
||
h = inp | ||
for layer, res in zip(self.layers, self.res_trans): | ||
h = layer(h) | ||
h = res(inp) + h | ||
|
||
if index is None: | ||
return h.mean() | ||
|
||
size = int(index.max().item() + 1) | ||
return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum() | ||
|
||
|
||
class MomentumOptimizer(torch.nn.Module): | ||
r""" | ||
Provides an inner loop optimizer for the implicitly defined output | ||
layer. It is based on an unrolled Nesterov momentum algorithm. | ||
|
||
Args: | ||
learning_rate (flaot): learning rate for optimizer. | ||
momentum (float): momentum for optimizer. | ||
learnable (bool): If :obj:`True` then the :obj:`learning_rate` and | ||
:obj:`momentum` will be learnable parameters. If False they | ||
are fixed. (default: :obj:`True`) | ||
""" | ||
def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9, | ||
learnable: bool = True): | ||
super().__init__() | ||
|
||
self._initial_lr = learning_rate | ||
self._initial_mom = momentum | ||
self._lr = torch.nn.Parameter(Tensor([learning_rate]), | ||
requires_grad=learnable) | ||
self._mom = torch.nn.Parameter(Tensor([momentum]), | ||
requires_grad=learnable) | ||
self.softplus = torch.nn.Softplus() | ||
self.sigmoid = torch.nn.Sigmoid() | ||
|
||
def reset_parameters(self): | ||
self._lr.data.fill_(self._initial_lr) | ||
self._mom.data.fill_(self._initial_mom) | ||
|
||
@property | ||
def learning_rate(self): | ||
return self.softplus(self._lr) | ||
|
||
@property | ||
def momentum(self): | ||
return self.sigmoid(self._mom) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
y: Tensor, | ||
index: Optional[Tensor], | ||
func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], | ||
iterations: int = 5, | ||
) -> Tuple[Tensor, float]: | ||
|
||
momentum_buffer = torch.zeros_like(y) | ||
for _ in range(iterations): | ||
val = func(x, y, index) | ||
grad = torch.autograd.grad(val, y, create_graph=True, | ||
retain_graph=True)[0] | ||
delta = self.learning_rate * grad | ||
momentum_buffer = self.momentum * momentum_buffer - delta | ||
y = y + momentum_buffer | ||
return y | ||
|
||
|
||
class EquilibriumAggregation(Aggregation): | ||
r""" | ||
The graph global pooling layer from the | ||
`"Equilibrium Aggregation: Encoding Sets via Optimization" | ||
<https://arxiv.org/abs/2202.12795>`_ paper. | ||
This output of this layer :math:`\mathbf{y}` is defined implicitly by | ||
defining a potential function :math:`F(\mathbf{x}, \mathbf{y})` | ||
and regulatization function :math:`R(\mathbf{y})` and the condition | ||
|
||
.. math:: | ||
\mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + | ||
\sum_{i} F(\mathbf{x}_i, \mathbf{y}) | ||
|
||
This implementation use a ResNet Like model for the potential function | ||
and a simple L2 norm for the regularizer with learnable weight | ||
:math:`\lambda`. | ||
|
||
Args: | ||
in_channels (int): The number of channels in the input to the layer. | ||
out_channels (float): The number of channels in the ouput. | ||
num_layers (List[int): A list of the number of hidden units in the | ||
potential function. | ||
grad_iter (int): The number of steps to take in the internal gradient | ||
descent. (default: :obj:`5`) | ||
lamb (float): The initial regularization constant. Is learnable. | ||
descent. (default: :obj:`0.1`) | ||
""" | ||
def __init__(self, in_channels: int, out_channels: int, | ||
num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1): | ||
super().__init__() | ||
|
||
self.potential = ResNetPotential(in_channels + out_channels, 1, | ||
num_layers) | ||
self.optimizer = MomentumOptimizer() | ||
self._initial_lambda = lamb | ||
self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True) | ||
self.softplus = torch.nn.Softplus() | ||
self.grad_iter = grad_iter | ||
self.output_dim = out_channels | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.lamb.data.fill_(self._initial_lambda) | ||
reset(self.optimizer) | ||
reset(self.potential) | ||
|
||
@property | ||
def lamb(self): | ||
return self.softplus(self._labmda) | ||
|
||
def init_output(self, index: Optional[Tensor] = None) -> Tensor: | ||
index_size = 1 if index is None else int(index.max().item() + 1) | ||
return torch.zeros(index_size, self.output_dim, | ||
requires_grad=True).float() | ||
|
||
def reg(self, y: Tensor) -> float: | ||
return self.lamb * y.square().mean(dim=-2).sum(dim=0) | ||
|
||
def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): | ||
return self.potential(x, y, index) + self.reg(y) | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, | ||
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, | ||
dim: int = -2) -> Tensor: | ||
|
||
self.assert_index_present(index) | ||
|
||
index_size = 1 if index is None else index.max() + 1 | ||
dim_size = index_size if dim_size is None else dim_size | ||
|
||
if dim_size < index_size: | ||
raise ValueError("`dim_size` is less than `index` " | ||
"implied size") | ||
|
||
with torch.enable_grad(): | ||
y = self.optimizer(x, self.init_output(index), index, self.energy, | ||
iterations=self.grad_iter) | ||
|
||
if dim_size > index_size: | ||
zero = torch.zeros(dim_size - index_size, *y.size()[1:]) | ||
y = torch.cat([y, zero]) | ||
|
||
return y | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}()') |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add a few comments about the convergence of
EquilibriumAggregation
.