Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code Coverage] models/dimenet.py #6781

Merged
merged 9 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))
- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))
- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763))
- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781))
- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))
- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))
- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))
Expand Down
62 changes: 42 additions & 20 deletions test/nn/models/test_dimenet.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
import pytest
import torch
import torch.nn.functional as F

from torch_geometric.nn import DimeNetPlusPlus
from torch_geometric.testing import onlyFullTest
from torch_geometric.nn import DimeNet, DimeNetPlusPlus
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
Envelope,
ResidualLayer,
)
from torch_geometric.testing import is_full_test


@onlyFullTest
def test_dimenet_plus_plus():
def test_dimenet_modules():
env = Envelope(exponent=5)
x = torch.randn(10, 3)
assert env(x).size() == (10, 3) # Isotonic layer.

bbl = BesselBasisLayer(5)
x = torch.randn(10, 3)
assert bbl(x).size() == (10, 3, 5) # Non-isotonic layer.

rl = ResidualLayer(128, torch.nn.functional.relu)
x = torch.randn(128, 128)
assert rl(x).size() == (128, 128) # Isotonic layer.


@pytest.mark.parametrize('Model', [DimeNet, DimeNetPlusPlus])
def test_dimenet(Model):
z = torch.randint(1, 10, (20, ))
pos = torch.randn(20, 3)

model = DimeNetPlusPlus(
if Model == DimeNet:
kwargs = dict(num_bilinear=3)
else:
kwargs = dict(out_emb_channels=3, int_emb_size=5, basis_emb_size=5)

model = Model(
hidden_channels=5,
out_channels=1,
num_blocks=5,
out_emb_channels=3,
int_emb_size=5,
basis_emb_size=5,
num_spherical=5,
num_radial=5,
num_before_skip=2,
num_after_skip=2,
**kwargs,
)
model.reset_parameters()

Expand All @@ -31,14 +52,15 @@ def test_dimenet_plus_plus():
jit = torch.jit.export(model)
assert torch.allclose(jit(z, pos), out)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
if is_full_test():
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

min_loss = float('inf')
for i in range(100):
optimizer.zero_grad()
out = model(z, pos)
loss = F.l1_loss(out, torch.tensor([1.]))
loss.backward()
optimizer.step()
min_loss = min(float(loss), min_loss)
assert min_loss < 2
min_loss = float('inf')
for i in range(100):
optimizer.zero_grad()
out = model(z, pos)
loss = F.l1_loss(out, torch.tensor([1.0]))
loss.backward()
optimizer.step()
min_loss = min(float(loss), min_loss)
assert min_loss < 2
111 changes: 81 additions & 30 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path as osp
from math import pi as PI
from math import sqrt
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -17,7 +17,7 @@
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter

qm9_target_dict = {
qm9_target_dict: Dict[int, str] = {
0: 'mu',
1: 'alpha',
2: 'homo',
Expand Down Expand Up @@ -45,7 +45,7 @@ def forward(self, x: Tensor) -> Tensor:
x_pow_p0 = x.pow(p - 1)
x_pow_p1 = x_pow_p0 * x
x_pow_p2 = x_pow_p1 * x
return (1. / x + a * x_pow_p0 + b * x_pow_p1 +
return (1.0 / x + a * x_pow_p0 + b * x_pow_p1 +
c * x_pow_p2) * (x < 1.0).to(x.dtype)


Expand All @@ -66,13 +66,18 @@ def reset_parameters(self):
self.freq.requires_grad_()

def forward(self, dist: Tensor) -> Tensor:
dist = (dist.unsqueeze(-1) / self.cutoff)
dist = dist.unsqueeze(-1) / self.cutoff
return self.envelope(dist) * (self.freq * dist).sin()


class SphericalBasisLayer(torch.nn.Module):
def __init__(self, num_spherical: int, num_radial: int,
cutoff: float = 5.0, envelope_exponent: int = 5):
def __init__(
self,
num_spherical: int,
num_radial: int,
cutoff: float = 5.0,
envelope_exponent: int = 5,
):
super().__init__()
import sympy as sym

Expand Down Expand Up @@ -159,9 +164,16 @@ def forward(self, x: Tensor) -> Tensor:


class InteractionBlock(torch.nn.Module):
def __init__(self, hidden_channels: int, num_bilinear: int,
num_spherical: int, num_radial: int, num_before_skip: int,
num_after_skip: int, act: Callable):
def __init__(
self,
hidden_channels: int,
num_bilinear: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -223,9 +235,17 @@ def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,


class InteractionPPBlock(torch.nn.Module):
def __init__(self, hidden_channels: int, int_emb_size: int,
basis_emb_size: int, num_spherical: int, num_radial: int,
num_before_skip: int, num_after_skip: int, act: Callable):
def __init__(
self,
hidden_channels: int,
int_emb_size: int,
basis_emb_size: int,
num_spherical: int,
num_radial: int,
num_before_skip: int,
num_after_skip: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -311,8 +331,14 @@ def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,


class OutputBlock(torch.nn.Module):
def __init__(self, num_radial: int, hidden_channels: int,
out_channels: int, num_layers: int, act: Callable):
def __init__(
self,
num_radial: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -341,9 +367,15 @@ def forward(self, x: Tensor, rbf: Tensor, i: Tensor,


class OutputPPBlock(torch.nn.Module):
def __init__(self, num_radial: int, hidden_channels: int,
out_emb_channels: int, out_channels: int, num_layers: int,
act: Callable):
def __init__(
self,
num_radial: int,
hidden_channels: int,
out_emb_channels: int,
out_channels: int,
num_layers: int,
act: Callable,
):
super().__init__()
self.act = act

Expand Down Expand Up @@ -450,7 +482,7 @@ def __init__(
num_blocks: int,
num_bilinear: int,
num_spherical: int,
num_radial,
num_radial: int,
cutoff: float = 5.0,
max_num_neighbors: int = 32,
envelope_exponent: int = 5,
Expand All @@ -462,7 +494,7 @@ def __init__(
super().__init__()

if num_spherical < 2:
raise ValueError("num_spherical should be greater than 1")
raise ValueError("'num_spherical' should be greater than 1")

act = activation_resolver(act)

Expand All @@ -482,9 +514,15 @@ def __init__(
])

self.interaction_blocks = torch.nn.ModuleList([
InteractionBlock(hidden_channels, num_bilinear, num_spherical,
num_radial, num_before_skip, num_after_skip, act)
for _ in range(num_blocks)
InteractionBlock(
hidden_channels,
num_bilinear,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
act,
) for _ in range(num_blocks)
])

def reset_parameters(self):
Expand All @@ -502,7 +540,7 @@ def from_qm9_pretrained(
root: str,
dataset: Dataset,
target: int,
) -> Tuple['DimeNet', Dataset, Dataset, Dataset]:
) -> Tuple['DimeNet', Dataset, Dataset, Dataset]: # pragma: no cover
r"""Returns a pre-trained :class:`DimeNet` model on the
:class:`~torch_geometric.datasets.QM9` dataset, trained on the
specified target :obj:`target`."""
Expand Down Expand Up @@ -729,15 +767,27 @@ def __init__(
# variable `num_bilinear` does not have any purpose as it is used
# solely in the `OutputBlock` of DimeNet:
self.output_blocks = torch.nn.ModuleList([
OutputPPBlock(num_radial, hidden_channels, out_emb_channels,
out_channels, num_output_layers, act)
for _ in range(num_blocks + 1)
OutputPPBlock(
num_radial,
hidden_channels,
out_emb_channels,
out_channels,
num_output_layers,
act,
) for _ in range(num_blocks + 1)
])

self.interaction_blocks = torch.nn.ModuleList([
InteractionPPBlock(hidden_channels, int_emb_size, basis_emb_size,
num_spherical, num_radial, num_before_skip,
num_after_skip, act) for _ in range(num_blocks)
InteractionPPBlock(
hidden_channels,
int_emb_size,
basis_emb_size,
num_spherical,
num_radial,
num_before_skip,
num_after_skip,
act,
) for _ in range(num_blocks)
])

self.reset_parameters()
Expand All @@ -748,7 +798,8 @@ def from_qm9_pretrained(
root: str,
dataset: Dataset,
target: int,
) -> Tuple['DimeNetPlusPlus', Dataset, Dataset, Dataset]:
) -> Tuple['DimeNetPlusPlus', Dataset, Dataset,
Dataset]: # pragma: no cover
r"""Returns a pre-trained :class:`DimeNetPlusPlus` model on the
:class:`~torch_geometric.datasets.QM9` dataset, trained on the
specified target :obj:`target`."""
Expand Down