Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Neural Fingerprint from Duvenaud et al. #7919

Merged
merged 26 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
76a12d1
Neural FingerPrint Model added
Aug 23, 2023
e3221fe
Comments added
Aug 23, 2023
c01ca01
Minor Change
Aug 23, 2023
ee5bda6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2023
7251ba6
Update neural_fingerprint.py
harshit5674 Sep 3, 2023
71672c8
Update test_neural_fingerprint.py
harshit5674 Sep 4, 2023
7421b0f
Update neural_fingerprint.py
harshit5674 Sep 4, 2023
dece87a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
eea4936
Update test_neural_fingerprint.py
harshit5674 Sep 4, 2023
045016c
Update neural_fingerprint.py
harshit5674 Sep 5, 2023
24465ae
Update neural_fingerprint.py
harshit5674 Sep 5, 2023
caf6ee0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2023
fc85130
Update test_neural_fingerprint.py
harshit5674 Sep 9, 2023
dca3d41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2023
440ad30
Update CHANGELOG.md
harshit5674 Sep 9, 2023
a48fe05
Merge branch 'master' into master
EdisonLeeeee Sep 9, 2023
e26dbd5
Update neural_fingerprint.py
EdisonLeeeee Sep 9, 2023
49165a2
Update neural_fingerprint.py
EdisonLeeeee Sep 11, 2023
55fd802
Merge branch 'master' into master
EdisonLeeeee Sep 11, 2023
89f1aa6
Update test_neural_fingerprint.py
harshit5674 Sep 11, 2023
5d5d206
Merge branch 'master' into master
harshit5674 Sep 11, 2023
0ea927c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 11, 2023
076ae2c
Update CHANGELOG.md
EdisonLeeeee Sep 12, 2023
7a2740c
Merge branch 'master' into master
EdisonLeeeee Sep 12, 2023
abffe00
update
rusty1s Sep 12, 2023
0608f3d
update
rusty1s Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))
- Added `SparseTensor` support to `WLConvContinuous`, `GeneralConv`, `PDNConv` and `ARMAConv` ([#8013](https://github.com/pyg-team/pytorch_geometric/pull/8013))
- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976))
- Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995))
Expand Down
28 changes: 28 additions & 0 deletions test/nn/models/test_neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.nn import NeuralFingerprint
from torch_geometric.testing import is_full_test
from torch_geometric.typing import SparseTensor


@pytest.mark.parametrize('batch', [None, torch.tensor([0, 1, 1])])
def test_neural_fingerprint(batch):
x = torch.randn(3, 7)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = NeuralFingerprint(7, 16, out_channels=5, num_layers=4)
assert str(model) == 'NeuralFingerprint(7, 5, num_layers=4)'
model.reset_parameters()

out = model(x, edge_index, batch)
assert out.size() == (1, 5) if batch is None else (2, 5)

if torch_geometric.typing.WITH_TORCH_SPARSE:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3))
assert torch.allclose(model(x, adj.t(), batch), out)

if is_full_test():
jit = torch.jit.export(model)
assert torch.allclose(jit(x, edge_index, batch), out)
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .rev_gnn import GroupAddRev
from .gnnff import GNNFF
from .pmlp import PMLP
from .neural_fingerprint import NeuralFingerprint

__all__ = classes = [
'MLP',
Expand Down Expand Up @@ -64,4 +65,5 @@
'GroupAddRev',
'GNNFF',
'PMLP',
'NeuralFingerprint',
]
72 changes: 72 additions & 0 deletions torch_geometric/nn/models/neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.nn import Linear, MFConv, global_add_pool
from torch_geometric.typing import Adj


class NeuralFingerprint(torch.nn.Module):
r"""The Neural Fingerprint model from the
`"Convolutional Networks on Graphs for Learning Molecular Fingerprints"
<https://arxiv.org/abs/1509.09292>`__ paper to generate fingerprints
of molecules.

Args:
in_channels (int): Size of each input sample.
hidden_channels (int): Size of each hidden sample.
out_channels (int): Size of each output fingerprint.
num_layers (int): Number of layers.
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MFConv`.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
**kwargs,
):
super().__init__()

self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.num_layers = num_layers

self.convs = torch.nn.ModuleList()
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.hidden_channels
self.convs.append(MFConv(in_channels, hidden_channels, **kwargs))

self.lins = torch.nn.ModuleList()
for _ in range(self.num_layers):
self.lins.append(Linear(hidden_channels, out_channels, bias=False))

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.convs:
conv.reset_parameters()
for lin in self.lins:
lin.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: Adj,
batch: Optional[Tensor] = None,
batch_size: Optional[int] = None,
) -> Tensor:
""""""
outs = []
for conv, lin in zip(self.convs, self.lins):
x = conv(x, edge_index).sigmoid()
y = lin(x).softmax(dim=-1)
outs.append(global_add_pool(y, batch, batch_size))
return sum(outs)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')