Skip to content

Commit

Permalink
MoLeR support for latent encoding (#205)
Browse files Browse the repository at this point in the history
* feat: MoLeR seed smiles support

* test: expand moler tests

* chore: ignore dataset imports

* chore: formatting

* refactor: move from rdkit-pypi to rdkit

* fix: seed based decoding fixed

* chore: remove old file

* feat: specifying noise for random latent sampling

* fix: latent shape

* chore: isort
  • Loading branch information
jannisborn authored Mar 9, 2023
1 parent 45a1484 commit f658502
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 16 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pymatgen>=2022.11.7
PyTDC==0.3.7
pytorch_lightning<=1.7.7
pyyaml>=5.4.1
rdkit-pypi>=2020.9.5.2,<=2021.9.4
rdkit>=2022.3.5
regex>=2.5.91
reinvent-chemistry==0.0.38
sacremoses>=0.0.41
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ install_requires =
pymatgen
pyTDC!=0.3.8
pyyaml
rdkit-pypi
rdkit
regex
reinvent-chemistry
sacremoses
Expand Down Expand Up @@ -260,3 +260,6 @@ ignore_missing_imports = True

[mypy-pymatgen.*]
ignore_missing_imports = True

[mypy-datasets.*]
ignore_missing_imports = True
14 changes: 14 additions & 0 deletions src/gt4sd/algorithms/generation/moler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ class MoLeRDefaultGenerator(AlgorithmConfiguration[SMILES, Any]):
default=6,
metadata=dict(description="Number of workers used for generation."),
)
seed_smiles: str = field(
default="",
metadata=dict(
description="Dot-separated SMILES used to initialize the encoder. If empty, random codes are used."
),
)
sigma: float = field(
default=0.0,
metadata=dict(
description="Variance of Gaussian noise being added to latent code."
),
)

def get_target_description(self) -> Optional[Dict[str, str]]:
"""Get description of the target for generation.
Expand All @@ -159,6 +171,8 @@ def get_conditional_generator(self, resources_path: str) -> MoLeRGenerator:
beam_size=self.beam_size,
seed=self.seed,
num_workers=self.num_workers,
seed_smiles=self.seed_smiles,
sigma=self.sigma,
)

def validate_item(self, item: str) -> SMILES:
Expand Down
42 changes: 33 additions & 9 deletions src/gt4sd/algorithms/generation/moler/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from itertools import cycle, islice
from typing import List

import numpy as np
from rdkit import Chem
from molecule_generation import VaeWrapper

Expand All @@ -45,6 +46,8 @@ def __init__(
beam_size: int,
seed: int,
num_workers: int,
seed_smiles: str,
sigma: float,
) -> None:
"""Instantiate a MoLeR generator.
Expand All @@ -55,30 +58,43 @@ def __init__(
beam_size: beam size to use during decoding.
seed: seed used for random number generation.
num_workers: number of workers used for generation.
seed_smiles: dot-separated SMILES used to initialize the decoder. If empty,
random codes are sampled from the latent space.
sigma: variance of gaussian noise being added to the latent code.
Raises:
RuntimeError: in the case extras are disabled.
"""
# loading artifacts
self.resources_path = resources_path
self.scaffolds = scaffolds
self.num_samples = num_samples
self.beam_size = beam_size
self.num_workers = num_workers
self._seed = seed
self.sigma = sigma

# Process context
self.seed_smiles = [
smi for smi in seed_smiles.split(".") if Chem.MolFromSmiles(smi) is not None
]
self.scaffolds = [
scaffold
for scaffold in scaffolds.split(".")
if Chem.MolFromSmiles(scaffold) is not None
]
# Repeat scaffolds if needed
if self.scaffolds != [""] and len(self.scaffolds) < self.num_samples:
self.scaffolds = list(islice(cycle(self.scaffolds), self.num_samples))
# Repeat seed smiles if needed
if self.seed_smiles != [""] and len(self.seed_smiles) < self.num_samples:
self.seed_smiles = list(islice(cycle(self.seed_smiles), self.num_samples))

def generate(self) -> List[str]:
"""Sample molecules using MoLeR.
Returns:
sampled molecule (SMILES).
"""
# process scaffolds
valid_scaffolds = [
scaffold
for scaffold in self.scaffolds.split(".")
if Chem.MolFromSmiles(scaffold) is not None
]
# generate molecules
logger.info("running MoLeR...")
with VaeWrapper(
Expand All @@ -87,8 +103,16 @@ def generate(self) -> List[str]:
seed=self._seed,
num_workers=self.num_workers,
) as model:
latents = model.sample_latents(self.num_samples)
scaffolds = list(islice(cycle(valid_scaffolds), self.num_samples))
if self.seed_smiles == [""]:
latents = model.sample_latents(self.num_samples)
else:
latents = np.stack(model.encode(self.seed_smiles))

# Add noise to latent codes
latents = latents + self.sigma * np.random.randn(*latents.shape).astype(
np.float32
)
scaffolds = list(islice(cycle(self.scaffolds), self.num_samples))
samples = model.decode(
latents=latents,
scaffolds=scaffolds if len(scaffolds) == self.num_samples else None,
Expand Down
48 changes: 44 additions & 4 deletions src/gt4sd/algorithms/generation/tests/test_moler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,56 @@ def test_available_versions(config_class: Type[AlgorithmConfiguration]):


@pytest.mark.parametrize(
"config, algorithm",
"config, algorithm, params",
[
# Test unconditional generation
(MoLeRDefaultGenerator, MoLeR, {"seed": 42, "num_workers": 1}),
# Test soft-conditioning on latent codes
(
MoLeRDefaultGenerator,
MoLeR,
)
{
"seed": 0,
"seed_smiles": "c1ccccc1.CNC=O.CN1C=NC2=C1C(=O)N(C(=O)N2C)C.[O-]C(=O)[O-].CCOCN",
"num_workers": 1,
},
),
# Test hard-conditioning on scaffolds
(
MoLeRDefaultGenerator,
MoLeR,
{
"seed": 0,
"scaffolds": "CN.O=C1C2=CC=C(C3=CC=CC=C3)C=C=C2OC2=CC=CC=C12",
"num_workers": 1,
},
),
# Test both together
(
MoLeRDefaultGenerator,
MoLeR,
{
"seed": 0,
"seed_smiles": "c1ccccc1.CNC=O.CN1C=NC2=C1C(=O)N(C(=O)N2C)C.[O-]C(=O)[O-].CCOCN",
"scaffolds": "CN.CCC",
"num_workers": 1,
},
),
# Test both together with unequal sizes
(
MoLeRDefaultGenerator,
MoLeR,
{
"seed": 0,
"seed_smiles": "c1ccccc1.CNC=O.CN1C=NC2=C1C(=O)N(C(=O)N2C)C.[O-]C(=O)[O-].CCOCN",
"scaffolds": "CN.CCC.CNO",
"num_workers": 1,
},
),
],
)
def test_generation_via_import(config, algorithm):
algorithm = algorithm(configuration=config())
def test_generation_via_import(config, algorithm, params):
algorithm = algorithm(configuration=config(**params))
items = list(algorithm.sample(5))
assert len(items) == 5

Expand Down
1 change: 0 additions & 1 deletion src/gt4sd/algorithms/generation/tests/test_paccmann_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_config_instance(config_class: Type[AlgorithmConfiguration]):
)
def test_available_versions(config_class: Type[AlgorithmConfiguration]):
versions = config_class.list_versions()
print("HERE", versions)
assert "v0" in versions


Expand Down

0 comments on commit f658502

Please sign in to comment.