Skip to content

Commit

Permalink
Allowing to use Memory in Pipeline (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann authored Aug 6, 2024
1 parent 501964d commit cc98b76
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 7 deletions.
3 changes: 1 addition & 2 deletions molpipeline/mol2any/mol2rdkit_phys_chem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import numpy as np
import numpy.typing as npt
from loguru import logger
from rdkit import rdBase
from rdkit import Chem
from rdkit import Chem, rdBase
from rdkit.Chem import Descriptors
from sklearn.preprocessing import StandardScaler

Expand Down
6 changes: 6 additions & 0 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def __init__(
"""
super().__init__(steps, memory=memory, verbose=verbose)
self.n_jobs = n_jobs
self._set_error_resinserter()

def _set_error_resinserter(self) -> None:
"""Connect the error resinserters with the error filters."""
error_replacer_list = [
e_filler
for _, e_filler in self.steps
Expand Down Expand Up @@ -288,6 +291,9 @@ def _fit(
self.steps[idx_i] = (name_i, ele_i)
if y is not None:
y = fitted_transformer.co_transform(y)
for idx_i, name_i, ele_i in zip(step_idx, name, ele_list):
self.steps[idx_i] = (name_i, ele_i)
self._set_error_resinserter()
elif isinstance(name, list) or isinstance(step_idx, list):
raise AssertionError()
else:
Expand Down
10 changes: 6 additions & 4 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from molpipeline.pipeline import Pipeline
from molpipeline.post_prediction import PostPredictionWrapper
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
from tests import TEST_DATA_DIR


# pylint: disable=duplicate-code
Expand Down Expand Up @@ -256,8 +257,8 @@ def test_prediction(self) -> None:
"""Test the prediction of the regression model."""

molecule_net_logd_df = pd.read_csv(
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv"
).head(1000)
TEST_DATA_DIR / "molecule_net_logd.tsv.gz", sep="\t", nrows=100
)
regression_model = get_regression_pipeline()
regression_model.fit(
molecule_net_logd_df["smiles"].tolist(),
Expand All @@ -279,8 +280,9 @@ def test_prediction(self) -> None:
"""Test the prediction of the classification model."""

molecule_net_bbbp_df = pd.read_csv(
"https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
).head(1000)
TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100
)
molecule_net_bbbp_df.to_csv("molecule_net_bbbp.tsv.gz", sep="\t", index=False)
classification_model = get_classification_pipeline()
classification_model.fit(
molecule_net_bbbp_df["smiles"].tolist(),
Expand Down
Binary file added tests/test_data/molecule_net_bbbp.tsv.gz
Binary file not shown.
Binary file added tests/test_data/molecule_net_logd.tsv.gz
Binary file not shown.
101 changes: 100 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

from __future__ import annotations

import tempfile
import unittest
from itertools import combinations
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from joblib import Memory
from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier

Expand All @@ -21,6 +27,8 @@
)
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
from molpipeline.utils.matrices import are_equal
from tests import TEST_DATA_DIR
from tests.utils.execution_count import get_exec_counted_rf_regressor
from tests.utils.fingerprints import make_sparse_fp

TEST_SMILES = ["CC", "CCO", "COC", "CCCCC", "CCC(-O)O", "CCCN"]
Expand Down Expand Up @@ -275,6 +283,97 @@ def test_gridsearchcv(self) -> None:
for k, value in param_grid.items():
self.assertIn(grid_search_cv.best_params_[k], value)

def test_caching(self) -> None:
"""Test if the caching gives the same results and is faster on the second run."""

molecule_net_logd_df = pd.read_csv(
TEST_DATA_DIR / "molecule_net_logd.tsv.gz", sep="\t", nrows=20
)
prediction_list = []
for cache_activated in [False, True]:
pipeline = get_exec_counted_rf_regressor(_RANDOM_STATE)
with tempfile.TemporaryDirectory() as temp_dir:

if cache_activated:
cache_dir = Path(temp_dir) / ".cache"
mem = Memory(location=cache_dir, verbose=0)
else:
mem = Memory(location=None, verbose=0)
pipeline.memory = mem
# Run fitting 1
pipeline.fit(
molecule_net_logd_df["smiles"].tolist(),
molecule_net_logd_df["exp"].tolist(),
)
# Get predictions
prediction = pipeline.predict(molecule_net_logd_df["smiles"].tolist())
prediction_list.append(prediction)

# Reset the last step with an untrained model
pipeline.steps[-1] = (
"rf",
RandomForestRegressor(random_state=_RANDOM_STATE, n_jobs=1),
)

# Run fitting 2
pipeline.fit(
molecule_net_logd_df["smiles"].tolist(),
molecule_net_logd_df["exp"].tolist(),
)
# Get predictions
prediction = pipeline.predict(molecule_net_logd_df["smiles"].tolist())
prediction_list.append(prediction)

n_transformations = pipeline.named_steps["mol2concat"].n_transformations
if cache_activated:
# Fit is called twice, but the transform is only called once, since the second run is cached
self.assertEqual(n_transformations, 1)
else:
self.assertEqual(n_transformations, 2)

mem.clear(warn=False)
for pred1, pred2 in combinations(prediction_list, 2):
self.assertTrue(np.allclose(pred1, pred2))

def test_gridsearch_cache(self) -> None:
"""Run a short GridSearchCV and check if the caching and not caching gives the same results."""
h_params = {
"rf__n_estimators": [1, 2],
}
# First without caching
data_df = pd.read_csv(
TEST_DATA_DIR / "molecule_net_logd.tsv.gz", sep="\t", nrows=20
)
best_param_dict = {}
prediction_dict = {}
for cache_activated in [True, False]:
pipeline = get_exec_counted_rf_regressor(_RANDOM_STATE)
with tempfile.TemporaryDirectory() as temp_dir:
cache_dir = Path(temp_dir) / ".cache"
if cache_activated:
mem = Memory(location=cache_dir, verbose=0)
else:
mem = Memory(location=None, verbose=0)
pipeline.memory = mem
grid_search_cv = GridSearchCV(
estimator=pipeline,
param_grid=h_params,
cv=2,
scoring="neg_mean_squared_error",
n_jobs=1,
error_score="raise",
refit=True,
pre_dispatch=1,
)
grid_search_cv.fit(data_df["smiles"].tolist(), data_df["exp"].tolist())
best_param_dict[cache_activated] = grid_search_cv.best_params_
prediction_dict[cache_activated] = grid_search_cv.predict(
data_df["smiles"].tolist()
)
mem.clear(warn=False)
self.assertEqual(best_param_dict[True], best_param_dict[False])
self.assertTrue(np.allclose(prediction_dict[True], prediction_dict[False]))


if __name__ == "__main__":
unittest.main()
155 changes: 155 additions & 0 deletions tests/utils/execution_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Functions for counting the number of times a function is executed."""

from __future__ import annotations

from typing import Any

try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Self

from sklearn.base import BaseEstimator
from sklearn.ensemble import RandomForestRegressor

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement
from molpipeline.any2mol import SmilesToMol
from molpipeline.mol2any import MolToMorganFP


class CountingTransformerWrapper(BaseEstimator):
"""A transformer that counts the number of transformations."""

def __init__(self, element: ABCPipelineElement):
"""Initialize the wrapper.
Parameters
----------
element : ABCPipelineElement
The element to wrap.
"""
self.element = element
self.n_transformations = 0

def fit(self, X: Any, y: Any) -> Self: # pylint: disable=invalid-name
"""Fit the data.
Parameters
----------
X : Any
The input data.
y : Any
The target data.
Returns
-------
Any
The fitted data.
"""
self.element.fit(X, y)
return self

def transform(self, X: Any) -> Any: # pylint: disable=invalid-name
"""Transform the data.
Transform is called during prediction, which is not cached.
Since the transformer is not cached, the counter is not increased.
Parameters
----------
X : Any
The input data.
Returns
-------
Any
The transformed data.
"""
return self.element.transform(X)

def fit_transform(self, X: Any, y: Any) -> Any: # pylint: disable=invalid-name
"""Fit and transform the data.
Parameters
----------
X : Any
The input data.
y : Any
The target data.
Returns
-------
Any
The transformed data.
"""
self.n_transformations += 1
return self.element.fit_transform(X, y)

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get the parameters of the transformer.
Parameters
----------
deep : bool
If True, the parameters of the transformer are also returned.
Returns
-------
dict[str, Any]
The parameters of the transformer.
"""
params = {
"element": self.element,
}
if deep:
params.update(self.element.get_params(deep))
return params

def set_params(self, **params: Any) -> Self:
"""Set the parameters of the transformer.
Parameters
----------
**params
The parameters to set.
Returns
-------
Self
The transformer with the set parameters
"""
element = params.pop("element", None)
if element is not None:
self.element = element
self.element.set_params(**params)
return self


def get_exec_counted_rf_regressor(random_state: int) -> Pipeline:
"""Get a morgan + random forest pipeline, which counts the number of transformations.
Parameters
----------
random_state : int
The random state to use.
Returns
-------
Pipeline
A pipeline with a morgan fingerprint, physchem descriptors, and a random forest
"""
smi2mol = SmilesToMol()

mol2concat = CountingTransformerWrapper(
MolToMorganFP(radius=2, n_bits=2048),
)
rf = RandomForestRegressor(random_state=random_state, n_jobs=1)
return Pipeline(
[
("smi2mol", smi2mol),
("mol2concat", mol2concat),
("rf", rf),
],
n_jobs=1,
)

0 comments on commit cc98b76

Please sign in to comment.