Skip to content

Commit

Permalink
explainability: Explanation datastructures using mixins
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 20, 2024
1 parent ef338d3 commit adaa3e3
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 61 deletions.
9 changes: 6 additions & 3 deletions molpipeline/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Explainability module for the molpipeline package."""

from molpipeline.explainability.explainer import SHAPTreeExplainer
from molpipeline.explainability.explanation import Explanation, SHAPExplanation
from molpipeline.explainability.explanation import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
)
from molpipeline.explainability.visualization.visualization import (
structure_heatmap,
structure_heatmap_shap,
)

__all__ = [
"Explanation",
"SHAPExplanation",
"SHAPFeatureExplanation",
"SHAPFeatureAndAtomExplanation",
"SHAPTreeExplainer",
"structure_heatmap",
"structure_heatmap_shap",
Expand Down
62 changes: 44 additions & 18 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.explainability.explanation import SHAPExplanation
from molpipeline.explainability.explanation import (
AtomExplanationMixin,
BondExplanationMixin,
FeatureExplanationMixin,
FeatureInfoMixin,
SHAPExplanationMixin,
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
)
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor
Expand Down Expand Up @@ -125,7 +133,9 @@ class AbstractSHAPExplainer(abc.ABC):

# pylint: disable=C0103,W0613
@abc.abstractmethod
def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
def explain(
self, X: Any, **kwargs: Any
) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.
Parameters
Expand Down Expand Up @@ -194,6 +204,12 @@ def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
if self.featurization_subpipeline is None:
raise ValueError("Could not determine the featurization subpipeline.")

featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]
if isinstance(featurization_element, MolToMorganFP):
self.return_type = SHAPFeatureAndAtomExplanation
else:
self.return_type = SHAPFeatureExplanation

def _prediction_is_valid(self, prediction: Any) -> bool:
"""Check if the prediction is valid using some heuristics.
Expand All @@ -220,7 +236,9 @@ def _prediction_is_valid(self, prediction: Any) -> bool:
return True

# pylint: disable=C0103,W0613
def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
def explain(
self, X: Any, **kwargs: Any
) -> list[SHAPFeatureExplanation, SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand Down Expand Up @@ -249,7 +267,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
prediction = _get_predictions(self.pipeline, input_sample)
if not self._prediction_is_valid(prediction):
# we use the prediction to check if the input is valid. If not, we cannot explain it.
explanation_results.append(SHAPExplanation())
explanation_results.append(self.return_type())
continue

if prediction.ndim > 1:
Expand All @@ -267,7 +285,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
# if the feature vector is empty, we cannot explain the prediction.
# This happens for failed instances in pipeline with fill values
# that could be valid predictions, like 0.
explanation_results.append(SHAPExplanation())
explanation_results.append(self.return_type())
continue

# Feature names should also be extracted from the Pipeline.
Expand All @@ -282,7 +300,9 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
atom_weights = None
bond_weights = None

if isinstance(featurization_element, MolToMorganFP):
if issubclass(self.return_type, AtomExplanationMixin) and isinstance(
featurization_element, MolToMorganFP
):
# for Morgan fingerprint, we can map the shap values to atom weights
atom_weights = _convert_shap_feature_weights_to_atom_weights(
feature_weights,
Expand All @@ -291,17 +311,23 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
feature_vector,
)

explanation_results.append(
SHAPExplanation(
feature_vector=feature_vector,
feature_names=feature_names,
molecule=molecule,
prediction=prediction,
feature_weights=feature_weights,
atom_weights=atom_weights,
bond_weights=bond_weights,
expected_value=self.explainer.expected_value,
)
)
# gather all input data for the explanation type to be returned
explanation_data = {
"molecule": molecule,
"prediction": prediction,
}
if issubclass(self.return_type, FeatureInfoMixin):
explanation_data["feature_vector"] = feature_vector
explanation_data["feature_names"] = feature_names
if issubclass(self.return_type, FeatureExplanationMixin):
explanation_data["feature_weights"] = feature_weights
if issubclass(self.return_type, AtomExplanationMixin):
explanation_data["atom_weights"] = atom_weights
if issubclass(self.return_type, BondExplanationMixin):
explanation_data["bond_weights"] = bond_weights
if issubclass(self.return_type, SHAPExplanationMixin):
explanation_data["expected_value"] = self.explainer.expected_value

explanation_results.append(self.return_type(**explanation_data))

return explanation_results
95 changes: 75 additions & 20 deletions molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import abc
import dataclasses

import numpy as np
Expand All @@ -11,22 +12,61 @@


@dataclasses.dataclass(kw_only=True)
class Explanation:
"""Class representing explanations of a prediction."""
class _AbstractMoleculeExplanation(abc.ABC):
"""Abstract class representing an explanation for a prediction for a molecule."""

# input data
feature_vector: npt.NDArray[np.float64] | None = None
feature_names: list[str] | None = None
molecule: RDKitMol | None = None
prediction: float | npt.NDArray[np.float64] | None = None

# explanation results mappable to the feature vector

@dataclasses.dataclass(kw_only=True)
class FeatureInfoMixin:
"""Mixin providing additional information about the features used in the explanation."""

feature_vector: npt.NDArray[np.float64] | None = None
feature_names: list[str] | None = None


@dataclasses.dataclass(kw_only=True)
class FeatureExplanationMixin:
"""Explanation based on feature importance scores, e.g. Shapley Values."""

# explanation scores for individual features
feature_weights: npt.NDArray[np.float64] | None = None

# explanation results mappable to the molecule.

@dataclasses.dataclass(kw_only=True)
class AtomExplanationMixin:
"""Atom score based explanation."""

# explanation scores for individual atoms
atom_weights: npt.NDArray[np.float64] | None = None


@dataclasses.dataclass(kw_only=True)
class BondExplanationMixin:
"""Bond score based explanation."""

# explanation scores for individual bonds
bond_weights: npt.NDArray[np.float64] | None = None


@dataclasses.dataclass(kw_only=True)
class SHAPExplanationMixin:
"""Mixin providing additional information only present in SHAP explanations."""

expected_value: npt.NDArray[np.float64] | None = None


@dataclasses.dataclass(kw_only=True)
class SHAPFeatureExplanation(
FeatureInfoMixin,
FeatureExplanationMixin,
SHAPExplanationMixin,
_AbstractMoleculeExplanation, # base-class should be the last element https://www.ianlewis.org/en/mixins-and-python
):
"""Explanation of a molecular prediction using feature importance scores and SHAP."""

def is_valid(self) -> bool:
"""Check if the explanation is valid.
Expand All @@ -38,25 +78,40 @@ def is_valid(self) -> bool:
return all(
[
self.feature_vector is not None,
# self.feature_names is not None,
# self.feature_names is not None, # TODO uncomment when PR is merged
self.molecule is not None,
self.prediction is not None,
any(
[
self.feature_weights is not None,
self.atom_weights is not None,
self.bond_weights is not None,
]
),
self.feature_weights is not None,
]
)


@dataclasses.dataclass(kw_only=True)
class SHAPExplanation(Explanation):
"""Class representing SHAP explanations of a prediction.
class SHAPFeatureAndAtomExplanation(
FeatureInfoMixin,
FeatureExplanationMixin,
SHAPExplanationMixin,
AtomExplanationMixin,
_AbstractMoleculeExplanation,
):
"""Explanation of a molecular prediction using feature importance scores,
atom importance scores and SHAP."""

This Explanation holds additional information only present in SHAP explanations.
"""
def is_valid(self) -> bool:
"""Check if the explanation is valid.
expected_value: npt.NDArray[np.float64] | None = None
Returns
-------
bool
True if the explanation is valid, False otherwise.
"""
return all(
[
self.feature_vector is not None,
# self.feature_names is not None, # TODO uncomment when PR is merged
self.molecule is not None,
self.prediction is not None,
self.feature_weights is not None,
self.atom_weights is not None,
]
)
16 changes: 8 additions & 8 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from rdkit.Chem.Draw import rdMolDraw2D

from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.explainability.explanation import SHAPExplanation
from molpipeline.explainability import SHAPFeatureAndAtomExplanation
from molpipeline.explainability.visualization.gauss import GaussFunctor2D
from molpipeline.explainability.visualization.heatmaps import (
ValueGrid,
Expand Down Expand Up @@ -378,7 +378,7 @@ def structure_heatmap(


def structure_heatmap_shap(
explanation: SHAPExplanation,
explanation: SHAPFeatureAndAtomExplanation,
color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None,
width: int = 600,
height: int = 600,
Expand All @@ -405,17 +405,17 @@ def structure_heatmap_shap(
The image as PNG.
"""
if explanation.feature_weights is None:
raise ValueError("SHAPExplanation does not contain feature weights.")
raise ValueError("Explanation does not contain feature weights.")
if explanation.feature_vector is None:
raise ValueError("SHAPExplanation does not contain feature_vector.")
raise ValueError("Explanation does not contain feature_vector.")
if explanation.expected_value is None:
raise ValueError("SHAPExplanation does not contain expected value.")
raise ValueError("Explanation does not contain expected value.")
if explanation.prediction is None:
raise ValueError("SHAPExplanation does not contain prediction.")
raise ValueError("Explanation does not contain prediction.")
if explanation.molecule is None:
raise ValueError("SHAPExplanation does not contain molecule.")
raise ValueError("Explanation does not contain molecule.")
if explanation.atom_weights is None:
raise ValueError("SHAPExplanation does not contain atom weights.")
raise ValueError("Explanation does not contain atom weights.")

present_shap = explanation.feature_weights[:, 1] * explanation.feature_vector
absent_shap = explanation.feature_weights[:, 1] * (1 - explanation.feature_vector)
Expand Down
19 changes: 10 additions & 9 deletions tests/test_explainability/test_shap_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability.explainer import SHAPTreeExplainer
from molpipeline.explainability.explanation import Explanation
from molpipeline.explainability.explanation import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
)
from molpipeline.mol2any import (
MolToConcatenatedVector,
MolToMorganFP,
Expand Down Expand Up @@ -49,7 +52,7 @@ class TestSHAPTreeExplainer(unittest.TestCase):

def _test_valid_explanation(
self,
explanation: Explanation,
explanation: SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation,
estimator: BaseEstimator,
molecule_reader_subpipeline: Pipeline,
nof_features: int,
Expand All @@ -73,6 +76,8 @@ def _test_valid_explanation(
is_morgan_fingerprint : bool
Whether the feature vector is a Morgan fingerprint or not.
"""
if not explanation.is_valid():
print()
self.assertTrue(explanation.is_valid())

self.assertIsInstance(explanation.feature_vector, np.ndarray)
Expand Down Expand Up @@ -114,18 +119,14 @@ def _test_valid_explanation(
else:
raise ValueError("Error in unittest. Unsupported estimator.")

if is_morgan_fingerprint:
if (
is_morgan_fingerprint
): # TODO recplace with issubclass(explanation, AtomExplanationMixin)
self.assertIsInstance(explanation.atom_weights, np.ndarray)
self.assertEqual(
explanation.atom_weights.shape, # type: ignore[union-attr]
(explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr]
)
else:
self.assertIsNone(explanation.atom_weights)

self.assertIsNone(
explanation.bond_weights
) # SHAPTreeExplainer doesn't set bond weights yet

def test_explanations_fingerprint_pipeline(self) -> None:
"""Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability import (
SHAPExplanation,
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPTreeExplainer,
structure_heatmap,
structure_heatmap_shap,
Expand Down Expand Up @@ -53,7 +54,9 @@ class TestExplainabilityVisualization(unittest.TestCase):

test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[list[SHAPExplanation]]
test_explanations: ClassVar[
list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation]
]

@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -135,7 +138,9 @@ class TestSumOfGaussiansGrid(unittest.TestCase):

test_pipeline: ClassVar[Pipeline]
test_explainer: ClassVar[SHAPTreeExplainer]
test_explanations: ClassVar[list[SHAPExplanation]]
test_explanations: ClassVar[
list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation]
]

@classmethod
def setUpClass(cls) -> None:
Expand Down

0 comments on commit adaa3e3

Please sign in to comment.