From a0ed00e8134080b7843f9072074195fd5975ce55 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Wed, 20 Nov 2024 15:44:56 +0100 Subject: [PATCH] linting --- molpipeline/explainability/explainer.py | 7 +- molpipeline/explainability/explanation.py | 5 +- .../test_visualization/test_gaussian_grid.py | 64 +++++++++++++++++++ .../test_visualization/test_visualization.py | 43 ------------- 4 files changed, 71 insertions(+), 48 deletions(-) create mode 100644 tests/test_explainability/test_visualization/test_gaussian_grid.py diff --git a/molpipeline/explainability/explainer.py b/molpipeline/explainability/explainer.py index 55823970..fa21d272 100644 --- a/molpipeline/explainability/explainer.py +++ b/molpipeline/explainability/explainer.py @@ -135,7 +135,7 @@ class AbstractSHAPExplainer(abc.ABC): @abc.abstractmethod def explain( self, X: Any, **kwargs: Any - ) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]: + ) -> list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]: """Explain the predictions for the input data. Parameters @@ -166,6 +166,8 @@ class SHAPTreeExplainer(AbstractSHAPExplainer): None if these failed instances should not be explained. """ + return_type: type[SHAPFeatureExplanation] | type[SHAPFeatureAndAtomExplanation] + def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None: """Initialize the SHAPTreeExplainer. @@ -204,6 +206,7 @@ def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None: if self.featurization_subpipeline is None: raise ValueError("Could not determine the featurization subpipeline.") + # determine type of returned explanation featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr] if isinstance(featurization_element, MolToMorganFP): self.return_type = SHAPFeatureAndAtomExplanation @@ -238,7 +241,7 @@ def _prediction_is_valid(self, prediction: Any) -> bool: # pylint: disable=C0103,W0613 def explain( self, X: Any, **kwargs: Any - ) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]: + ) -> list[SHAPFeatureExplanation] | list[SHAPFeatureAndAtomExplanation]: """Explain the predictions for the input data. If the calculation of the SHAP values for an input sample fails, the explanation will be invalid. diff --git a/molpipeline/explainability/explanation.py b/molpipeline/explainability/explanation.py index b594153d..ca48023a 100644 --- a/molpipeline/explainability/explanation.py +++ b/molpipeline/explainability/explanation.py @@ -65,7 +65,7 @@ class SHAPFeatureExplanation( SHAPExplanationMixin, _AbstractMoleculeExplanation, # base-class should be the last element https://www.ianlewis.org/en/mixins-and-python ): - """Explanation of a molecular prediction using feature importance scores and SHAP.""" + """Explanation using feature importance scores from SHAP.""" def is_valid(self) -> bool: """Check if the explanation is valid. @@ -94,8 +94,7 @@ class SHAPFeatureAndAtomExplanation( AtomExplanationMixin, _AbstractMoleculeExplanation, ): - """Explanation of a molecular prediction using feature importance scores, - atom importance scores and SHAP.""" + """Explanation using feature and atom importance scores from SHAP.""" def is_valid(self) -> bool: """Check if the explanation is valid. diff --git a/tests/test_explainability/test_visualization/test_gaussian_grid.py b/tests/test_explainability/test_visualization/test_gaussian_grid.py new file mode 100644 index 00000000..461186e7 --- /dev/null +++ b/tests/test_explainability/test_visualization/test_gaussian_grid.py @@ -0,0 +1,64 @@ +"""Test gaussian grid visualization.""" + +import unittest +from typing import ClassVar + +import numpy as np +from rdkit import Chem +from rdkit.Chem import Draw + +from molpipeline import Pipeline +from molpipeline.explainability import ( + SHAPFeatureAndAtomExplanation, + SHAPFeatureExplanation, + SHAPTreeExplainer, +) +from molpipeline.explainability.visualization.visualization import ( + make_sum_of_gaussians_grid, +) +from tests.test_explainability.test_visualization.test_visualization import ( + _get_test_morgan_rf_pipeline, +) + +TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] +CONTAINS_OX = [0, 1, 1, 0, 1, 0] + + +class TestSumOfGaussiansGrid(unittest.TestCase): + """Test sum of gaussian grid .""" + + test_pipeline: ClassVar[Pipeline] + test_explainer: ClassVar[SHAPTreeExplainer] + test_explanations: ClassVar[ + list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation] + ] + + @classmethod + def setUpClass(cls) -> None: + """Set up the tests.""" + cls.test_pipeline = _get_test_morgan_rf_pipeline() + cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) + cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) + cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) + + def test_grid_with_shap_atom_weights(self) -> None: + """Test grid with SHAP atom weights.""" + for explanation in self.test_explanations: + self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation.atom_weights, np.ndarray) + + mol_copy = Chem.Mol(explanation.molecule) + mol_copy = Draw.PrepareMolForDrawing(mol_copy) + value_grid = make_sum_of_gaussians_grid( + mol_copy, + atom_weights=explanation.atom_weights, + atom_width=np.inf, + grid_resolution=[64, 64], + padding=[0.4, 0.4], + ) + self.assertIsNotNone(value_grid) + self.assertEqual(value_grid.values.size, 64 * 64) + + # test that the range of summed gaussian values is as expected for SHAP + self.assertTrue(value_grid.values.min() >= -1) + self.assertTrue(value_grid.values.max() <= 1) diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py index c5ba88cf..1ea3896b 100644 --- a/tests/test_explainability/test_visualization/test_visualization.py +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -17,9 +17,6 @@ structure_heatmap, structure_heatmap_shap, ) -from molpipeline.explainability.visualization.visualization import ( - make_sum_of_gaussians_grid, -) from molpipeline.mol2any import MolToMorganFP TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"] @@ -131,43 +128,3 @@ def test_explicit_hydrogens(self) -> None: ) # type: ignore[union-attr] self.assertIsNotNone(image) self.assertEqual(image.format, "PNG") - - -class TestSumOfGaussiansGrid(unittest.TestCase): - """Test visualization methods for explanations.""" - - test_pipeline: ClassVar[Pipeline] - test_explainer: ClassVar[SHAPTreeExplainer] - test_explanations: ClassVar[ - list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation] - ] - - @classmethod - def setUpClass(cls) -> None: - """Set up the tests.""" - cls.test_pipeline = _get_test_morgan_rf_pipeline() - cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) - cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) - cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) - - def test_grid_with_shap_atom_weights(self) -> None: - """Test grid with SHAP atom weights.""" - for explanation in self.test_explanations: - self.assertTrue(explanation.is_valid()) - self.assertIsInstance(explanation.atom_weights, np.ndarray) - - mol_copy = Chem.Mol(explanation.molecule) - mol_copy = Draw.PrepareMolForDrawing(mol_copy) - value_grid = make_sum_of_gaussians_grid( - mol_copy, - atom_weights=explanation.atom_weights, - atom_width=np.inf, - grid_resolution=[64, 64], - padding=[0.4, 0.4], - ) - self.assertIsNotNone(value_grid) - self.assertEqual(value_grid.values.size, 64 * 64) - - # test that the range of summed gaussian values is as expected for SHAP - self.assertTrue(value_grid.values.min() >= -1) - self.assertTrue(value_grid.values.max() <= 1)