Skip to content

Commit

Permalink
Fix: moltochemprop not compatible with chemprop 2.0.3 (#37)
Browse files Browse the repository at this point in the history
adapt to chemprop 2.0.3
  • Loading branch information
c-w-feldmann authored Jul 3, 2024
1 parent 41bb336 commit 58984cd
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Wrapper for Chemprop Featurizer."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Wrapper for Chemprop GraphFeaturizer."""

from dataclasses import InitVar
from typing import Any

try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Self

from chemprop.featurizers.molgraph import (
SimpleMoleculeMolGraphFeaturizer as _SimpleMoleculeMolGraphFeaturizer,
)


class SimpleMoleculeMolGraphFeaturizer(_SimpleMoleculeMolGraphFeaturizer):
"""Wrapper for Chemprop SimpleMoleculeMolGraphFeaturizer."""

extra_atom_fdim: InitVar[int]
extra_bond_fdim: InitVar[int]

def get_params(
self, deep: bool = True # pylint: disable=unused-argument
) -> dict[str, InitVar[int]]:
"""Get parameters for the featurizer.
Parameters
----------
deep: bool, optional (default=True)
Used for compatibility with scikit-learn.
Returns
-------
dict[str, int]
Parameters of the featurizer.
"""
return {}

def set_params(self, **parameters: Any) -> Self: # pylint: disable=unused-argument
"""Set the parameters of the featurizer.
Parameters
----------
parameters: Any
Parameters to set. Only used for compatibility with scikit-learn.
Returns
-------
Self
This featurizer with the parameters set.
"""
return self
102 changes: 93 additions & 9 deletions molpipeline/mol2any/mol2chemprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
from __future__ import annotations

import warnings
from typing import Iterable, Optional
from typing import Any, Iterable, Optional

import numpy as np
import numpy.typing as npt

try:
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.featurizers.molecule import MoleculeFeaturizer
from chemprop.featurizers.base import GraphFeaturizer, VectorFeaturizer

from molpipeline.estimators.chemprop.featurizer_wrapper.graph_wrapper import (
SimpleMoleculeMolGraphFeaturizer,
)
except ImportError:
warnings.warn(
"chemprop not installed. MolToChemprop will not work.",
ImportWarning,
)


from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement
from molpipeline.utils.molpipeline_types import RDKitMol

Expand All @@ -31,11 +39,13 @@ class MolToChemprop(MolToAnyPipelineElement):
[1] https://github.com/chemprop/chemprop/
"""

featurizer_list: list[MoleculeFeaturizer] | None
graph_featurizer: GraphFeaturizer[RDKitMol] | None
mol_featurizer: VectorFeaturizer[RDKitMol] | None

def __init__(
self,
featurizer_list: list[MoleculeFeaturizer] | None = None,
graph_featurizer: GraphFeaturizer[RDKitMol] | None = None,
mol_featurizer: VectorFeaturizer[RDKitMol] | None = None,
name: str = "Mol2Chemprop",
n_jobs: int = 1,
uuid: Optional[str] = None,
Expand All @@ -44,16 +54,21 @@ def __init__(
Parameters
----------
featurizer_list: list[MoleculeFeaturizer] | None, optional (default=None)
List of featurizers to use.
graph_featurizer: GraphFeaturizer[RDKitMol] | None, optional (default=None)
Defines how the graph is featurized. Defaults to None.
mol_featurizer: MoleculeFeaturizer | None, optional (default=None)
In contrast to graph_featurizer, features from the mol_featurizer are not used during the message passing.
These features are concatenated to the neural fingerprints before the feedforward layers.
name: str, optional (default="Mol2Chemprop")
Name of the pipeline element. Defaults to "Mol2Chemprop".
n_jobs: int
Number of parallel jobs to use. Defaults to 1.
uuid: str | None, optional (default=None)
UUID of the pipeline element.
"""
self.featurizer_list = featurizer_list
self.graph_featurizer = graph_featurizer or SimpleMoleculeMolGraphFeaturizer()
self.mol_featurizer = mol_featurizer

super().__init__(
name=name,
n_jobs=n_jobs,
Expand All @@ -73,7 +88,10 @@ def pretransform_single(self, value: RDKitMol) -> MoleculeDatapoint:
MoleculeDatapoint
Molecular representation used as input for ChemProp. None if transformation failed.
"""
return MoleculeDatapoint(mol=value, mfs=self.featurizer_list)
mol_features: npt.NDArray[np.float64] | None = None
if self.mol_featurizer is not None:
mol_features = np.array(self.mol_featurizer(value), dtype=np.float64)
return MoleculeDatapoint(mol=value, x_d=mol_features)

def assemble_output(
self, value_list: Iterable[MoleculeDatapoint]
Expand All @@ -90,4 +108,70 @@ def assemble_output(
Any
Assembled output.
"""
return MoleculeDataset(data=list(value_list))
return MoleculeDataset(data=list(value_list), featurizer=self.graph_featurizer)

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get parameters for this pipeline element.
Parameters
----------
deep: bool, optional (default=True)
If True, will return the parameters for this pipeline element and its subobjects.
Returns
-------
dict
Parameters of this pipeline element.
"""
params = super().get_params(deep=deep)
params["graph_featurizer"] = self.graph_featurizer
params["mol_featurizer"] = self.mol_featurizer

if deep:
if hasattr(self.graph_featurizer, "get_params"):
graph_featurizer_params = self.graph_featurizer.get_params(deep=deep) # type: ignore
for key, value in graph_featurizer_params.items():
params[f"graph_featurizer__{key}"] = value
if hasattr(self.mol_featurizer, "get_params"):
mol_featurizer_params = self.mol_featurizer.get_params(deep=deep) # type: ignore
for key, value in mol_featurizer_params.items():
params[f"mol_featurizer__{key}"] = value
return params

def set_params(self, **parameters: Any) -> MolToChemprop:
"""Set the parameters of this pipeline element.
Parameters
----------
**parameters: Any
Parameters to set.
Returns
-------
MolToChemprop
This pipeline element with the parameters set.
"""
param_copy = dict(parameters)
graph_featurizer = param_copy.pop("graph_featurizer", None)
mol_featurizer = param_copy.pop("mol_featurizer", None)
if graph_featurizer is not None:
self.graph_featurizer = graph_featurizer
if mol_featurizer is not None:
self.mol_featurizer = mol_featurizer
graph_featurizer_params = {}
mol_featurizer_params = {}
for key in list(param_copy.keys()):
if "__" not in key:
continue
component_name, _, param_name = key.partition("__")
if component_name == "graph_featurizer":
graph_featurizer_params[param_name] = param_copy.pop(key)
elif component_name == "mol_featurizer":
mol_featurizer_params[param_name] = param_copy.pop(key)
if hasattr(self.graph_featurizer, "set_params"):
self.graph_featurizer.set_params(**graph_featurizer_params) # type: ignore
if hasattr(self.mol_featurizer, "set_params"):
self.mol_featurizer.set_params(**mol_featurizer_params) # type: ignore

super().set_params(**param_copy)
return self
2 changes: 1 addition & 1 deletion requirements_chemprop.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
chemprop >= 2.0.0, < 2.0.3
chemprop>=2.0.3
lightning
30 changes: 5 additions & 25 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import joblib
import numpy as np
import pandas as pd
from chemprop.nn.loss import LossFunction
from lightning import pytorch as pl
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.profilers.base import PassThroughProfiler
from sklearn.base import clone
from torch import nn

from molpipeline.any2mol import SmilesToMol
from molpipeline.error_handling import ErrorFilter, FilterReinserter
from molpipeline.estimators.chemprop.abstract import ABCChemprop
from molpipeline.estimators.chemprop.component_wrapper import (
MPNN,
BinaryClassificationFFN,
Expand All @@ -30,6 +27,7 @@
from molpipeline.mol2any.mol2chemprop import MolToChemprop
from molpipeline.pipeline import Pipeline
from molpipeline.post_prediction import PostPredictionWrapper
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params


# pylint: disable=duplicate-code
Expand Down Expand Up @@ -199,6 +197,9 @@ def test_clone(self) -> None:
self.assertEqual(step.__class__, cloned_step.__class__)
params = step.get_params(deep=True) # type: ignore
cloned_params = cloned_step.get_params(deep=True)
if isinstance(step, ABCChemprop):
compare_params(self, step, cloned_step)
continue
for param_name, param in params.items():
# If parm implements get_params, it was cloned as well and we need
# to compare the parameters. Since all parameters are listed flat in
Expand All @@ -207,27 +208,6 @@ def test_clone(self) -> None:
self.assertEqual(
param.__class__, cloned_params[param_name].__class__
)
elif param_name == "lightning_trainer":
# Lightning trainer does not implement get_params so things are a bit tricky
# at the moment. We can only check if the classes are the same.
self.assertEqual(
param.__class__, cloned_params[param_name].__class__
)
elif isinstance(param, LossFunction):
self.assertEqual(
param.state_dict()["task_weights"],
cloned_params[param_name].state_dict()["task_weights"],
)
self.assertEqual(type(param), type(cloned_params[param_name]))
elif isinstance(param, (nn.Identity, Accelerator, PassThroughProfiler)):
self.assertEqual(type(param), type(cloned_params[param_name]))
elif param_name == "lightning_trainer__callbacks":
self.assertIsInstance(cloned_params[param_name], list)
self.assertEqual(len(param), len(cloned_params[param_name]))
for callback, cloned_callback in zip(
param, cloned_params[param_name]
):
self.assertEqual(type(callback), type(cloned_callback))
else:
self.assertEqual(
param, cloned_params[param_name], f"Failed for {param_name}"
Expand Down
9 changes: 6 additions & 3 deletions test_extras/test_chemprop/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json

# pylint: disable=relative-beyond-top-level
from .chemprop_test_utils.compare_models import compare_params
from .chemprop_test_utils.constant_vars import DEFAULT_PARAMS, NO_IDENTITY_CHECK
from .chemprop_test_utils.default_models import (
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
from test_extras.test_chemprop.chemprop_test_utils.constant_vars import (
DEFAULT_PARAMS,
NO_IDENTITY_CHECK,
)
from test_extras.test_chemprop.chemprop_test_utils.default_models import (
get_chemprop_model_binary_classification_mpnn,
)

Expand Down
6 changes: 4 additions & 2 deletions test_extras/test_chemprop/test_neural_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json

# pylint: disable=relative-beyond-top-level
from .chemprop_test_utils.compare_models import compare_params
from .chemprop_test_utils.default_models import get_neural_fp_encoder
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
from test_extras.test_chemprop.chemprop_test_utils.default_models import (
get_neural_fp_encoder,
)

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)

Expand Down

0 comments on commit 58984cd

Please sign in to comment.