Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TDC Oracle integration #145

Merged
merged 7 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/regression_transformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ conda activate gt4sd
To launch a finetuning of a RT pretrained on drug-like moelcules from ChEMBL, execute the following from the GT4SD root:

```console
gt4sd-trainer --training_pipeline_name regression-transformer-trainer --model_path ~/.gt4sd/algorithms/conditional_generation/RegressionTransformer/RegressionTransformerMolecules/qed --do_train --output_dir dummy_regression_transformer --train_data_path src/gt4sd/training_pipelines/tests/regression_transformer_raw.csv --test_data_path src/gt4sd/training_pipelines/tests/regression_transformer_raw.csv --overwrite_output_dir --eval_steps 2 --augment 10 --test_fraction 0.2 --eval_accumulation_steps 1
gt4sd-trainer --training_pipeline_name regression-transformer-trainer --model_path ~/.gt4sd/algorithms/conditional_generation/RegressionTransformer/RegressionTransformerMolecules/qed --do_train --output_dir dummy_regression_transformer --train_data_path src/gt4sd/training_pipelines/tests/regression_transformer_raw.csv --test_data_path src/gt4sd/training_pipelines/tests/regression_transformer_raw.csv --overwrite_output_dir --eval_steps 2 --augment 10 --eval_accumulation_steps 1
```
*NOTE*: This is *dummy* example, do not use "as is" :warning:

Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
accelerate>=0.12
datasets>=1.11.0
diffusers>=0.2.4
ipaddress>=1.0.23
joblib>=1.1.0
keras==2.3.1
keybert==0.2.0
Expand All @@ -11,13 +12,13 @@ numpy>=1.16.5
protobuf<3.20
pytorch_lightning<=1.5.0
pydantic>=1.7.3,<=1.9.2
PyTDC>=0.3.6
PyTDC>=0.3.7
pyyaml>=5.4.1
rdkit-pypi>=2020.9.5.2,<=2021.9.4
regex>=2.5.91
reinvent-chemistry==0.0.38
sacremoses>=0.0.41
scikit-learn<0.24.0
scikit-learn>=1.0.0
scikit-optimize>=0.8.1
sentencepiece>=0.1.95
sympy>=1.10.1
Expand All @@ -31,3 +32,4 @@ torchvision>=0.12.0
transformers>=4.2.1
typing_extensions>=3.7.4.3
wheel>=0.26
importlib-metadata<5.0.0 # temporary: https://github.com/python/importlib_metadata/issues/409
10 changes: 10 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ install_requires =
accelerate
datasets
diffusers
ipaddress
joblib
keras
keybert
Expand Down Expand Up @@ -229,3 +230,12 @@ ignore_missing_imports = True

[mypy-sympy.*]
ignore_missing_imports = True

[mypy-openbabel.*]
ignore_missing_imports = True

[mypy-pyscreener.*]
ignore_missing_imports = True

[mypy-pdbfixer.*]
ignore_missing_imports = True
17 changes: 17 additions & 0 deletions src/gt4sd/properties/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@ class S3Parameters(PropertyPredictorParameters):
algorithm_application: str = Field(..., example="Tox21")


class ApiTokenParameters(PropertyPredictorParameters):
api_token: str = Field(
...,
example="apk-c9db......",
description="The API token/key to access the service",
)


class IpAdressParameters(PropertyPredictorParameters):

host_ip: str = Field(
...,
example="xx.xx.xxx.xxx",
description="The host IP address to access the service",
)


class PropertyPredictor:
"""PropertyPredictor base class."""

Expand Down
18 changes: 16 additions & 2 deletions src/gt4sd/properties/molecules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,24 @@
from ...algorithms.core import PredictorAlgorithm
from ..core import PropertyPredictor, PropertyPredictorParameters
from .core import (
SIDER,
ActivityAgainstTarget,
ActivityAgainstTargetParameters,
Askcos,
AskcosParameters,
Bertz,
ClinTox,
ClinToxParameters,
Docking,
DockingParameters,
DockingTdc,
DockingTdcParameters,
Esol,
IsScaffold,
Lipinski,
Logp,
MolecularWeight,
MoleculeOne,
MoleculeOneParameters,
NumberAromaticRings,
NumberAtoms,
NumberHAcceptors,
Expand All @@ -53,6 +60,7 @@
Sas,
Scscore,
ScscoreConfiguration,
Sider,
SiderParameters,
SimilaritySeed,
SimilaritySeedParameters,
Expand Down Expand Up @@ -96,9 +104,15 @@
"scscore": (Scscore, ScscoreConfiguration),
"activity_against_target": (ActivityAgainstTarget, ActivityAgainstTargetParameters),
"tox21": (Tox21, Tox21Parameters),
"sider": (SIDER, SiderParameters),
"sider": (Sider, SiderParameters),
"organtox": (OrganTox, OrganToxParameters),
"clintox": (ClinTox, ClinToxParameters),
# # properties from models requiring authentification
"askcos": (Askcos, AskcosParameters),
"molecule_one": (MoleculeOne, MoleculeOneParameters),
# # properties from models require additional installations
"docking_tdc": (DockingTdc, DockingTdcParameters),
"docking": (Docking, DockingParameters),
}


Expand Down
170 changes: 166 additions & 4 deletions src/gt4sd/properties/molecules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
from enum import Enum
from typing import List

from paccmann_generator.drug_evaluators import SIDER as _SIDER
from paccmann_generator.drug_evaluators import SIDER
from paccmann_generator.drug_evaluators import ClinTox as _ClinTox
from paccmann_generator.drug_evaluators import OrganDB as _OrganTox
from paccmann_generator.drug_evaluators import SCScore
from paccmann_generator.drug_evaluators import Tox21 as _Tox21
from pydantic import Field
from tdc import Oracle
from tdc.metadata import download_receptor_oracle_name

from ...algorithms.core import (
ConfigurablePropertyAlgorithmConfiguration,
Expand All @@ -38,13 +40,23 @@
)
from ...domains.materials import SmallMolecule
from ..core import (
ApiTokenParameters,
CallablePropertyPredictor,
ConfigurableCallablePropertyPredictor,
DomainSubmodule,
IpAdressParameters,
PropertyPredictorParameters,
PropertyValue,
S3Parameters,
)
from ..utils import get_activity_fn, get_similarity_fn, to_smiles
from ..utils import (
docking_import_check,
get_activity_fn,
get_similarity_fn,
to_smiles,
validate_api_token,
validate_ip,
)
from .functions import (
bertz,
esol,
Expand Down Expand Up @@ -84,6 +96,71 @@ class ActivityAgainstTargetParameters(PropertyPredictorParameters):
target: str = Field(..., example="drd2", description="name of the target.")


class AskcosParameters(IpAdressParameters):
class Output(str, Enum):
plausability: str = "plausibility"
num_step: str = "num_step"
synthesizability: str = "synthesizability"
price: str = "price"

output: Output = Field(
default=Output.plausability,
example=Output.synthesizability,
description="Main output return type from ASKCOS",
options=["plausibility", "num_step", "synthesizability", "price"],
)
save_json: bool = Field(default=False)
file_name: str = Field(default="tree_builder_result.json")
num_trials: int = Field(default=5)
max_depth: int = Field(default=9)
max_branching: int = Field(default=25)
expansion_time: int = Field(default=60)
max_ppg: int = Field(default=100)
template_count: int = Field(default=1000)
max_cum_prob: float = Field(default=0.999)
chemical_property_logic: str = Field(default="none")
max_chemprop_c: int = Field(default=0)
max_chemprop_n: int = Field(default=0)
max_chemprop_o: int = Field(default=0)
max_chemprop_h: int = Field(default=0)
chemical_popularity_logic: str = Field(default="none")
min_chempop_reactants: int = Field(default=5)
min_chempop_products: int = Field(default=5)
filter_threshold: float = Field(default=0.1)
return_first: str = Field(default="true")

# Convert enum items back to strings
class Config:
use_enum_values = True


class MoleculeOneParameters(ApiTokenParameters):

oracle_name: str = "Molecule One Synthesis"


class DockingTdcParameters(PropertyPredictorParameters):
# To dock against a receptor defined via TDC
target: str = Field(
...,
example="1iep_docking",
description="Target for docking, provided via TDC",
options=download_receptor_oracle_name,
)


class DockingParameters(PropertyPredictorParameters):
# To dock against a user-provided receptor
name: str = Field(default="pyscreener")
receptor_pdb_file: str = Field(
example="/tmp/2hbs.pdb", description="Path to receptor PDB file"
)
box_center: List[int] = Field(
example=[15.190, 53.903, 16.917], description="Docking box center"
)
box_size: List[float] = Field(example=[20, 20, 20], description="Docking box size")


class S3ParametersMolecules(S3Parameters):
domain: DomainSubmodule = DomainSubmodule("molecules")

Expand Down Expand Up @@ -355,6 +432,91 @@ def __init__(self, parameters: ActivityAgainstTargetParameters) -> None:
)


class Askcos(ConfigurableCallablePropertyPredictor):
"""
A property predictor that uses the ASKCOs API to calculate the synthesizability
of a molecule.
"""

def __init__(self, parameters: AskcosParameters):

# Raises if IP is not valid
msg = (
"You have to point to an IP address of a running ASKCOS instance. "
"For details on setting this up, see: https://tdcommons.ai/functions/oracles/#askcos"
)
if not isinstance(parameters.host_ip, str):
raise TypeError(f"IP adress must be a string, not {parameters.host_ip}")

if not hasattr(parameters, "host_ip"):
raise AttributeError(f"IP adress missing in {parameters}")

if "http" not in parameters.host_ip:
raise ValueError(
f"ASKCOS requires an IP prepended with a http, e.g., "
f"'http://xx.xx.xxx.xxx' and not {parameters.host_ip}."
)
ip = parameters.host_ip.split("//")[1]

validate_ip(ip, message=msg)
super().__init__(callable_fn=Oracle(name="ASKCOS"), parameters=parameters)


class MoleculeOne(CallablePropertyPredictor):
"""
A property predictor that uses the MoleculeOne API to calculate the synthesizability
of a molecule.
"""

def __init__(self, parameters: MoleculeOneParameters):

msg = (
"You have to provide a valid API key, for details on setting this up, see: "
"https://tdcommons.ai/functions/oracles/#moleculeone"
)

# Only performs type checking on API key
validate_api_token(parameters, message=msg)

super().__init__(
callable_fn=Oracle(
name=parameters.oracle_name, api_token=parameters.api_token
),
parameters=parameters,
)


class DockingTdc(ConfigurableCallablePropertyPredictor):
"""
A property predictor that computes the docking score against a target
provided via the TDC package (see: https://tdcommons.ai/functions/oracles/#docking-scores)
"""

def __init__(self, parameters: DockingTdcParameters):

docking_import_check()
callable = Oracle(name=parameters.target)
super().__init__(callable_fn=callable, parameters=parameters)


class Docking(ConfigurableCallablePropertyPredictor):
"""
A property predictor that computes the docking score against a user-defined target.
Relies on TDC backend, see https://tdcommons.ai/functions/oracles/#docking-scores for setup.
"""

def __init__(self, parameters: DockingParameters):

docking_import_check()
callable = Oracle(
name=parameters.name,
receptor_pdb_file=parameters.receptor_pdb_file,
box_center=parameters.box_center,
box_size=parameters.box_size,
)
super().__init__(callable_fn=callable, parameters=parameters)


class _MCA(PredictorAlgorithm):
"""Base class for all MCA-based predictive algorithms."""

Expand Down Expand Up @@ -440,7 +602,7 @@ def get_description(cls) -> str:
return text


class SIDER(_MCA):
class Sider(_MCA):
def get_model(self, resources_path: str) -> Predictor:
"""Instantiate the actual model.

Expand All @@ -451,7 +613,7 @@ def get_model(self, resources_path: str) -> Predictor:
Predictor: the model.
"""
# This model returns a singular reward and not a prediction for both classes.
model = _SIDER(model_path=resources_path)
model = SIDER(model_path=resources_path)

# Wrapper to get toxicity-endpoint-level predictions
def informative_model(x: SmallMolecule) -> List[PropertyValue]:
Expand Down
Loading