diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 7986de7f..df576a14 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -30,7 +30,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install $(find . -name "requirement*" -type f -printf ' -r %p') pip install mypy mypy . || exit_code=$? mypy --install-types --non-interactive @@ -148,3 +147,69 @@ jobs: - name: Analysing the code with isort run: | isort --profile black . + + test_basis: + needs: + - pylint + - mypy + - pydocstyle + - docsig + - black + - flake8 + - interrogate + - bandit + - isort + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install package + run: | + python -m pip install --upgrade pip + pip install . + - name: Run unit-tests + run: | + python -m unittest discover -v -s tests -t . + + test_chemprop: + needs: + - test_basis + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Install package + run: | + python -m pip install --upgrade pip + pip install torch + pip install .[chemprop] + - name: Run unit-tests for chemprop + run: | + python -m unittest discover -v -s test_extras/test_chemprop -t . + + test_notebooks: + needs: + - test_basis + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: "3.11" + - name: Install package + run: | + python -m pip install --upgrade pip + pip install .[notebooks] + - name: Run unit-tests for notebooks + run: | + python test_extras/test_notebooks/test_notebooks.py --continue-on-failure diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml deleted file mode 100644 index 1faaf0f4..00000000 --- a/.github/workflows/unittests.yml +++ /dev/null @@ -1,56 +0,0 @@ -name: Unit-Tests - -on: [push] - -jobs: - test_basis: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install package - run: | - python -m pip install --upgrade pip - pip install . - - name: Run unit-tests - run: | - python -m unittest discover -v -s tests -t . - - test_chemprop: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v3 - with: - python-version: "3.11" - - name: Install package - run: | - python -m pip install --upgrade pip - pip install torch - pip install .[chemprop] - - name: Run unit-tests for chemprop - run: | - python -m unittest discover -v -s test_extras/test_chemprop -t . - - test_notebooks: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v3 - with: - python-version: "3.11" - - name: Install package - run: | - python -m pip install --upgrade pip - pip install .[notebooks] - - name: Run unit-tests for notebooks - run: | - python test_extras/test_notebooks/test_notebooks.py --continue-on-failure diff --git a/molpipeline/estimators/chemprop/__init__.py b/molpipeline/estimators/chemprop/__init__.py index 4c32d13b..656d1833 100644 --- a/molpipeline/estimators/chemprop/__init__.py +++ b/molpipeline/estimators/chemprop/__init__.py @@ -1,11 +1,18 @@ """Initialize Chemprop module.""" -import pkgutil +try: + from molpipeline.estimators.chemprop.models import ( # noqa: F401 + ChempropClassifier, + ChempropModel, + ChempropNeuralFP, + ChempropRegressor, + ) -installed_packages = {pkg.name for pkg in pkgutil.iter_modules()} -if "chemprop" in installed_packages: - from molpipeline.estimators.chemprop.models import ChempropModel # noqa - - __all__ = ["ChempropModel"] -else: + __all__ = [ + "ChempropClassifier", + "ChempropModel", + "ChempropNeuralFP", + "ChempropRegressor", + ] +except ImportError: __all__ = [] diff --git a/molpipeline/estimators/chemprop/abstract.py b/molpipeline/estimators/chemprop/abstract.py index 262b5a60..c9018cac 100644 --- a/molpipeline/estimators/chemprop/abstract.py +++ b/molpipeline/estimators/chemprop/abstract.py @@ -13,7 +13,7 @@ import numpy.typing as npt try: - from chemprop.data import MoleculeDataset, MolGraphDataLoader + from chemprop.data import MoleculeDataset, build_dataloader from chemprop.models.model import MPNN from lightning import pytorch as pl except ImportError: @@ -90,7 +90,7 @@ def fit( if y.ndim == 1: y = y.reshape(-1, 1) X.Y = y - training_data = MolGraphDataLoader( + training_data = build_dataloader( X, batch_size=self.batch_size, num_workers=self.n_jobs ) self.lightning_trainer.fit(self.model, training_data) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index 93e1581e..f74c8824 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -1,5 +1,6 @@ """Wrapper classes for the chemprop components to make them compatible with scikit-learn.""" +import abc from typing import Any, Iterable, Self import torch @@ -9,12 +10,40 @@ from chemprop.nn.agg import MeanAggregation as _MeanAggregation from chemprop.nn.agg import SumAggregation as _SumAggregation from chemprop.nn.ffn import MLP -from chemprop.nn.loss import LossFunction +from chemprop.nn.loss import ( + BCELoss, + BinaryDirichletLoss, + CrossEntropyLoss, + EvidentialLoss, + LossFunction, + MSELoss, + MulticlassDirichletLoss, + MVELoss, + SIDLoss, +) from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing from chemprop.nn.message_passing import MessagePassing -from chemprop.nn.metrics import BCELoss, Metric +from chemprop.nn.metrics import ( + BinaryAUROCMetric, + CrossEntropyMetric, + Metric, + MSEMetric, + SIDMetric, +) from chemprop.nn.predictors import BinaryClassificationFFN as _BinaryClassificationFFN -from chemprop.nn.predictors import Predictor +from chemprop.nn.predictors import BinaryDirichletFFN as _BinaryDirichletFFN +from chemprop.nn.predictors import EvidentialFFN as _EvidentialFFN +from chemprop.nn.predictors import ( + MulticlassClassificationFFN as _MulticlassClassificationFFN, +) +from chemprop.nn.predictors import MulticlassDirichletFFN as _MulticlassDirichletFFN +from chemprop.nn.predictors import MveFFN as _MveFFN +from chemprop.nn.predictors import RegressionFFN as _RegressionFFN +from chemprop.nn.predictors import SpectralFFN as _SpectralFFN +from chemprop.nn.predictors import ( + _FFNPredictorBase as _Predictor, # pylint: disable=protected-access +) +from chemprop.nn.transforms import UnscaleTransform from chemprop.nn.utils import Activation, get_activation_function from sklearn.base import BaseEstimator from torch import Tensor, nn @@ -115,11 +144,11 @@ def set_params(self, **params: Any) -> Self: # pylint: disable=too-many-ancestors, too-many-instance-attributes -class BinaryClassificationFFN(_BinaryClassificationFFN, BaseEstimator): - """A wrapper for the BinaryClassificationFFN class.""" +class PredictorWrapper(_Predictor, BaseEstimator, abc.ABC): # type: ignore + """Abstract wrapper for the Predictor class.""" - n_targets: int = 1 - _default_criterion = BCELoss() + _T_default_criterion: LossFunction + _T_default_metric: Metric def __init__( self, @@ -130,6 +159,9 @@ def __init__( dropout: float = 0, activation: str = "relu", criterion: LossFunction | None = None, + task_weights: Tensor | None = None, + threshold: float | None = None, + output_transform: UnscaleTransform | None = None, ): """Initialize the BinaryClassificationFFN class. @@ -149,7 +181,15 @@ def __init__( Activation function. criterion : LossFunction or None, optional (default=None) Loss function. None defaults to BCELoss. + task_weights : Tensor or None, optional (default=None) + Task weights. + threshold : float or None, optional (default=None) + Threshold for binary classification. + output_transform : UnscaleTransform or None, optional (default=None) + Transformations to apply to the output. None defaults to UnscaleTransform. """ + if task_weights is None: + task_weights = torch.ones(n_tasks) super().__init__( n_tasks=n_tasks, input_dim=input_dim, @@ -158,6 +198,7 @@ def __init__( dropout=dropout, activation=activation, criterion=criterion, + output_transform=output_transform, ) self.n_tasks = n_tasks self._input_dim = input_dim @@ -165,6 +206,8 @@ def __init__( self.n_layers = n_layers self.dropout = dropout self.activation = activation + self.task_weights = task_weights + self.threshold = threshold @property def input_dim(self) -> int: @@ -218,13 +261,13 @@ def reinitialize_fnn(self) -> Self: Self The reinitialized feedforward network. """ - self.ffn = MLP( - self.input_dim, - self.n_tasks * self.n_targets, - self.hidden_dim, - self.n_layers, - self.dropout, - self.activation, + self.ffn = MLP.build( + input_dim=self.input_dim, + output_dim=self.n_tasks * self.n_targets, + hidden_dim=self.hidden_dim, + n_layers=self.n_layers, + dropout=self.dropout, + activation=self.activation, ) return self @@ -246,6 +289,68 @@ def set_params(self, **params: Any) -> Self: return self +class RegressionFFN(PredictorWrapper, _RegressionFFN): # type: ignore + """A wrapper for the RegressionFFN class.""" + + n_targets: int = 1 + _T_default_criterion = MSELoss + _T_default_metric = MSEMetric + + +class MveFFN(PredictorWrapper, _MveFFN): # type: ignore + """A wrapper for the MveFFN class.""" + + n_targets: int = 2 + _T_default_criterion = MVELoss + + +class EvidentialFFN(PredictorWrapper, _EvidentialFFN): # type: ignore + """A wrapper for the EvidentialFFN class.""" + + n_targets: int = 4 + _T_default_criterion = EvidentialLoss + + +class BinaryClassificationFFN(PredictorWrapper, _BinaryClassificationFFN): # type: ignore + """A wrapper for the BinaryClassificationFFN class.""" + + n_targets: int = 1 + _T_default_criterion = BCELoss + _T_default_metric = BinaryAUROCMetric + + +class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore + """A wrapper for the BinaryDirichletFFN class.""" + + n_targets: int = 2 + _T_default_criterion = BinaryDirichletLoss + _T_default_metric = BinaryAUROCMetric + + +class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN): # type: ignore + """A wrapper for the MulticlassClassificationFFN class.""" + + n_targets: int = 1 + _T_default_criterion = CrossEntropyLoss + _T_default_metric = CrossEntropyMetric + + +class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type: ignore + """A wrapper for the MulticlassDirichletFFN class.""" + + n_targets: int = 1 + _T_default_criterion = MulticlassDirichletLoss + _T_default_metric = CrossEntropyMetric + + +class SpectralFFN(PredictorWrapper, _SpectralFFN): # type: ignore + """A wrapper for the SpectralFFN class.""" + + n_targets: int = 1 + _T_default_criterion = SIDLoss + _T_default_metric = SIDMetric + + class MPNN(_MPNN, BaseEstimator): """A wrapper for the MPNN class. @@ -253,14 +358,15 @@ class MPNN(_MPNN, BaseEstimator): and a feedforward network for prediction. """ + bn: nn.BatchNorm1d | nn.Identity + def __init__( self, message_passing: MessagePassing, agg: Aggregation, - predictor: Predictor, + predictor: PredictorWrapper, batch_norm: bool = True, metric_list: Iterable[Metric] | None = None, - task_weight: Tensor | None = None, warmup_epochs: int = 2, init_lr: float = 1e-4, max_lr: float = 1e-3, @@ -280,8 +386,6 @@ def __init__( Whether to use batch normalization. metric_list : Iterable[Metric] | None, optional (default=None) The metrics to use for evaluation. - task_weight : Tensor | None, optional (default=None) - The weights to use for each task during training. If None, use uniform weights. warmup_epochs : int, optional (default=2) The number of epochs to use for the learning rate warmup. init_lr : float, optional (default=1e-4) @@ -292,20 +396,18 @@ def __init__( The final learning rate. """ super().__init__( - message_passing, - agg, - predictor, - batch_norm, - metric_list, - task_weight, - warmup_epochs, - init_lr, - max_lr, - final_lr, + message_passing=message_passing, + agg=agg, + predictor=predictor, + batch_norm=batch_norm, + metrics=metric_list, + warmup_epochs=warmup_epochs, + init_lr=init_lr, + max_lr=max_lr, + final_lr=final_lr, ) self.metric_list = metric_list self.batch_norm = batch_norm - self.task_weight = task_weight def reinitialize_network(self) -> Self: """Reinitialize the network with the current parameters. @@ -315,21 +417,17 @@ def reinitialize_network(self) -> Self: Self The reinitialized network. """ - self.bn = ( - nn.BatchNorm1d(self.message_passing.output_dim) - if self.batch_norm - else nn.Identity() - ) + if self.batch_norm: + self.bn = nn.BatchNorm1d(self.message_passing.output_dim) + else: + self.bn = nn.Identity() + if self.metric_list is None: # pylint: disable=protected-access - self.metrics = [self.predictor._default_metric, self.criterion] + self.metrics = [self.predictor._T_default_metric, self.criterion] else: self.metrics = list(self.metric_list) + [self.criterion] - if self.task_weight is None: - w_t = torch.ones(self.n_tasks) - else: - w_t = torch.tensor(self.task_weight) - self.w_t = nn.Parameter(w_t.unsqueeze(0), False) + return self def set_params(self, **params: Any) -> Self: diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 6a74cc87..6178b2e0 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -9,18 +9,22 @@ import numpy as np import numpy.typing as npt +from loguru import logger from sklearn.base import clone from sklearn.utils.metaestimators import available_if try: - from chemprop.data import MoleculeDataset, MolGraphDataLoader + from chemprop.data import MoleculeDataset, build_dataloader from chemprop.nn.predictors import ( BinaryClassificationFFNBase, MulticlassClassificationFFN, ) from lightning import pytorch as pl -except ImportError: - pass +except ImportError as error: + logger.error( + "Chemprop is not installed. Please install it using `pip install chemprop`." + ) + logger.info(error) from molpipeline.estimators.chemprop.abstract import ABCChemprop @@ -28,6 +32,7 @@ MPNN, BinaryClassificationFFN, BondMessagePassing, + RegressionFFN, SumAggregation, ) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP @@ -86,9 +91,10 @@ def _predict( The predictions for the input data. """ self.model.eval() - test_data = MolGraphDataLoader(X, num_workers=self.n_jobs, shuffle=False) + test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False) predictions = self.lightning_trainer.predict(self.model, test_data) - prediction_array = np.array([pred.numpy() for pred in predictions]) # type: ignore + prediction_array = np.vstack(predictions) # type: ignore + prediction_array = prediction_array.squeeze() # Check if the predictions have the same length as the input dataset if prediction_array.shape[0] != len(X): @@ -98,11 +104,10 @@ def _predict( # If the model is a binary classifier, return the probability of the positive class if self._is_binary_classifier(): - if prediction_array.shape[1] != 1 or prediction_array.shape[2] != 1: + if prediction_array.ndim != 1: raise ValueError( "Binary classification model should output a single probability." ) - prediction_array = prediction_array[:, 0, 0] return prediction_array def predict( @@ -170,7 +175,7 @@ def to_encoder(self) -> ChempropNeuralFP: class ChempropClassifier(ChempropModel): - """Wrap Chemprop in a sklearn like classifier.""" + """Chemprop model with default parameters for binary classification tasks.""" def __init__( self, @@ -178,7 +183,7 @@ def __init__( lightning_trainer: pl.Trainer | None = None, batch_size: int = 64, n_jobs: int = 1, - **kwargs: Any, # pylint: disable=unused-argument + **kwargs: Any, ) -> None: """Initialize the chemprop classifier model. @@ -206,6 +211,7 @@ def __init__( lightning_trainer=lightning_trainer, batch_size=batch_size, n_jobs=n_jobs, + **kwargs, ) if not self._is_binary_classifier(): raise ValueError("ChempropClassifier should be a binary classifier.") @@ -227,3 +233,44 @@ def set_params(self, **params: Any) -> Self: if not self._is_binary_classifier(): raise ValueError("ChempropClassifier should be a binary classifier.") return self + + +class ChempropRegressor(ChempropModel): + """Chemprop model with default parameters for regression tasks.""" + + def __init__( + self, + model: MPNN | None = None, + lightning_trainer: pl.Trainer | None = None, + batch_size: int = 64, + n_jobs: int = 1, + **kwargs: Any, + ) -> None: + """Initialize the chemprop regressor model. + + Parameters + ---------- + model : MPNN | None, optional + The chemprop model to wrap. If None, a default model will be used. + lightning_trainer : pl.Trainer, optional + The lightning trainer to use, by default None + batch_size : int, optional (default=64) + The batch size to use. + n_jobs : int, optional (default=1) + The number of jobs to use. + kwargs : Any + Parameters set using `set_params`. + Can be used to modify components of the model. + """ + if model is None: + bond_encoder = BondMessagePassing() + agg = SumAggregation() + predictor = RegressionFFN() + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + super().__init__( + model=model, + lightning_trainer=lightning_trainer, + batch_size=batch_size, + n_jobs=n_jobs, + **kwargs, + ) diff --git a/molpipeline/mol2any/__init__.py b/molpipeline/mol2any/__init__.py index 96509f73..da7bb760 100644 --- a/molpipeline/mol2any/__init__.py +++ b/molpipeline/mol2any/__init__.py @@ -1,7 +1,5 @@ """Init the module for mol2any pipeline elements.""" -import pkgutil - from molpipeline.mol2any.mol2bin import MolToBinary from molpipeline.mol2any.mol2concatinated_vector import MolToConcatenatedVector from molpipeline.mol2any.mol2inchi import MolToInchi, MolToInchiKey @@ -21,8 +19,9 @@ "MolToRDKitPhysChem", ] -installed_packages = {pkg.name for pkg in pkgutil.iter_modules()} -if "chemprop" in installed_packages: +try: from molpipeline.mol2any.mol2chemprop import MolToChemprop # noqa __all__.append("MolToChemprop") +except ImportError: + pass diff --git a/requirements_chemprop.txt b/requirements_chemprop.txt index 41ccfc6a..0bd7e545 100644 --- a/requirements_chemprop.txt +++ b/requirements_chemprop.txt @@ -1,2 +1,2 @@ -chemprop @ https://github.com//c-w-feldmann/chemprop/archive/v2/molpipeline_requirement.zip -lightning < 1000 \ No newline at end of file +chemprop +lightning \ No newline at end of file diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 31de13e6..92fe69cc 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -1,9 +1,16 @@ """Test module for behavior of chemprop in a pipeline.""" import unittest +from io import BytesIO +from typing import TypeVar +import joblib import numpy as np +import pandas as pd +from chemprop.nn.loss import LossFunction +from lightning import pytorch as pl from sklearn.base import clone +from torch import nn from molpipeline.any2mol import SmilesToMol from molpipeline.error_handling import ErrorFilter, FilterReinserter @@ -13,7 +20,11 @@ BondMessagePassing, SumAggregation, ) -from molpipeline.estimators.chemprop.models import ChempropModel +from molpipeline.estimators.chemprop.models import ( + ChempropClassifier, + ChempropModel, + ChempropRegressor, +) from molpipeline.mol2any.mol2chemprop import MolToChemprop from molpipeline.pipeline import Pipeline from molpipeline.post_prediction import PostPredictionWrapper @@ -61,6 +72,100 @@ def get_model_pipeline() -> Pipeline: return model_pipeline +DEFAULT_TRAINER = pl.Trainer( + accelerator="cpu", + logger=False, + enable_checkpointing=False, + max_epochs=5, + enable_model_summary=False, + enable_progress_bar=False, + val_check_interval=0.0, +) + + +def get_regression_pipeline() -> Pipeline: + """Get the Chemprop model pipeline for regression. + + Returns + ------- + Pipeline + The Chemprop model pipeline for regression. + """ + + smiles2mol = SmilesToMol() + mol2chemprop = MolToChemprop() + error_filter = ErrorFilter(filter_everything=True) + filter_reinserter = FilterReinserter.from_error_filter( + error_filter, fill_value=np.nan + ) + chemprop_model = ChempropRegressor(lightning_trainer=DEFAULT_TRAINER) + model_pipeline = Pipeline( + steps=[ + ("smiles2mol", smiles2mol), + ("mol2chemprop", mol2chemprop), + ("error_filter", error_filter), + ("model", chemprop_model), + ("filter_reinserter", PostPredictionWrapper(filter_reinserter)), + ], + ) + return model_pipeline + + +def get_classification_pipeline() -> Pipeline: + """Get the Chemprop model pipeline for classification. + + Returns + ------- + Pipeline + The Chemprop model pipeline for classification. + """ + smiles2mol = SmilesToMol() + mol2chemprop = MolToChemprop() + error_filter = ErrorFilter(filter_everything=True) + filter_reinserter = FilterReinserter.from_error_filter( + error_filter, fill_value=np.nan + ) + chemprop_model = ChempropClassifier(lightning_trainer=DEFAULT_TRAINER) + model_pipeline = Pipeline( + steps=[ + ("smiles2mol", smiles2mol), + ("mol2chemprop", mol2chemprop), + ("error_filter", error_filter), + ("model", chemprop_model), + ("filter_reinserter", PostPredictionWrapper(filter_reinserter)), + ], + ) + return model_pipeline + + +_T = TypeVar("_T") + + +def joblib_dump_load(obj: _T) -> _T: + """Dump and load an object using joblib. + + Notes + ----- + The object is not dumped to disk but to a BytesIO object. + + Parameters + ---------- + obj : _T + The object to dump and load. + + Returns + ------- + _T + The loaded object. + """ + bytes_container = BytesIO() + joblib.dump(obj, bytes_container) + bytes_container.seek(0) # update to enable reading + bytes_model = bytes_container.read() + + return joblib.load(BytesIO(bytes_model)) + + class TestChempropPipeline(unittest.TestCase): """Test the Chemprop model pipeline.""" @@ -106,6 +211,14 @@ def test_clone(self) -> None: self.assertEqual( param.__class__, cloned_params[param_name].__class__ ) + elif isinstance(param, LossFunction): + self.assertEqual( + param.state_dict()["task_weights"], + cloned_params[param_name].state_dict()["task_weights"], + ) + self.assertEqual(type(param), type(cloned_params[param_name])) + elif isinstance(param, nn.Identity): + self.assertEqual(type(param), type(cloned_params[param_name])) else: self.assertEqual( param, cloned_params[param_name], f"Failed for {param_name}" @@ -145,3 +258,60 @@ def test_error_handling(self) -> None: proba = pipeline.predict_proba(smiles) self.assertEqual(len(proba), 4) self.assertTrue(np.isnan(proba[-1]).all()) + + +class TestRegressionPipeline(unittest.TestCase): + """Test the Chemprop model pipeline for regression.""" + + def test_prediction(self) -> None: + """Test the prediction of the regression model.""" + + molecule_net_logd_df = pd.read_csv( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv" + ).head(1000) + regression_model = get_regression_pipeline() + regression_model.fit( + molecule_net_logd_df["smiles"].tolist(), + molecule_net_logd_df["exp"].to_numpy(), + ) + pred = regression_model.predict(molecule_net_logd_df["smiles"].tolist()) + + self.assertEqual(len(pred), len(molecule_net_logd_df)) + + model_copy = joblib_dump_load(regression_model) + pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist()) + self.assertTrue(np.allclose(pred, pred_copy)) + + +class TestClassificationPipeline(unittest.TestCase): + """Test the Chemprop model pipeline for classification.""" + + def test_prediction(self) -> None: + """Test the prediction of the classification model.""" + + molecule_net_bbbp_df = pd.read_csv( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv" + ).head(1000) + classification_model = get_classification_pipeline() + classification_model.fit( + molecule_net_bbbp_df["smiles"].tolist(), + molecule_net_bbbp_df["p_np"].to_numpy(), + ) + pred = classification_model.predict(molecule_net_bbbp_df["smiles"].tolist()) + proba = classification_model.predict_proba( + molecule_net_bbbp_df["smiles"].tolist() + ) + self.assertEqual(len(pred), len(molecule_net_bbbp_df)) + self.assertEqual(proba.shape[1], 2) + self.assertEqual(proba.shape[0], len(molecule_net_bbbp_df)) + + model_copy = joblib_dump_load(classification_model) + pred_copy = model_copy.predict(molecule_net_bbbp_df["smiles"].tolist()) + proba_copy = model_copy.predict_proba(molecule_net_bbbp_df["smiles"].tolist()) + + nan_indices = np.isnan(pred) + self.assertListEqual(nan_indices.tolist(), np.isnan(pred_copy).tolist()) + self.assertTrue(np.allclose(pred[~nan_indices], pred_copy[~nan_indices])) + + self.assertEqual(proba.shape, proba_copy.shape) + self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) diff --git a/test_extras/test_chemprop/test_component_wrapper.py b/test_extras/test_chemprop/test_component_wrapper.py index 02dcdd54..53a901aa 100644 --- a/test_extras/test_chemprop/test_component_wrapper.py +++ b/test_extras/test_chemprop/test_component_wrapper.py @@ -2,7 +2,9 @@ import unittest +from chemprop.nn.loss import LossFunction from sklearn.base import clone +from torch import nn from molpipeline.estimators.chemprop.component_wrapper import ( MPNN, @@ -113,6 +115,14 @@ def test_clone(self) -> None: clone_param = mpnn_clone.get_params(deep=True)[param_name] if hasattr(param, "get_params"): self.assertEqual(param.__class__, clone_param.__class__) + elif isinstance(param, LossFunction): + self.assertEqual( + param.state_dict()["task_weights"], + clone_param.state_dict()["task_weights"], + ) + self.assertEqual(type(param), type(clone_param)) + elif isinstance(param, nn.Identity): + self.assertEqual(type(param), type(clone_param)) else: self.assertEqual(param, clone_param) diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index f3a88e09..7445102a 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -3,17 +3,24 @@ import logging import unittest +from chemprop.nn.loss import BCELoss, LossFunction, MSELoss from lightning import pytorch as pl from sklearn.base import clone +from torch import Tensor, nn from molpipeline.estimators.chemprop.component_wrapper import ( MPNN, BinaryClassificationFFN, BondMessagePassing, MeanAggregation, + RegressionFFN, SumAggregation, ) -from molpipeline.estimators.chemprop.models import ChempropModel +from molpipeline.estimators.chemprop.models import ( + ChempropClassifier, + ChempropModel, + ChempropRegressor, +) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) @@ -39,6 +46,53 @@ def get_model() -> ChempropModel: return chemprop_model +DEFAULT_PARAMS = { + "batch_size": 64, + "lightning_trainer": pl.Trainer, + "model": MPNN, + "model__agg__dim": 0, + "model__agg": SumAggregation, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__message_passing": BondMessagePassing, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor": BinaryClassificationFFN, + "model__predictor__criterion": BCELoss, + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__output_transform": nn.Identity, + "model__predictor__task_weights": Tensor([1.0]), + "model__predictor__threshold": None, + "n_jobs": 1, +} + +NO_IDENTITY_CHECK = [ + "model__agg", + "model__message_passing", + "lightning_trainer", + "model", + "model__predictor", + "model__predictor__criterion", + "model__predictor__output_transform", +] + + class TestChempropModel(unittest.TestCase): """Test the Chemprop model.""" @@ -46,40 +100,19 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = get_model() orig_params = chemprop_model.get_params(deep=True) + expected_params = dict(DEFAULT_PARAMS) # Shallow copy - expected_params = { - "batch_size": 64, - "lightning_trainer": pl.Trainer, - "model__agg__dim": 0, - "model__agg": SumAggregation, - "model__batch_norm": True, - "model__final_lr": 0.0001, - "model__init_lr": 0.0001, - "model__max_lr": 0.001, - "model__message_passing__activation": "relu", - "model__message_passing__bias": False, - "model__message_passing__d_e": 14, - "model__message_passing__d_h": 300, - "model__message_passing__d_v": 133, - "model__message_passing__d_vd": None, - "model__message_passing__depth": 3, - "model__message_passing__dropout_rate": 0.0, - "model__message_passing__undirected": False, - "model__message_passing": BondMessagePassing, - } - + self.assertSetEqual(set(orig_params), set(expected_params)) # Check if the parameters are as expected for param_name, param in expected_params.items(): - if param_name in [ - "model__agg", - "model__message_passing", - "lightning_trainer", - ]: + if param_name in NO_IDENTITY_CHECK: if not isinstance(param, type): raise ValueError(f"{param_name} should be a type.") self.assertIsInstance(orig_params[param_name], param) else: - self.assertEqual(orig_params[param_name], param) + self.assertEqual( + orig_params[param_name], param, f"Test failed for {param_name}" + ) new_params = { "batch_size": 32, @@ -102,7 +135,6 @@ def test_clone(self) -> None: cloned_model = clone(chemprop_model) self.assertIsInstance(cloned_model, ChempropModel) cloned_model_params = cloned_model.get_params(deep=True) - for param_name, param in chemprop_model.get_params(deep=True).items(): cloned_param = cloned_model_params[param_name] if hasattr(param, "get_params"): @@ -110,8 +142,16 @@ def test_clone(self) -> None: self.assertNotEqual(id(param), id(cloned_param)) elif isinstance(param, pl.Trainer): self.assertIsInstance(cloned_param, pl.Trainer) + elif isinstance(param, LossFunction): + self.assertEqual( + param.state_dict()["task_weights"], + cloned_param.state_dict()["task_weights"], + ) + self.assertEqual(type(param), type(cloned_param)) + elif isinstance(param, nn.Identity): + self.assertEqual(type(param), type(cloned_param)) else: - self.assertEqual(param, cloned_param) + self.assertEqual(param, cloned_param, f"Test failed for {param_name}") def test_classifier_methods(self) -> None: """Test the classifier methods.""" @@ -131,3 +171,45 @@ def test_neural_fp(self) -> None: # the model should be cloned self.assertNotEqual(id(chemprop_model.model), id(neural_fp.model)) self.assertEqual(neural_fp.disable_fitting, True) + + +class TestChempropClassifier(unittest.TestCase): + """Test the Chemprop classifier model.""" + + def test_get_params(self) -> None: + """Test the get_params and set_params methods.""" + chemprop_model = ChempropClassifier() + param_dict = chemprop_model.get_params(deep=True) + expected_params = dict(DEFAULT_PARAMS) # Shallow copy + self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) + for param_name, param in expected_params.items(): + if param_name in NO_IDENTITY_CHECK: + if not isinstance(param, type): + raise ValueError(f"{param_name} should be a type.") + self.assertIsInstance(param_dict[param_name], param) + else: + self.assertEqual( + param_dict[param_name], param, f"Test failed for {param_name}" + ) + + +class TestChempropRegressor(unittest.TestCase): + """Test the Chemprop regressor model.""" + + def test_get_params(self) -> None: + """Test the get_params and set_params methods.""" + chemprop_model = ChempropRegressor() + param_dict = chemprop_model.get_params(deep=True) + expected_params = dict(DEFAULT_PARAMS) + expected_params["model__predictor"] = RegressionFFN + expected_params["model__predictor__criterion"] = MSELoss + self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) + for param_name, param in expected_params.items(): + if param_name in NO_IDENTITY_CHECK: + if not isinstance(param, type): + raise ValueError(f"{param_name} should be a type.") + self.assertIsInstance(param_dict[param_name], param) + else: + self.assertEqual( + param_dict[param_name], param, f"Test failed for {param_name}" + )