Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 23, 2023
1 parent 76a12d1 commit 733bdbc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
10 changes: 6 additions & 4 deletions test/nn/models/test_neural_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from torch_geometric.nn import NeuralFingerprint


def test_neural_fingerprint():
x = torch.randn(3,7)
edge_index = torch.tensor([[0,1,1,2],[1,0,2,1]])
model = NeuralFingerprint(num_features=7, fingerprint_length=5, num_layers=4)
x = torch.randn(3, 7)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
model = NeuralFingerprint(num_features=7, fingerprint_length=5,
num_layers=4)
fingerprint = model.forward(x, edge_index)
assert fingerprint.size() == torch.Size([5])

test_neural_fingerprint()

test_neural_fingerprint()
18 changes: 11 additions & 7 deletions torch_geometric/nn/models/neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from torch import tensor
import torch.nn.functional as F
from torch import tensor

from torch_geometric.nn.conv import MFConv
from torch_geometric.nn.dense import Linear


class NeuralFingerprint(torch.nn.Module):
def __init__(self, num_features, fingerprint_length, num_layers):
super().__init__()
Expand All @@ -13,16 +14,19 @@ def __init__(self, num_features, fingerprint_length, num_layers):
self.num_layers = num_layers
self.layers = torch.nn.ModuleList()
for i in range(self.num_layers):
self.layers.append(MFConv(in_channels=self.num_features, out_channels=self.num_features))
self.layers.append(Linear(in_channels=self.num_features, out_channels=self.fingerprint_length))

self.layers.append(
MFConv(in_channels=self.num_features,
out_channels=self.num_features))
self.layers.append(
Linear(in_channels=self.num_features,
out_channels=self.fingerprint_length))

def forward(self, x, edge_index):
fingerprint = torch.zeros(self.fingerprint_length)
for i in range(0,2*self.num_layers,2):
for i in range(0, 2 * self.num_layers, 2):
x = self.layers[i](x, edge_index)
x = torch.sigmoid(x)
y = F.softmax(self.layers[i+1](x), dim=1)
y = F.softmax(self.layers[i + 1](x), dim=1)
for j in range(y.shape[0]):
fingerprint += y[j]
return fingerprint

0 comments on commit 733bdbc

Please sign in to comment.