Skip to content

Commit

Permalink
explainability: adapt to mypy suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Jun 7, 2024
1 parent 1460ac9 commit c1919a4
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.explainability.explanation import Explanation
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.utils.subpipeline import SubpipelineExtractor
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor


# pylint: disable=C0103,W0613
Expand Down
1 change: 0 additions & 1 deletion molpipeline/explainability/fingerprint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def assign_prediction_importance(
dict[int, float]
The atom contribution.
"""

atom_contribution: dict[int, float] = defaultdict(lambda: 0)
for bit, atom_env_list in bit_dict.items(): # type: int, Sequence[AtomEnvironment]
n_machtes = len(atom_env_list)
Expand Down
2 changes: 1 addition & 1 deletion molpipeline/explainability/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_similaritymap_from_weights(
contour_lines: int = 10,
contour_params: Draw.ContourParams | None = None,
) -> Draw.MolDraw2D:
"""Generates the similarity map for a molecule given the atomic weights.
"""Generate the similarity map for a molecule given the atomic weights.
Strongly inspired from Chem.Draw.SimilarityMaps.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ rdkit >= 2023.9.1
scipy
setuptools
scikit-learn >= 1.4.0
shap
typing_extensions
8 changes: 6 additions & 2 deletions tests/test_explainability/test_shap_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability.explainer import SHAPTreeExplainer
from molpipeline.explainability.explanation import Explanation
from molpipeline.utils.subpipeline import SubpipelineExtractor
from molpipeline.mol2any import (
MolToConcatenatedVector,
MolToMorganFP,
MolToRDKitPhysChem,
)
from molpipeline.mol2mol import SaltRemover
from molpipeline.utils.subpipeline import SubpipelineExtractor

TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"]
CONTAINS_OX = [0, 1, 1, 0, 1, 0]
Expand Down Expand Up @@ -74,6 +74,7 @@ def _test_valid_explanation(
"""
self.assertTrue(explanation.is_valid())

self.assertIsInstance(explanation.feature_vector, np.ndarray)
self.assertEqual((nof_features,), explanation.feature_vector.shape)

# feature names are not implemented yet
Expand All @@ -86,6 +87,8 @@ def _test_valid_explanation(
Chem.MolToInchi(explanation.molecule),
)

self.assertIsInstance(explanation.prediction, np.ndarray)
self.assertIsInstance(explanation.feature_weights, np.ndarray)
if is_regressor(estimator):
self.assertTrue((1,), explanation.prediction.shape)
self.assertEqual((nof_features,), explanation.feature_weights.shape)
Expand All @@ -103,6 +106,7 @@ def _test_valid_explanation(
raise ValueError("Error in unittest. Unsupported estimator.")

if is_morgan_fingerprint:
self.assertIsInstance(explanation.atom_weights, np.ndarray)
self.assertEqual(
explanation.atom_weights.shape,
(explanation.molecule.GetNumAtoms(),),
Expand Down Expand Up @@ -144,7 +148,7 @@ def test_explanations_fingerprint_pipeline(self) -> None:
mol_reader_subpipeline = SubpipelineExtractor(
pipeline
).get_molecule_reader_subpipeline()
self.assertIsNotNone(mol_reader_subpipeline)
self.assertIsInstance(mol_reader_subpipeline, Pipeline)

for i, explanation in enumerate(explanations):
self._test_valid_explanation(
Expand Down
2 changes: 2 additions & 0 deletions tests/test_explainability/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import unittest

import numpy as np
from sklearn.ensemble import RandomForestClassifier

from molpipeline import Pipeline
Expand Down Expand Up @@ -36,6 +37,7 @@ def test_test_fingerprint_based_atom_coloring(self) -> None:

for explanation in explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)
drawer = rdkit_gaussplot(
explanation.molecule, explanation.atom_weights.tolist()
)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_utils/test_subpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.any2mol import SmilesToMol
from molpipeline.utils.subpipeline import SubpipelineExtractor
from molpipeline.mol2any import MolToMorganFP, MolToSmiles
from molpipeline.utils.subpipeline import SubpipelineExtractor


class TestSubpipelineExtractor(unittest.TestCase):
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_get_molecule_reader_subpipeline(self) -> None:
)
extractor = SubpipelineExtractor(pipeline)
subpipeline = extractor.get_molecule_reader_subpipeline()
self.assertIsNotNone(subpipeline)
self.assertIsInstance(subpipeline, Pipeline)
self.assertEqual(len(subpipeline.steps), 1)
self.assertIs(subpipeline.steps[0], pipeline.steps[0])

Expand All @@ -104,7 +104,7 @@ def test_get_molecule_reader_subpipeline(self) -> None:
)
extractor = SubpipelineExtractor(pipeline)
subpipeline = extractor.get_molecule_reader_subpipeline()
self.assertIsNotNone(subpipeline)
self.assertIsInstance(subpipeline, Pipeline)
self.assertEqual(len(subpipeline.steps), 3)
for i, subpipe_step in enumerate(subpipeline.steps):
self.assertIs(subpipe_step, pipeline.steps[i])
Expand All @@ -122,7 +122,7 @@ def test_get_model_subpipeline(self) -> None:
)
extractor = SubpipelineExtractor(pipeline)
subpipeline = extractor.get_model_subpipeline()
self.assertIsNotNone(subpipeline)
self.assertIsInstance(subpipeline, Pipeline)
self.assertEqual(len(subpipeline.steps), 3)
for i, subpipe_step in enumerate(subpipeline.steps):
self.assertIs(subpipe_step, pipeline.steps[i])
Expand All @@ -146,7 +146,7 @@ def test_get_model_subpipeline(self) -> None:
)
extractor = SubpipelineExtractor(pipeline)
subpipeline = extractor.get_model_subpipeline()
self.assertIsNotNone(subpipeline)
self.assertIsInstance(subpipeline, Pipeline)
self.assertEqual(len(subpipeline.steps), 4)
for i, subpipe_step in enumerate(subpipeline.steps):
self.assertIs(subpipe_step, pipeline.steps[i])
Expand Down

0 comments on commit c1919a4

Please sign in to comment.