From fbafde5613730178f674da6457a1810dbb2bb154 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 24 Apr 2022 10:05:07 +0800 Subject: [PATCH 01/25] wip equilibrium aggregation --- examples/mutag_gin_equib.py | 100 ++++++++++++++++++ .../nn/glob/equilibrium_aggregation.py | 72 +++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 examples/mutag_gin_equib.py create mode 100644 torch_geometric/nn/glob/equilibrium_aggregation.py diff --git a/examples/mutag_gin_equib.py b/examples/mutag_gin_equib.py new file mode 100644 index 000000000000..0c9c70c76ab5 --- /dev/null +++ b/examples/mutag_gin_equib.py @@ -0,0 +1,100 @@ +import os.path as osp + +import torch +import torch.nn.functional as F +from torch.nn import BatchNorm1d, Linear, ReLU, Sequential + +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import EquilibriumAggregation, GINConv + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU') +dataset = TUDataset(path, name='MUTAG') + +train_dataset = dataset[len(dataset) // 10:] +test_dataset = dataset[:len(dataset) // 10] + +train_loader = DataLoader(train_dataset, batch_size=1) +test_loader = DataLoader(test_dataset, batch_size=1) + + +class Net(torch.nn.Module): + def __init__(self, in_channels, dim, out_channels): + super().__init__() + + self.conv1 = GINConv( + Sequential(Linear(in_channels, dim), BatchNorm1d(dim), ReLU(), + Linear(dim, dim), ReLU())) + + self.conv2 = GINConv( + Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), + Linear(dim, dim), ReLU())) + + self.conv3 = GINConv( + Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), + Linear(dim, dim), ReLU())) + + self.conv4 = GINConv( + Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), + Linear(dim, dim), ReLU())) + + self.conv5 = GINConv( + Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), + Linear(dim, dim), ReLU())) + + self.lin1 = Linear(dim, dim) + self.lin2 = Linear(dim, out_channels) + self.readout = EquilibriumAggregation(dim, dim, [256, 256], + grad_iter=10, alpha=0.05) + + def forward(self, x, edge_index, batch): + x = self.conv1(x, edge_index) + x = self.conv2(x, edge_index) + x = self.conv3(x, edge_index) + x = self.conv4(x, edge_index) + x = self.conv5(x, edge_index) + x = self.readout(x).unsqueeze(0) + x = self.lin1(x).relu() + x = F.dropout(x, p=0.5, training=self.training) + x = self.lin2(x) + return F.log_softmax(x, dim=-1) + + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +model = Net(dataset.num_features, 32, dataset.num_classes).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + +def train(): + model.train() + + total_loss = 0 + for data in train_loader: + data = data.to(device) + optimizer.zero_grad() + output = model(data.x, data.edge_index, data.batch) + loss = F.nll_loss(output, data.y) + loss.backward() + optimizer.step() + total_loss += float(loss) * data.num_graphs + return total_loss / len(train_loader.dataset) + + +@torch.no_grad() +def test(loader): + model.eval() + + total_correct = 0 + for data in loader: + data = data.to(device) + out = model(data.x, data.edge_index, data.batch) + total_correct += int((out.argmax(-1) == data.y).sum()) + return total_correct / len(loader.dataset) + + +for epoch in range(1, 101): + loss = train() + train_acc = test(train_loader) + test_acc = test(test_loader) + print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f} ' + f'Test Acc: {test_acc:.4f}') diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py new file mode 100644 index 000000000000..23451ff92df7 --- /dev/null +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -0,0 +1,72 @@ +from typing import List + +import torch + +EPS = 1e-15 + + +class ResNetPotential(torch.nn.Module): + def __init__(self, input_size: int, layers: List[int]): + + super().__init__() + output_size = 1 + sizes = [input_size] + layers + [output_size] + self.layers = torch.nn.ModuleList([ + torch.nn.Sequential(torch.nn.Linear(in_size, out_size), + torch.nn.LayerNorm(out_size), torch.nn.Tanh()) + for in_size, out_size in zip(sizes[:-1], sizes[1:]) + ]) + + self.res_trans = torch.nn.ModuleList([ + torch.nn.Linear(input_size, layer_size) + for layer_size in layers + [output_size] + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + for layer, res in zip(self.layers, self.res_trans): + h = layer(h) + h = res(x) + h + + return h + + +class EquilibriumAggregation(torch.nn.Module): + """ + Args: + potential (torch.nn.Module): trainable potenial function + """ + def __init__(self, input_dim: int, output_dim: int, layers: List[int], + grad_iter: int = 5, alpha: float = 0.1): + super().__init__() + + self.potential = ResNetPotential(input_dim + output_dim, layers) + self.lamb = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True) + self.grad_iter = grad_iter + self.alpha = alpha + self.output_dim = output_dim + + def init_output(self): + return torch.zeros(self.output_dim, requires_grad=True) + + def reg(self, y): + return torch.nn.Softplus()( + self.lamb) * (y + EPS).square().mean().sqrt() + + def combine_input(self, x, y): + return torch.cat([x, y.expand(x.size(0), -1)], dim=1) + + def forward(self, x: torch.Tensor): + grad_enabled = torch.is_grad_enabled() + torch.set_grad_enabled(True) + yhat = self.init_output() + for _ in range(self.grad_iter): + erg = self.potential(self.combine_input( + x, yhat)).mean() + self.reg(yhat) + yhat = yhat - self.alpha * torch.autograd.grad( + erg, yhat, create_graph=True, retain_graph=True)[0] + torch.set_grad_enabled(grad_enabled) + return yhat + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}()') From 3c1d0ba611f3c46839f2aabae54929a640e707aa Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 29 Apr 2022 21:51:20 +0800 Subject: [PATCH 02/25] update with docs and batch --- test/nn/glob/test_equilibrium_aggregation.py | 39 ++++++ torch_geometric/data/dataset.py | 1 + torch_geometric/datasets/ppi.py | 2 + .../nn/glob/equilibrium_aggregation.py | 118 +++++++++++++----- 4 files changed, 126 insertions(+), 34 deletions(-) create mode 100644 test/nn/glob/test_equilibrium_aggregation.py diff --git a/test/nn/glob/test_equilibrium_aggregation.py b/test/nn/glob/test_equilibrium_aggregation.py new file mode 100644 index 000000000000..7ba73a14331f --- /dev/null +++ b/test/nn/glob/test_equilibrium_aggregation.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from torch_geometric.nn import EquilibriumAggregation + + +@pytest.mark.parametrize('iter', [0, 1, 5]) +@pytest.mark.parametrize('alpha', [0, .1, 5]) +def test_equilibrium_aggregation(iter, alpha): + + batch = 10 + feature_channels = 3 + output_channels = 2 + x = torch.randn(batch, feature_channels) + potential = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter, + alpha=alpha) + + out = potential(x) + assert out.size() == (1, 2) + + +@pytest.mark.parametrize('iter', [0, 1, 5]) +@pytest.mark.parametrize('alpha', [0, .1, 5]) +def test_equilibrium_aggregation_batch(iter, alpha): + + batch_1, batch_2 = 4, 6 + feature_channels = 3 + output_channels = 2 + x = torch.randn(batch_1 + batch_2, feature_channels) + batch = torch.tensor([0 for _ in range(batch_1)] + + [1 for _ in range(batch_2)]) + + potential = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter, + alpha=alpha) + + out = potential(x, batch) + assert out.size() == (2, 2) diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index cf006b5f1054..cb795ff42c14 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -146,6 +146,7 @@ def _download(self): def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pt') + print(_repr(self.pre_transform)) if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): warnings.warn( f"The `pre_transform` argument differs from the one used in " diff --git a/torch_geometric/datasets/ppi.py b/torch_geometric/datasets/ppi.py index 4ccd3dc9e6cd..b90acce26f61 100644 --- a/torch_geometric/datasets/ppi.py +++ b/torch_geometric/datasets/ppi.py @@ -91,6 +91,7 @@ def download(self): def process(self): import networkx as nx from networkx.readwrite import json_graph + print("test") for s, split in enumerate(['train', 'valid', 'test']): path = osp.join(self.raw_dir, f'{split}_graph.json') @@ -123,6 +124,7 @@ def process(self): continue if self.pre_transform is not None: + print("test") data = self.pre_transform(data) data_list.append(data) diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 23451ff92df7..934957350746 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -1,16 +1,17 @@ -from typing import List +from typing import List, Optional import torch +from torch_scatter import scatter -EPS = 1e-15 +from torch_geometric.nn.inits import normal, reset class ResNetPotential(torch.nn.Module): - def __init__(self, input_size: int, layers: List[int]): + def __init__(self, in_channels: int, num_layers: List[int]): super().__init__() output_size = 1 - sizes = [input_size] + layers + [output_size] + sizes = [in_channels] + num_layers + [output_size] self.layers = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(in_size, out_size), torch.nn.LayerNorm(out_size), torch.nn.Tanh()) @@ -18,8 +19,8 @@ def __init__(self, input_size: int, layers: List[int]): ]) self.res_trans = torch.nn.ModuleList([ - torch.nn.Linear(input_size, layer_size) - for layer_size in layers + [output_size] + torch.nn.Linear(in_channels, layer_size) + for layer_size in num_layers + [output_size] ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -32,41 +33,90 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EquilibriumAggregation(torch.nn.Module): - """ + r""" + The graph global pooling layer from the + `"Equilibrium Aggregation: Encoding Sets via Optimization" + `_ paper. + This output of this layer :math:`\mathbf{y}` is defined implicitly by + defining a potential function :math:`F(\mathbf{x}, \mathbf{y})` + and regulatization function :math:`R(\mathbf{y})` and the condition + + .. math:: + \mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + + \sum_{i} F(\mathbf{x}_i, \mathbf{y}) + + This implementation use a ResNet Like model for the potential function + and a simple L2 norm for the regularizer with learnable weight + :math:`\lambda`. + + .. note:: + + The forward function of this layer accepts a :obj:`batch` argument that + works like the other global pooling layers when working on input that + is from multiple graphs. + Args: - potential (torch.nn.Module): trainable potenial function + in_channels (int): The number of channels in the input to the layer. + out_channels (float): The number of channels in the ouput. + num_layers (List[int): A list of the number of hidden units in the + potential function. + grad_iter (int): The number of steps to take in the internal gradient + descent. (default: :obj:`5`) + alpha (float): The step size of the internal gradient descent. + (default: :obj:`0.1`) + """ - def __init__(self, input_dim: int, output_dim: int, layers: List[int], - grad_iter: int = 5, alpha: float = 0.1): + def __init__(self, in_channels: int, out_channels: int, + num_layers: List[int], grad_iter: int = 5, + alpha: float = 0.1): super().__init__() - self.potential = ResNetPotential(input_dim + output_dim, layers) + self.potential = ResNetPotential(in_channels + out_channels, + num_layers) self.lamb = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True) + self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter self.alpha = alpha - self.output_dim = output_dim - - def init_output(self): - return torch.zeros(self.output_dim, requires_grad=True) - - def reg(self, y): - return torch.nn.Softplus()( - self.lamb) * (y + EPS).square().mean().sqrt() - - def combine_input(self, x, y): - return torch.cat([x, y.expand(x.size(0), -1)], dim=1) - - def forward(self, x: torch.Tensor): - grad_enabled = torch.is_grad_enabled() - torch.set_grad_enabled(True) - yhat = self.init_output() - for _ in range(self.grad_iter): - erg = self.potential(self.combine_input( - x, yhat)).mean() + self.reg(yhat) - yhat = yhat - self.alpha * torch.autograd.grad( - erg, yhat, create_graph=True, retain_graph=True)[0] - torch.set_grad_enabled(grad_enabled) - return yhat + self.output_dim = out_channels + + self.reset_parameters() + + def reset_parameters(self): + normal(self.lamb, 0, 0.1) + reset(self.potential) + + def init_output(self, + batch: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size = 1 if batch is None else int(batch.max().item() + 1) + return torch.zeros(batch_size, self.output_dim, requires_grad=True) + + def reg(self, y: torch.Tensor) -> float: + return self.softplus(self.lamb) * y.norm(dim=1, keepdim=True) + + def combine_input(self, x: torch.Tensor, y: torch.Tensor, + batch: Optional[torch.Tensor] = None) -> torch.Tensor: + if batch is None: + return torch.cat([x, y.expand(x.size(0), -1)], dim=1) + return torch.cat([x, y[batch]], dim=1) + + def forward(self, x: torch.Tensor, + batch: Optional[torch.Tensor] = None) -> torch.Tensor: + with torch.enable_grad(): + yhat = self.init_output(batch) + for _ in range(self.grad_iter): + z = self.combine_input(x, yhat, batch) + potential = self.potential(z) + if batch is None: + potential = potential.mean(axis=0, keepdim=True) + else: + size = int(batch.max().item() + 1) + potential = scatter(potential, batch, dim=0, dim_size=size, + reduce='mean') + reg = self.reg(yhat) + enrg = (potential + reg).sum() + yhat = yhat - self.alpha * torch.autograd.grad( + enrg, yhat, create_graph=True, retain_graph=True)[0] + return yhat def __repr__(self) -> str: return (f'{self.__class__.__name__}()') From 9b662321a71c575c7b407c2f0bb634ba3c4e89c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Apr 2022 09:48:35 +0000 Subject: [PATCH 03/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/glob/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/glob/__init__.py b/torch_geometric/nn/glob/__init__.py index be0b138ffcd2..1d1c4daccb91 100644 --- a/torch_geometric/nn/glob/__init__.py +++ b/torch_geometric/nn/glob/__init__.py @@ -1,17 +1,17 @@ from .glob import global_add_pool, global_mean_pool, global_max_pool -from .glob import GlobalPooling from .sort import global_sort_pool from .attention import GlobalAttention from .gmt import GraphMultisetTransformer +from .equilibrium_aggregation import EquilibriumAggregation __all__ = [ 'global_add_pool', 'global_mean_pool', 'global_max_pool', - 'GlobalPooling', 'global_sort_pool', 'GlobalAttention', 'GraphMultisetTransformer', + 'EquilibriumAggregation', ] classes = __all__ From fcef12cf9d1ff1e58c8fe7f85dcc0dbc03ec78ce Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 29 Apr 2022 22:21:44 +0800 Subject: [PATCH 04/25] example of median --- examples/equilibrium_aggregation_median.py | 27 +++++ examples/mutag_gin_equib.py | 100 ------------------ .../nn/glob/equilibrium_aggregation.py | 6 +- 3 files changed, 30 insertions(+), 103 deletions(-) create mode 100644 examples/equilibrium_aggregation_median.py delete mode 100644 examples/mutag_gin_equib.py diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py new file mode 100644 index 000000000000..976813df8d1d --- /dev/null +++ b/examples/equilibrium_aggregation_median.py @@ -0,0 +1,27 @@ +r""" +Replicates the experiment from `"Deep Graph Infomax" +`_ to try and teach +`EquilibriumAggregation` to learn to take the median of +a set of numbers +""" + +import torch + +from torch_geometric.nn import EquilibriumAggregation + +input_size = 1000 +epochs = 100 + +model = EquilibriumAggregation(1, 1, [256, 256], 5, 0.1) +optimizer = torch.optim.Adam(model.parameters(), lr=0.00001) + +for i in range(epochs): + optimizer.zero_grad() + x = torch.rand(input_size, 1) + y = model(x) + loss = (y - x.median()).norm(2) + print(y, x.median()) + loss.backward() + optimizer.step() + if i % 10 == 9: + print(f"Loss at epoc {i} is {loss}") diff --git a/examples/mutag_gin_equib.py b/examples/mutag_gin_equib.py deleted file mode 100644 index 0c9c70c76ab5..000000000000 --- a/examples/mutag_gin_equib.py +++ /dev/null @@ -1,100 +0,0 @@ -import os.path as osp - -import torch -import torch.nn.functional as F -from torch.nn import BatchNorm1d, Linear, ReLU, Sequential - -from torch_geometric.datasets import TUDataset -from torch_geometric.loader import DataLoader -from torch_geometric.nn import EquilibriumAggregation, GINConv - -path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU') -dataset = TUDataset(path, name='MUTAG') - -train_dataset = dataset[len(dataset) // 10:] -test_dataset = dataset[:len(dataset) // 10] - -train_loader = DataLoader(train_dataset, batch_size=1) -test_loader = DataLoader(test_dataset, batch_size=1) - - -class Net(torch.nn.Module): - def __init__(self, in_channels, dim, out_channels): - super().__init__() - - self.conv1 = GINConv( - Sequential(Linear(in_channels, dim), BatchNorm1d(dim), ReLU(), - Linear(dim, dim), ReLU())) - - self.conv2 = GINConv( - Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), - Linear(dim, dim), ReLU())) - - self.conv3 = GINConv( - Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), - Linear(dim, dim), ReLU())) - - self.conv4 = GINConv( - Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), - Linear(dim, dim), ReLU())) - - self.conv5 = GINConv( - Sequential(Linear(dim, dim), BatchNorm1d(dim), ReLU(), - Linear(dim, dim), ReLU())) - - self.lin1 = Linear(dim, dim) - self.lin2 = Linear(dim, out_channels) - self.readout = EquilibriumAggregation(dim, dim, [256, 256], - grad_iter=10, alpha=0.05) - - def forward(self, x, edge_index, batch): - x = self.conv1(x, edge_index) - x = self.conv2(x, edge_index) - x = self.conv3(x, edge_index) - x = self.conv4(x, edge_index) - x = self.conv5(x, edge_index) - x = self.readout(x).unsqueeze(0) - x = self.lin1(x).relu() - x = F.dropout(x, p=0.5, training=self.training) - x = self.lin2(x) - return F.log_softmax(x, dim=-1) - - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -model = Net(dataset.num_features, 32, dataset.num_classes).to(device) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - - -def train(): - model.train() - - total_loss = 0 - for data in train_loader: - data = data.to(device) - optimizer.zero_grad() - output = model(data.x, data.edge_index, data.batch) - loss = F.nll_loss(output, data.y) - loss.backward() - optimizer.step() - total_loss += float(loss) * data.num_graphs - return total_loss / len(train_loader.dataset) - - -@torch.no_grad() -def test(loader): - model.eval() - - total_correct = 0 - for data in loader: - data = data.to(device) - out = model(data.x, data.edge_index, data.batch) - total_correct += int((out.argmax(-1) == data.y).sum()) - return total_correct / len(loader.dataset) - - -for epoch in range(1, 101): - loss = train() - train_acc = test(train_loader) - test_acc = test(test_loader) - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f} ' - f'Test Acc: {test_acc:.4f}') diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 934957350746..7cf2ced64387 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -3,7 +3,7 @@ import torch from torch_scatter import scatter -from torch_geometric.nn.inits import normal, reset +from torch_geometric.nn.inits import reset class ResNetPotential(torch.nn.Module): @@ -73,7 +73,7 @@ def __init__(self, in_channels: int, out_channels: int, self.potential = ResNetPotential(in_channels + out_channels, num_layers) - self.lamb = torch.nn.Parameter(torch.Tensor([1]), requires_grad=True) + self.lamb = torch.nn.Parameter(torch.Tensor([0.1]), requires_grad=True) self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter self.alpha = alpha @@ -82,7 +82,7 @@ def __init__(self, in_channels: int, out_channels: int, self.reset_parameters() def reset_parameters(self): - normal(self.lamb, 0, 0.1) + self.lamb.data.fill_(0.1) reset(self.potential) def init_output(self, From b21760c7c345bda5bb2a6fd43c70b12d4d6bd5a7 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Wed, 4 May 2022 14:38:16 +0800 Subject: [PATCH 05/25] update init --- examples/equilibrium_aggregation_median.py | 9 ++++----- torch_geometric/nn/glob/equilibrium_aggregation.py | 5 +++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index 976813df8d1d..302465864d33 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -9,19 +9,18 @@ from torch_geometric.nn import EquilibriumAggregation -input_size = 1000 -epochs = 100 +input_size = 100 +epochs = 1000 model = EquilibriumAggregation(1, 1, [256, 256], 5, 0.1) -optimizer = torch.optim.Adam(model.parameters(), lr=0.00001) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for i in range(epochs): optimizer.zero_grad() x = torch.rand(input_size, 1) y = model(x) loss = (y - x.median()).norm(2) - print(y, x.median()) loss.backward() optimizer.step() if i % 10 == 9: - print(f"Loss at epoc {i} is {loss}") + print(f"Loss at epoc {i} is {loss}, Median {x.median()}, Y {y.item()}") diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 7cf2ced64387..375b9a1def8a 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -88,10 +88,11 @@ def reset_parameters(self): def init_output(self, batch: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = 1 if batch is None else int(batch.max().item() + 1) - return torch.zeros(batch_size, self.output_dim, requires_grad=True) + return torch.randn(batch_size, self.output_dim, + requires_grad=True) * 0.01 def reg(self, y: torch.Tensor) -> float: - return self.softplus(self.lamb) * y.norm(dim=1, keepdim=True) + return self.softplus(self.lamb) * y.norm(2, dim=1, keepdim=True) def combine_input(self, x: torch.Tensor, y: torch.Tensor, batch: Optional[torch.Tensor] = None) -> torch.Tensor: From 58e1678147b9ef2c13fe3b5f34c46859ae06af13 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Wed, 4 May 2022 14:40:45 +0800 Subject: [PATCH 06/25] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15090708aa02..194448c66268 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604)) - Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600)) - Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) +- Added `nn.glob.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522)) - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed From 3b1bc45b180be5d95d409cbbf47b196d03809570 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 6 May 2022 09:27:49 +0800 Subject: [PATCH 07/25] add details to median example --- examples/equilibrium_aggregation_median.py | 22 ++++++++++++++----- test/nn/glob/test_equilibrium_aggregation.py | 18 ++++++++------- torch_geometric/data/dataset.py | 1 - torch_geometric/datasets/ppi.py | 2 -- .../nn/glob/equilibrium_aggregation.py | 14 ++++++------ 5 files changed, 34 insertions(+), 23 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index 302465864d33..6321e401ac48 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -5,22 +5,34 @@ a set of numbers """ +import numpy as np import torch from torch_geometric.nn import EquilibriumAggregation input_size = 100 -epochs = 1000 +epochs = 10000 +embedding_size = 10 -model = EquilibriumAggregation(1, 1, [256, 256], 5, 0.1) +model = torch.nn.Sequential(EquilibriumAggregation(1, 10, [256, 256], 5, 0.1), + torch.nn.Linear(10, 1)) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +norm = torch.distributions.normal.Normal(0, 0.5) +gamma = torch.distributions.gamma.Gamma(1, 2) +uniform = torch.distributions.uniform.Uniform(-1, 1) + +total_loss = 0 +n_loss = 0 for i in range(epochs): optimizer.zero_grad() - x = torch.rand(input_size, 1) + dist = np.random.choice([norm, gamma, uniform]) + x = dist.sample((input_size, 1)) y = model(x) loss = (y - x.median()).norm(2) loss.backward() optimizer.step() - if i % 10 == 9: - print(f"Loss at epoc {i} is {loss}, Median {x.median()}, Y {y.item()}") + total_loss += loss + n_loss += 1 + if i % 500 == 499: + print(f"Average loss at epoc {i} is {total_loss/n_loss}") diff --git a/test/nn/glob/test_equilibrium_aggregation.py b/test/nn/glob/test_equilibrium_aggregation.py index 7ba73a14331f..1df30ec78a88 100644 --- a/test/nn/glob/test_equilibrium_aggregation.py +++ b/test/nn/glob/test_equilibrium_aggregation.py @@ -12,11 +12,12 @@ def test_equilibrium_aggregation(iter, alpha): feature_channels = 3 output_channels = 2 x = torch.randn(batch, feature_channels) - potential = EquilibriumAggregation(feature_channels, output_channels, - num_layers=[10, 10], grad_iter=iter, - alpha=alpha) + model = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter, + alpha=alpha) - out = potential(x) + assert model.__repr__() == 'EquilibriumAggregation()' + out = model(x) assert out.size() == (1, 2) @@ -31,9 +32,10 @@ def test_equilibrium_aggregation_batch(iter, alpha): batch = torch.tensor([0 for _ in range(batch_1)] + [1 for _ in range(batch_2)]) - potential = EquilibriumAggregation(feature_channels, output_channels, - num_layers=[10, 10], grad_iter=iter, - alpha=alpha) + model = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter, + alpha=alpha) - out = potential(x, batch) + assert model.__repr__() == 'EquilibriumAggregation()' + out = model(x, batch) assert out.size() == (2, 2) diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index cb795ff42c14..cf006b5f1054 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -146,7 +146,6 @@ def _download(self): def _process(self): f = osp.join(self.processed_dir, 'pre_transform.pt') - print(_repr(self.pre_transform)) if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): warnings.warn( f"The `pre_transform` argument differs from the one used in " diff --git a/torch_geometric/datasets/ppi.py b/torch_geometric/datasets/ppi.py index b90acce26f61..4ccd3dc9e6cd 100644 --- a/torch_geometric/datasets/ppi.py +++ b/torch_geometric/datasets/ppi.py @@ -91,7 +91,6 @@ def download(self): def process(self): import networkx as nx from networkx.readwrite import json_graph - print("test") for s, split in enumerate(['train', 'valid', 'test']): path = osp.join(self.raw_dir, f'{split}_graph.json') @@ -124,7 +123,6 @@ def process(self): continue if self.pre_transform is not None: - print("test") data = self.pre_transform(data) data_list.append(data) diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 375b9a1def8a..7ba81ae4abae 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -6,12 +6,12 @@ from torch_geometric.nn.inits import reset -class ResNetPotential(torch.nn.Module): - def __init__(self, in_channels: int, num_layers: List[int]): +class FCResNetBlock(torch.nn.Module): + def __init__(self, in_channels: int, out_channels: int, + num_layers: List[int]): super().__init__() - output_size = 1 - sizes = [in_channels] + num_layers + [output_size] + sizes = [in_channels] + num_layers + [out_channels] self.layers = torch.nn.ModuleList([ torch.nn.Sequential(torch.nn.Linear(in_size, out_size), torch.nn.LayerNorm(out_size), torch.nn.Tanh()) @@ -20,7 +20,7 @@ def __init__(self, in_channels: int, num_layers: List[int]): self.res_trans = torch.nn.ModuleList([ torch.nn.Linear(in_channels, layer_size) - for layer_size in num_layers + [output_size] + for layer_size in num_layers + [out_channels] ]) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -71,8 +71,8 @@ def __init__(self, in_channels: int, out_channels: int, alpha: float = 0.1): super().__init__() - self.potential = ResNetPotential(in_channels + out_channels, - num_layers) + self.potential = FCResNetBlock(in_channels + out_channels, 1, + num_layers) self.lamb = torch.nn.Parameter(torch.Tensor([0.1]), requires_grad=True) self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter From a02eea8c782c932d8398b307f1ec6943dc059ae3 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 6 May 2022 09:31:19 +0800 Subject: [PATCH 08/25] match parameters from paper --- examples/equilibrium_aggregation_median.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index 6321e401ac48..cd66c622ec6c 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -13,14 +13,15 @@ input_size = 100 epochs = 10000 embedding_size = 10 +eval_each = 8000 model = torch.nn.Sequential(EquilibriumAggregation(1, 10, [256, 256], 5, 0.1), torch.nn.Linear(10, 1)) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) -norm = torch.distributions.normal.Normal(0, 0.5) -gamma = torch.distributions.gamma.Gamma(1, 2) -uniform = torch.distributions.uniform.Uniform(-1, 1) +norm = torch.distributions.normal.Normal(0.5, 0.4) +gamma = torch.distributions.gamma.Gamma(0.2, 0.5) +uniform = torch.distributions.uniform.Uniform(0, 1) total_loss = 0 n_loss = 0 @@ -34,5 +35,5 @@ optimizer.step() total_loss += loss n_loss += 1 - if i % 500 == 499: - print(f"Average loss at epoc {i} is {total_loss/n_loss}") + if i % eval_each == (eval_each - 1): + print(f"Average loss at epoc {i} is {total_loss / n_loss}") From 3f0a23bcecfb47c40753a82c5d2268818102eee4 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Fri, 6 May 2022 09:37:19 +0800 Subject: [PATCH 09/25] match parameters from paper --- examples/equilibrium_aggregation_median.py | 4 ++-- torch_geometric/nn/glob/equilibrium_aggregation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index cd66c622ec6c..56b02184509f 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -11,7 +11,7 @@ from torch_geometric.nn import EquilibriumAggregation input_size = 100 -epochs = 10000 +steps = 10000000 embedding_size = 10 eval_each = 8000 @@ -25,7 +25,7 @@ total_loss = 0 n_loss = 0 -for i in range(epochs): +for i in range(steps): optimizer.zero_grad() dist = np.random.choice([norm, gamma, uniform]) x = dist.sample((input_size, 1)) diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 7ba81ae4abae..7af7021ddc2e 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -88,8 +88,8 @@ def reset_parameters(self): def init_output(self, batch: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = 1 if batch is None else int(batch.max().item() + 1) - return torch.randn(batch_size, self.output_dim, - requires_grad=True) * 0.01 + return torch.zeros(batch_size, self.output_dim, + requires_grad=True).float() + 1e-15 def reg(self, y: torch.Tensor) -> float: return self.softplus(self.lamb) * y.norm(2, dim=1, keepdim=True) From 8d794244c51eca32014ce2eaedba95d396621050 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Mon, 9 May 2022 18:49:37 +0800 Subject: [PATCH 10/25] wip --- examples/equilibrium_aggregation_median.py | 4 ++-- torch_geometric/nn/glob/equilibrium_aggregation.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index 56b02184509f..bda17b00bd5c 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -13,9 +13,9 @@ input_size = 100 steps = 10000000 embedding_size = 10 -eval_each = 8000 +eval_each = 1000 -model = torch.nn.Sequential(EquilibriumAggregation(1, 10, [256, 256], 5, 0.1), +model = torch.nn.Sequential(EquilibriumAggregation(1, 10, [10, 10], 10, 0.01), torch.nn.Linear(10, 1)) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 7af7021ddc2e..65141c185ccb 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -13,8 +13,10 @@ def __init__(self, in_channels: int, out_channels: int, super().__init__() sizes = [in_channels] + num_layers + [out_channels] self.layers = torch.nn.ModuleList([ - torch.nn.Sequential(torch.nn.Linear(in_size, out_size), - torch.nn.LayerNorm(out_size), torch.nn.Tanh()) + torch.nn.Sequential( + torch.nn.Linear(in_size, out_size), + # torch.nn.LayerNorm(out_size), + torch.nn.Tanh()) for in_size, out_size in zip(sizes[:-1], sizes[1:]) ]) From 3f1216a30673199543f01846822c14bd4f0a24f6 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sat, 14 May 2022 15:12:42 +0800 Subject: [PATCH 11/25] add nesterov loss --- examples/equilibrium_aggregation_median.py | 5 +- test/nn/glob/test_equilibrium_aggregation.py | 6 +- .../nn/glob/equilibrium_aggregation.py | 138 +++++++++++++----- 3 files changed, 102 insertions(+), 47 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index bda17b00bd5c..02c98c656fe8 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -15,9 +15,8 @@ embedding_size = 10 eval_each = 1000 -model = torch.nn.Sequential(EquilibriumAggregation(1, 10, [10, 10], 10, 0.01), - torch.nn.Linear(10, 1)) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +model = EquilibriumAggregation(1, 10, [256, 256], 1) +optimizer = torch.optim.Adam(model.parameters(), lr=0.001) norm = torch.distributions.normal.Normal(0.5, 0.4) gamma = torch.distributions.gamma.Gamma(0.2, 0.5) diff --git a/test/nn/glob/test_equilibrium_aggregation.py b/test/nn/glob/test_equilibrium_aggregation.py index 1df30ec78a88..c6952c6b887a 100644 --- a/test/nn/glob/test_equilibrium_aggregation.py +++ b/test/nn/glob/test_equilibrium_aggregation.py @@ -13,8 +13,7 @@ def test_equilibrium_aggregation(iter, alpha): output_channels = 2 x = torch.randn(batch, feature_channels) model = EquilibriumAggregation(feature_channels, output_channels, - num_layers=[10, 10], grad_iter=iter, - alpha=alpha) + num_layers=[10, 10], grad_iter=iter) assert model.__repr__() == 'EquilibriumAggregation()' out = model(x) @@ -33,8 +32,7 @@ def test_equilibrium_aggregation_batch(iter, alpha): [1 for _ in range(batch_2)]) model = EquilibriumAggregation(feature_channels, output_channels, - num_layers=[10, 10], grad_iter=iter, - alpha=alpha) + num_layers=[10, 10], grad_iter=iter) assert model.__repr__() == 'EquilibriumAggregation()' out = model(x, batch) diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index 65141c185ccb..d716bac3ae1b 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -1,12 +1,11 @@ -from typing import List, Optional +from typing import Callable, List, Optional, Tuple import torch -from torch_scatter import scatter from torch_geometric.nn.inits import reset -class FCResNetBlock(torch.nn.Module): +class ResNetPotential(torch.nn.Module): def __init__(self, in_channels: int, out_channels: int, num_layers: List[int]): @@ -25,13 +24,80 @@ def __init__(self, in_channels: int, out_channels: int, for layer_size in num_layers + [out_channels] ]) - def forward(self, x: torch.Tensor) -> torch.Tensor: - h = x + def forward(self, x: torch.Tensor, y: torch.Tensor, + batch: Optional[torch.Tensor]) -> torch.Tensor: + if batch is None: + inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1) + else: + inp = torch.cat([x, y[batch]], dim=1) + + h = inp for layer, res in zip(self.layers, self.res_trans): h = layer(h) - h = res(x) + h + h = res(inp) + h + return h.sum() + + +class MomentumOptimizer(torch.nn.Module): + r""" + Provides an inner loop optimizer for the implicitly defined output + layer. It is based on an unrolled Nesterov momentum algorithm. - return h + Args: + learning_rate (flaot): learning rate for optimizer. + momentum (float): momentum for optimizer. + learnable (bool): If :obj:`True` then the :obj:`learning_rate` and + :obj:`momentum` will be learnable parameters. If False they + are fixed. (default: :obj:`True`) + """ + def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9, + learnable: bool = True): + super().__init__() + + self._initial_lr = learning_rate + self._initial_mom = momentum + self._lr = torch.nn.Parameter(torch.Tensor([learning_rate]), + requires_grad=learnable) + self._mom = torch.nn.Parameter(torch.Tensor([momentum]), + requires_grad=learnable) + self.softplus = torch.nn.Softplus() + self.sigmoid = torch.nn.Sigmoid() + self.register_full_backward_hook(self.gradient_hook) + + @staticmethod + def gradient_hook(module, grad_input, grad_output): + pass + + def reset_parameters(self): + self._lr.data.fill_(self._initial_lr) + self._mom.data.fill_(self._initial_mom) + + @property + def learning_rate(self): + return self.softplus(self._lr) + + @property + def momentum(self): + return self.sigmoid(self._mom) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + batch: Optional[torch.Tensor], + func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + torch.Tensor], + iterations: int = 5, + ) -> Tuple[torch.Tensor, float]: + + momentum = torch.zeros_like(y) + for _ in range(iterations): + val = func(x, y, batch) + grad = torch.autograd.grad(val, y, create_graph=True, + retain_graph=True)[0] + momentum = self.momentum * momentum - self.learning_rate * grad + y = y + momentum + return y class EquilibriumAggregation(torch.nn.Module): @@ -64,62 +130,54 @@ class EquilibriumAggregation(torch.nn.Module): potential function. grad_iter (int): The number of steps to take in the internal gradient descent. (default: :obj:`5`) - alpha (float): The step size of the internal gradient descent. - (default: :obj:`0.1`) - + lamb (float): The initial regularization constant. Is learnable. + descent. (default: :obj:`0.1`) """ def __init__(self, in_channels: int, out_channels: int, - num_layers: List[int], grad_iter: int = 5, - alpha: float = 0.1): + num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1): super().__init__() - self.potential = FCResNetBlock(in_channels + out_channels, 1, - num_layers) - self.lamb = torch.nn.Parameter(torch.Tensor([0.1]), requires_grad=True) + self.potential = ResNetPotential(in_channels + out_channels, 1, + num_layers) + self.optimizer = MomentumOptimizer() + self._initial_lambda = lamb + self._labmda = torch.nn.Parameter(torch.Tensor([lamb]), + requires_grad=True) self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter - self.alpha = alpha self.output_dim = out_channels - self.reset_parameters() def reset_parameters(self): - self.lamb.data.fill_(0.1) + self.lamb.data.fill_(self._initial_lambda) + reset(self.optimizer) reset(self.potential) + @property + def lamb(self): + return self.softplus(self._labmda) + def init_output(self, batch: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = 1 if batch is None else int(batch.max().item() + 1) return torch.zeros(batch_size, self.output_dim, - requires_grad=True).float() + 1e-15 + requires_grad=True).float() def reg(self, y: torch.Tensor) -> float: - return self.softplus(self.lamb) * y.norm(2, dim=1, keepdim=True) + return self.lamb * y.square().sum() - def combine_input(self, x: torch.Tensor, y: torch.Tensor, - batch: Optional[torch.Tensor] = None) -> torch.Tensor: - if batch is None: - return torch.cat([x, y.expand(x.size(0), -1)], dim=1) - return torch.cat([x, y[batch]], dim=1) + def energy(self, x: torch.Tensor, y: torch.Tensor, + batch: Optional[torch.Tensor]): + return self.potential(x, y, batch) + self.reg(y) def forward(self, x: torch.Tensor, batch: Optional[torch.Tensor] = None) -> torch.Tensor: + with torch.enable_grad(): - yhat = self.init_output(batch) - for _ in range(self.grad_iter): - z = self.combine_input(x, yhat, batch) - potential = self.potential(z) - if batch is None: - potential = potential.mean(axis=0, keepdim=True) - else: - size = int(batch.max().item() + 1) - potential = scatter(potential, batch, dim=0, dim_size=size, - reduce='mean') - reg = self.reg(yhat) - enrg = (potential + reg).sum() - yhat = yhat - self.alpha * torch.autograd.grad( - enrg, yhat, create_graph=True, retain_graph=True)[0] - return yhat + y = self.optimizer(x, self.init_output(batch), batch, self.energy, + iterations=self.grad_iter) + + return y def __repr__(self) -> str: return (f'{self.__class__.__name__}()') From 894a50fd5e975282e720541ca8031b6ad18cbd4d Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sun, 15 May 2022 22:00:21 +0800 Subject: [PATCH 12/25] change loss to be a mean across a single batch --- examples/equilibrium_aggregation_median.py | 1 - .../nn/glob/equilibrium_aggregation.py | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_aggregation_median.py index 02c98c656fe8..e96652361e5f 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_aggregation_median.py @@ -21,7 +21,6 @@ norm = torch.distributions.normal.Normal(0.5, 0.4) gamma = torch.distributions.gamma.Gamma(0.2, 0.5) uniform = torch.distributions.uniform.Uniform(0, 1) - total_loss = 0 n_loss = 0 for i in range(steps): diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/glob/equilibrium_aggregation.py index d716bac3ae1b..2ff46f74ccba 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/glob/equilibrium_aggregation.py @@ -1,6 +1,7 @@ from typing import Callable, List, Optional, Tuple import torch +from torch_scatter import scatter from torch_geometric.nn.inits import reset @@ -12,12 +13,11 @@ def __init__(self, in_channels: int, out_channels: int, super().__init__() sizes = [in_channels] + num_layers + [out_channels] self.layers = torch.nn.ModuleList([ - torch.nn.Sequential( - torch.nn.Linear(in_size, out_size), - # torch.nn.LayerNorm(out_size), - torch.nn.Tanh()) - for in_size, out_size in zip(sizes[:-1], sizes[1:]) + torch.nn.Sequential(torch.nn.Linear(in_size, out_size), + torch.nn.LayerNorm(out_size), torch.nn.Tanh()) + for in_size, out_size in zip(sizes[:-2], sizes[1:-1]) ]) + self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1])) self.res_trans = torch.nn.ModuleList([ torch.nn.Linear(in_channels, layer_size) @@ -35,7 +35,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, for layer, res in zip(self.layers, self.res_trans): h = layer(h) h = res(inp) + h - return h.sum() + + if batch is None: + return h.mean() + + size = int(batch.max().item() + 1) + return scatter(x, batch, dim=0, dim_size=size, reduce='mean').sum() class MomentumOptimizer(torch.nn.Module): @@ -164,7 +169,7 @@ def init_output(self, requires_grad=True).float() def reg(self, y: torch.Tensor) -> float: - return self.lamb * y.square().sum() + return self.lamb * y.square().mean(dim=1).sum(dim=0) def energy(self, x: torch.Tensor, y: torch.Tensor, batch: Optional[torch.Tensor]): From 5091ffeb4b22328185e72e89f2bbe8655a228aa5 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 28 May 2022 13:11:12 +0800 Subject: [PATCH 13/25] move to aggreation package --- .../test_equilibrium.py} | 6 +- torch_geometric/nn/aggr/__init__.py | 2 + .../equilibrium.py} | 69 ++++++++----------- torch_geometric/nn/glob/__init__.py | 2 - 4 files changed, 33 insertions(+), 46 deletions(-) rename test/nn/{glob/test_equilibrium_aggregation.py => aggr/test_equilibrium.py} (87%) rename torch_geometric/nn/{glob/equilibrium_aggregation.py => aggr/equilibrium.py} (72%) diff --git a/test/nn/glob/test_equilibrium_aggregation.py b/test/nn/aggr/test_equilibrium.py similarity index 87% rename from test/nn/glob/test_equilibrium_aggregation.py rename to test/nn/aggr/test_equilibrium.py index c6952c6b887a..10505cd25712 100644 --- a/test/nn/glob/test_equilibrium_aggregation.py +++ b/test/nn/aggr/test_equilibrium.py @@ -1,12 +1,12 @@ import pytest import torch -from torch_geometric.nn import EquilibriumAggregation +from torch_geometric.nn.aggr import EquilibriumAggregation @pytest.mark.parametrize('iter', [0, 1, 5]) @pytest.mark.parametrize('alpha', [0, .1, 5]) -def test_equilibrium_aggregation(iter, alpha): +def test_equilibrium(iter, alpha): batch = 10 feature_channels = 3 @@ -22,7 +22,7 @@ def test_equilibrium_aggregation(iter, alpha): @pytest.mark.parametrize('iter', [0, 1, 5]) @pytest.mark.parametrize('alpha', [0, .1, 5]) -def test_equilibrium_aggregation_batch(iter, alpha): +def test_equilibrium_batch(iter, alpha): batch_1, batch_2 = 4, 6 feature_channels = 3 diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index dbc6c42ec449..0e9733a2ad6a 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -14,6 +14,7 @@ from .lstm import LSTMAggregation from .set2set import Set2Set from .scaler import DegreeScalerAggregation +from .equilibrium import EquilibriumAggregation __all__ = classes = [ 'Aggregation', @@ -30,4 +31,5 @@ 'LSTMAggregation', 'Set2Set', 'DegreeScalerAggregation', + 'EquilibriumAggregation' ] diff --git a/torch_geometric/nn/glob/equilibrium_aggregation.py b/torch_geometric/nn/aggr/equilibrium.py similarity index 72% rename from torch_geometric/nn/glob/equilibrium_aggregation.py rename to torch_geometric/nn/aggr/equilibrium.py index 2ff46f74ccba..fb1062da09c2 100644 --- a/torch_geometric/nn/glob/equilibrium_aggregation.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -1,8 +1,10 @@ from typing import Callable, List, Optional, Tuple import torch +from torch import Tensor from torch_scatter import scatter +from torch_geometric.nn.aggr import Aggregation from torch_geometric.nn.inits import reset @@ -24,23 +26,22 @@ def __init__(self, in_channels: int, out_channels: int, for layer_size in num_layers + [out_channels] ]) - def forward(self, x: torch.Tensor, y: torch.Tensor, - batch: Optional[torch.Tensor]) -> torch.Tensor: - if batch is None: + def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor: + if index is None: inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1) else: - inp = torch.cat([x, y[batch]], dim=1) + inp = torch.cat([x, y[index]], dim=1) h = inp for layer, res in zip(self.layers, self.res_trans): h = layer(h) h = res(inp) + h - if batch is None: + if index is None: return h.mean() - size = int(batch.max().item() + 1) - return scatter(x, batch, dim=0, dim_size=size, reduce='mean').sum() + size = int(index.max().item() + 1) + return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum() class MomentumOptimizer(torch.nn.Module): @@ -61,17 +62,12 @@ def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9, self._initial_lr = learning_rate self._initial_mom = momentum - self._lr = torch.nn.Parameter(torch.Tensor([learning_rate]), + self._lr = torch.nn.Parameter(Tensor([learning_rate]), requires_grad=learnable) - self._mom = torch.nn.Parameter(torch.Tensor([momentum]), + self._mom = torch.nn.Parameter(Tensor([momentum]), requires_grad=learnable) self.softplus = torch.nn.Softplus() self.sigmoid = torch.nn.Sigmoid() - self.register_full_backward_hook(self.gradient_hook) - - @staticmethod - def gradient_hook(module, grad_input, grad_output): - pass def reset_parameters(self): self._lr.data.fill_(self._initial_lr) @@ -87,17 +83,16 @@ def momentum(self): def forward( self, - x: torch.Tensor, - y: torch.Tensor, - batch: Optional[torch.Tensor], - func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], - torch.Tensor], + x: Tensor, + y: Tensor, + index: Optional[Tensor], + func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], iterations: int = 5, - ) -> Tuple[torch.Tensor, float]: + ) -> Tuple[Tensor, float]: momentum = torch.zeros_like(y) for _ in range(iterations): - val = func(x, y, batch) + val = func(x, y, index) grad = torch.autograd.grad(val, y, create_graph=True, retain_graph=True)[0] momentum = self.momentum * momentum - self.learning_rate * grad @@ -105,7 +100,7 @@ def forward( return y -class EquilibriumAggregation(torch.nn.Module): +class EquilibriumAggregation(Aggregation): r""" The graph global pooling layer from the `"Equilibrium Aggregation: Encoding Sets via Optimization" @@ -122,12 +117,6 @@ class EquilibriumAggregation(torch.nn.Module): and a simple L2 norm for the regularizer with learnable weight :math:`\lambda`. - .. note:: - - The forward function of this layer accepts a :obj:`batch` argument that - works like the other global pooling layers when working on input that - is from multiple graphs. - Args: in_channels (int): The number of channels in the input to the layer. out_channels (float): The number of channels in the ouput. @@ -146,8 +135,7 @@ def __init__(self, in_channels: int, out_channels: int, num_layers) self.optimizer = MomentumOptimizer() self._initial_lambda = lamb - self._labmda = torch.nn.Parameter(torch.Tensor([lamb]), - requires_grad=True) + self._labmda = torch.nn.Parameter(Tensor([lamb]), requires_grad=True) self.softplus = torch.nn.Softplus() self.grad_iter = grad_iter self.output_dim = out_channels @@ -162,24 +150,23 @@ def reset_parameters(self): def lamb(self): return self.softplus(self._labmda) - def init_output(self, - batch: Optional[torch.Tensor] = None) -> torch.Tensor: - batch_size = 1 if batch is None else int(batch.max().item() + 1) - return torch.zeros(batch_size, self.output_dim, + def init_output(self, index: Optional[Tensor] = None) -> Tensor: + index_size = 1 if index is None else int(index.max().item() + 1) + return torch.zeros(index_size, self.output_dim, requires_grad=True).float() - def reg(self, y: torch.Tensor) -> float: + def reg(self, y: Tensor) -> float: return self.lamb * y.square().mean(dim=1).sum(dim=0) - def energy(self, x: torch.Tensor, y: torch.Tensor, - batch: Optional[torch.Tensor]): - return self.potential(x, y, batch) + self.reg(y) + def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): + return self.potential(x, y, index) + self.reg(y) - def forward(self, x: torch.Tensor, - batch: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: with torch.enable_grad(): - y = self.optimizer(x, self.init_output(batch), batch, self.energy, + y = self.optimizer(x, self.init_output(index), index, self.energy, iterations=self.grad_iter) return y diff --git a/torch_geometric/nn/glob/__init__.py b/torch_geometric/nn/glob/__init__.py index 1d1c4daccb91..655138a552ca 100644 --- a/torch_geometric/nn/glob/__init__.py +++ b/torch_geometric/nn/glob/__init__.py @@ -2,7 +2,6 @@ from .sort import global_sort_pool from .attention import GlobalAttention from .gmt import GraphMultisetTransformer -from .equilibrium_aggregation import EquilibriumAggregation __all__ = [ 'global_add_pool', @@ -11,7 +10,6 @@ 'global_sort_pool', 'GlobalAttention', 'GraphMultisetTransformer', - 'EquilibriumAggregation', ] classes = __all__ From ed94b33e176347c86833a5bd6fb12bc385745a53 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 28 May 2022 13:14:01 +0800 Subject: [PATCH 14/25] move example import --- ...{equilibrium_aggregation_median.py => equilibrium_median.py} | 2 +- torch_geometric/nn/glob/__init__.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) rename examples/{equilibrium_aggregation_median.py => equilibrium_median.py} (94%) diff --git a/examples/equilibrium_aggregation_median.py b/examples/equilibrium_median.py similarity index 94% rename from examples/equilibrium_aggregation_median.py rename to examples/equilibrium_median.py index e96652361e5f..afb840901249 100644 --- a/examples/equilibrium_aggregation_median.py +++ b/examples/equilibrium_median.py @@ -8,7 +8,7 @@ import numpy as np import torch -from torch_geometric.nn import EquilibriumAggregation +from torch_geometric.nn.aggr import EquilibriumAggregation input_size = 100 steps = 10000000 diff --git a/torch_geometric/nn/glob/__init__.py b/torch_geometric/nn/glob/__init__.py index 655138a552ca..be0b138ffcd2 100644 --- a/torch_geometric/nn/glob/__init__.py +++ b/torch_geometric/nn/glob/__init__.py @@ -1,4 +1,5 @@ from .glob import global_add_pool, global_mean_pool, global_max_pool +from .glob import GlobalPooling from .sort import global_sort_pool from .attention import GlobalAttention from .gmt import GraphMultisetTransformer @@ -7,6 +8,7 @@ 'global_add_pool', 'global_mean_pool', 'global_max_pool', + 'GlobalPooling', 'global_sort_pool', 'GlobalAttention', 'GraphMultisetTransformer', From 3b3aaca40fe3019325ee3479a255f3f49ef10ec1 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 28 May 2022 13:44:37 +0800 Subject: [PATCH 15/25] add assertions for unsupported --- torch_geometric/nn/aggr/equilibrium.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index fb1062da09c2..ea0597a1b327 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -156,7 +156,7 @@ def init_output(self, index: Optional[Tensor] = None) -> Tensor: requires_grad=True).float() def reg(self, y: Tensor) -> float: - return self.lamb * y.square().mean(dim=1).sum(dim=0) + return self.lamb * y.square().mean(dim=-2).sum(dim=0) def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): return self.potential(x, y, index) + self.reg(y) @@ -165,6 +165,12 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: + if ptr is not None: + raise ValueError(f"{self.__class__} doesn't support `ptr`") + + if dim_size is not None: + raise ValueError(f"{self.__class__} doesn't support `dim_size`") + with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, iterations=self.grad_iter) From 99f020f61dcf40b18720d395b8109df0fcaaa2d0 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 28 May 2022 15:18:47 +0800 Subject: [PATCH 16/25] update changelog --- CHANGELOG.md | 2 +- torch_geometric/nn/aggr/equilibrium.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 194448c66268..976ed99b2c43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index ea0597a1b327..891eb77dd23e 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -173,7 +173,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, - iterations=self.grad_iter) + dim=dim, iterations=self.grad_iter) return y From 8561c8bb1dc343041e857a999b00a9d7dba54775 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 28 May 2022 15:23:23 +0800 Subject: [PATCH 17/25] update changelog --- torch_geometric/nn/aggr/equilibrium.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index 891eb77dd23e..ea0597a1b327 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -173,7 +173,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, - dim=dim, iterations=self.grad_iter) + iterations=self.grad_iter) return y From 8460438ac35a8e778f115ce932bf0bc4faf03d0a Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 4 Jun 2022 08:16:11 +0800 Subject: [PATCH 18/25] add suggestions and test update --- examples/equilibrium_median.py | 6 +++++- test/nn/aggr/test_equilibrium.py | 14 ++++++++++++++ torch_geometric/nn/aggr/equilibrium.py | 22 ++++++++++++++++------ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/examples/equilibrium_median.py b/examples/equilibrium_median.py index afb840901249..55a4c06d9a31 100644 --- a/examples/equilibrium_median.py +++ b/examples/equilibrium_median.py @@ -3,6 +3,9 @@ `_ to try and teach `EquilibriumAggregation` to learn to take the median of a set of numbers + +This example converges slowly to being able to predict the +median similar to what is observed in the paper. """ import numpy as np @@ -23,12 +26,13 @@ uniform = torch.distributions.uniform.Uniform(0, 1) total_loss = 0 n_loss = 0 + for i in range(steps): optimizer.zero_grad() dist = np.random.choice([norm, gamma, uniform]) x = dist.sample((input_size, 1)) y = model(x) - loss = (y - x.median()).norm(2) + loss = (y - x.median()).norm(2) / input_size loss.backward() optimizer.step() total_loss += loss diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index 10505cd25712..8709358e2ce4 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -19,6 +19,13 @@ def test_equilibrium(iter, alpha): out = model(x) assert out.size() == (1, 2) + with pytest.raises(NotImplementedError): + model(x, dim_size=0) + + out = model(x, dim_size=3) + assert out.size() == (3, 2) + assert torch.all(out[1:, :] == 0) + @pytest.mark.parametrize('iter', [0, 1, 5]) @pytest.mark.parametrize('alpha', [0, .1, 5]) @@ -37,3 +44,10 @@ def test_equilibrium_batch(iter, alpha): assert model.__repr__() == 'EquilibriumAggregation()' out = model(x, batch) assert out.size() == (2, 2) + + with pytest.raises(NotImplementedError): + model(x, dim_size=0) + + out = model(x, dim_size=3) + assert out.size() == (3, 2) + assert torch.all(out[1:, :] == 0) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index ea0597a1b327..259f078dae97 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -90,13 +90,14 @@ def forward( iterations: int = 5, ) -> Tuple[Tensor, float]: - momentum = torch.zeros_like(y) + momentum_buffer = torch.zeros_like(y) for _ in range(iterations): val = func(x, y, index) grad = torch.autograd.grad(val, y, create_graph=True, retain_graph=True)[0] - momentum = self.momentum * momentum - self.learning_rate * grad - y = y + momentum + delta = self.learning_rate * grad + momentum_buffer = self.momentum * momentum_buffer - delta + y = y + momentum_buffer return y @@ -166,15 +167,24 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, dim: int = -2) -> Tensor: if ptr is not None: - raise ValueError(f"{self.__class__} doesn't support `ptr`") + raise NotImplementedError( + f"{self.__class__} doesn't support `ptr`") - if dim_size is not None: - raise ValueError(f"{self.__class__} doesn't support `dim_size`") + index_size = 1 if index is None else index.max() + 1 + dim_size = index_size if dim_size is None else dim_size + + if dim_size < index_size: + raise NotImplementedError(f"{self.__class__} doesn't support " + f"`dim_size`") with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, iterations=self.grad_iter) + if dim_size > index_size: + zero = torch.zeros(dim_size - index_size, *y.size()[1:]) + y = torch.cat([y, zero]) + return y def __repr__(self) -> str: From 6a38ff922724f7bbe3302a6b18d6d39cfc7187a8 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sat, 4 Jun 2022 07:46:28 +0800 Subject: [PATCH 19/25] Update test/nn/aggr/test_equilibrium.py Co-authored-by: Guohao Li --- test/nn/aggr/test_equilibrium.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index 8709358e2ce4..5366ba4f7914 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize('alpha', [0, .1, 5]) def test_equilibrium(iter, alpha): - batch = 10 +batch_size = 10 feature_channels = 3 output_channels = 2 x = torch.randn(batch, feature_channels) From b34b3a53d71a354fc86d8b5b10ca40377abf7bf6 Mon Sep 17 00:00:00 2001 From: Padarn Wilson Date: Sat, 4 Jun 2022 07:46:35 +0800 Subject: [PATCH 20/25] Update test/nn/aggr/test_equilibrium.py Co-authored-by: Guohao Li --- test/nn/aggr/test_equilibrium.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index 5366ba4f7914..e55b1c67e950 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -11,7 +11,7 @@ def test_equilibrium(iter, alpha): batch_size = 10 feature_channels = 3 output_channels = 2 - x = torch.randn(batch, feature_channels) + x = torch.randn(batch_size, feature_channels) model = EquilibriumAggregation(feature_channels, output_channels, num_layers=[10, 10], grad_iter=iter) From 4171372b29055eb6f79f500f74dc53635f26ee3c Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 4 Jun 2022 13:31:01 +0800 Subject: [PATCH 21/25] fix error message --- torch_geometric/nn/aggr/equilibrium.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index 259f078dae97..f1b823ab5397 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -174,8 +174,8 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, dim_size = index_size if dim_size is None else dim_size if dim_size < index_size: - raise NotImplementedError(f"{self.__class__} doesn't support " - f"`dim_size`") + raise NotImplementedError("`dim_size` is less than `index` " + "implied size") with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, From cc0eac10af4a0ad2e07a9861a04dad34fa3a3f45 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Mon, 6 Jun 2022 17:40:48 +0800 Subject: [PATCH 22/25] fix indent error --- test/nn/aggr/test_equilibrium.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index e55b1c67e950..92f926020ef3 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize('alpha', [0, .1, 5]) def test_equilibrium(iter, alpha): -batch_size = 10 + batch_size = 10 feature_channels = 3 output_channels = 2 x = torch.randn(batch_size, feature_channels) From fa47205513b3806d40681e711ee4c0fd1c3149eb Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 18 Jun 2022 12:48:02 +0800 Subject: [PATCH 23/25] rebase to fix tests --- test/nn/aggr/test_equilibrium.py | 4 ++-- torch_geometric/nn/aggr/equilibrium.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/nn/aggr/test_equilibrium.py b/test/nn/aggr/test_equilibrium.py index 92f926020ef3..ceea61d6c303 100644 --- a/test/nn/aggr/test_equilibrium.py +++ b/test/nn/aggr/test_equilibrium.py @@ -19,7 +19,7 @@ def test_equilibrium(iter, alpha): out = model(x) assert out.size() == (1, 2) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): model(x, dim_size=0) out = model(x, dim_size=3) @@ -45,7 +45,7 @@ def test_equilibrium_batch(iter, alpha): out = model(x, batch) assert out.size() == (2, 2) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): model(x, dim_size=0) out = model(x, dim_size=3) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index f1b823ab5397..23358c4437ef 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -167,15 +167,14 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, dim: int = -2) -> Tensor: if ptr is not None: - raise NotImplementedError( - f"{self.__class__} doesn't support `ptr`") + raise ValueError(f"{self.__class__} doesn't support `ptr`") index_size = 1 if index is None else index.max() + 1 dim_size = index_size if dim_size is None else dim_size if dim_size < index_size: - raise NotImplementedError("`dim_size` is less than `index` " - "implied size") + raise ValueError("`dim_size` is less than `index` " + "implied size") with torch.enable_grad(): y = self.optimizer(x, self.init_output(index), index, self.energy, From 7359edc78cadb31e99edb9e68459699fa91f0262 Mon Sep 17 00:00:00 2001 From: "padarn.wilson" Date: Sat, 9 Jul 2022 17:57:55 +0800 Subject: [PATCH 24/25] minor update for agg class --- torch_geometric/nn/aggr/equilibrium.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/aggr/equilibrium.py b/torch_geometric/nn/aggr/equilibrium.py index 23358c4437ef..dad2812d106f 100644 --- a/torch_geometric/nn/aggr/equilibrium.py +++ b/torch_geometric/nn/aggr/equilibrium.py @@ -162,12 +162,11 @@ def reg(self, y: Tensor) -> float: def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): return self.potential(x, y, index) + self.reg(y) - def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: - if ptr is not None: - raise ValueError(f"{self.__class__} doesn't support `ptr`") + self.assert_index_present(index) index_size = 1 if index is None else index.max() + 1 dim_size = index_size if dim_size is None else dim_size From 2236dc129143b09010f9fce57183f7b2bad99e81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Jul 2022 09:59:10 +0000 Subject: [PATCH 25/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/aggr/__init__.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 0e9733a2ad6a..566684c1caa7 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -17,19 +17,9 @@ from .equilibrium import EquilibriumAggregation __all__ = classes = [ - 'Aggregation', - 'MultiAggregation', - 'MeanAggregation', - 'SumAggregation', - 'MaxAggregation', - 'MinAggregation', - 'MulAggregation', - 'VarAggregation', - 'StdAggregation', - 'SoftmaxAggregation', - 'PowerMeanAggregation', - 'LSTMAggregation', - 'Set2Set', - 'DegreeScalerAggregation', + 'Aggregation', 'MultiAggregation', 'MeanAggregation', 'SumAggregation', + 'MaxAggregation', 'MinAggregation', 'MulAggregation', 'VarAggregation', + 'StdAggregation', 'SoftmaxAggregation', 'PowerMeanAggregation', + 'LSTMAggregation', 'Set2Set', 'DegreeScalerAggregation', 'EquilibriumAggregation' ]