Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

explainability: new module #44

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
8e24c7a
explainability: new module
JochenSiegWork Jun 27, 2024
8b39bf7
explainability: changes for numpy2
JochenSiegWork Aug 13, 2024
6573676
explanability: add new visualization
JochenSiegWork Aug 13, 2024
3bd54d1
explainability: linting vis code
JochenSiegWork Aug 13, 2024
36f2517
explainability: fix linting
JochenSiegWork Aug 13, 2024
3fdeee5
explainability: vis linting
JochenSiegWork Aug 14, 2024
bd78f18
explainability: more linting
JochenSiegWork Aug 14, 2024
2f6c521
linting
JochenSiegWork Aug 14, 2024
83df2dc
linting again
JochenSiegWork Aug 14, 2024
787e79f
explaonability: add matplotlib as dependency for visualization
JochenSiegWork Aug 14, 2024
368b9af
explainability: improve speed
JochenSiegWork Aug 14, 2024
59b550f
explainability: speed up unittests
JochenSiegWork Aug 14, 2024
3cbf568
explainability: suppress already checked mypy warning
JochenSiegWork Aug 14, 2024
d69d2b1
mypy
JochenSiegWork Aug 14, 2024
c8ac95f
mypy + rename interface atom_weights
JochenSiegWork Aug 14, 2024
206af69
explainability: add more visualization
JochenSiegWork Aug 29, 2024
76395c4
explainability: refactored shap heatmap visualization
JochenSiegWork Aug 30, 2024
9db8ca0
linting
JochenSiegWork Aug 30, 2024
e1176e4
linting
JochenSiegWork Aug 30, 2024
2e89391
linting
JochenSiegWork Aug 30, 2024
e9e0102
linting
JochenSiegWork Aug 30, 2024
45008c9
linting
JochenSiegWork Aug 30, 2024
4f5c186
linting
JochenSiegWork Aug 30, 2024
d883236
explainability: handle fill values better
JochenSiegWork Oct 9, 2024
01828ce
explainability: linting
JochenSiegWork Oct 9, 2024
de39c3d
explainability: remove enumerate call
JochenSiegWork Oct 9, 2024
ed830a6
explainability vis: use heavy atoms instead of all atoms
JochenSiegWork Nov 15, 2024
04f7794
explainability: test explicit/implicit hydrogens
JochenSiegWork Nov 15, 2024
52b93aa
linters
JochenSiegWork Nov 15, 2024
e753d6f
mypy
JochenSiegWork Nov 15, 2024
ee11e7b
mypy
JochenSiegWork Nov 15, 2024
6c7c637
myppy
JochenSiegWork Nov 15, 2024
0949d2f
mypy
JochenSiegWork Nov 15, 2024
0918a47
explainability: use all atoms instead of heavy atoms
JochenSiegWork Nov 20, 2024
9e2ef4b
explainability: Explanation datastructures using mixins
JochenSiegWork Nov 20, 2024
5231f61
linting
JochenSiegWork Nov 20, 2024
27afdb7
linting
JochenSiegWork Nov 20, 2024
0bd5487
try to add further SHAP explainers
JochenSiegWork Nov 20, 2024
0738241
add shap's KernalExplainer fully
JochenSiegWork Nov 21, 2024
ec97416
linting
JochenSiegWork Nov 21, 2024
0adb1ff
linting
JochenSiegWork Nov 25, 2024
ab06fac
linting
JochenSiegWork Nov 25, 2024
44cfe84
linting
JochenSiegWork Nov 25, 2024
2f7bc43
linting
JochenSiegWork Nov 25, 2024
8a95d29
linting
JochenSiegWork Nov 25, 2024
5565805
linting
JochenSiegWork Nov 25, 2024
29daa6d
linting
JochenSiegWork Nov 25, 2024
fe50d74
linting
JochenSiegWork Nov 25, 2024
994fc3e
linitng
JochenSiegWork Nov 25, 2024
e0c3d3c
linting
JochenSiegWork Nov 25, 2024
633def2
linting
JochenSiegWork Nov 25, 2024
fbb855c
linting
JochenSiegWork Nov 25, 2024
725cdb8
linting
JochenSiegWork Nov 25, 2024
f12432a
linting
JochenSiegWork Nov 25, 2024
5603cf0
linting
JochenSiegWork Nov 25, 2024
1981d6a
add feature_names to explainability code
JochenSiegWork Nov 25, 2024
d6bc7c2
linting
JochenSiegWork Nov 25, 2024
398a5dc
add xai notebook and adaptations
JochenSiegWork Nov 28, 2024
0a9acc4
linting
JochenSiegWork Nov 29, 2024
0c14d58
improve xai notebook
JochenSiegWork Nov 29, 2024
ccb43e3
finished notebook
JochenSiegWork Dec 2, 2024
671b4d0
black
JochenSiegWork Dec 3, 2024
964213b
example data
JochenSiegWork Dec 3, 2024
52387f5
Christian comments 1
JochenSiegWork Dec 4, 2024
8d3ae06
linting
JochenSiegWork Dec 5, 2024
55036aa
pydocstyle
JochenSiegWork Dec 5, 2024
a970d3d
rename xai notebook and update import
JochenSiegWork Dec 5, 2024
64220eb
Merge branch 'main' into explainability_module
JochenSiegWork Jan 14, 2025
0733089
Merge branch 'main' into explainability_module
JochenSiegWork Jan 14, 2025
e7f6653
fix mypy error
JochenSiegWork Jan 14, 2025
054a3a1
fix mypy
JochenSiegWork Jan 14, 2025
0560abf
ignore mypy error
JochenSiegWork Jan 14, 2025
be76892
mypy
JochenSiegWork Jan 14, 2025
3bc87e3
first part code review
JochenSiegWork Jan 14, 2025
dcf95c2
linting
JochenSiegWork Jan 14, 2025
87c4888
linting fix
JochenSiegWork Jan 14, 2025
66014c5
review comments typing
JochenSiegWork Jan 15, 2025
7d89b70
Review comments for notebook
JochenSiegWork Jan 15, 2025
46949b5
rework visualization of present/absent features text
JochenSiegWork Jan 15, 2025
38ec057
mypy
JochenSiegWork Jan 15, 2025
2038e29
isort
JochenSiegWork Jan 15, 2025
ea186cc
mypy
JochenSiegWork Jan 15, 2025
cd59c94
pylint can't parse f-strings correctly
JochenSiegWork Jan 15, 2025
2213bb1
heatmaps: move type defintion outside function
JochenSiegWork Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 18 additions & 35 deletions molpipeline/experimental/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from __future__ import annotations

import abc
from typing import Any, TypeAlias
from typing import Any, Callable

import numpy as np
import numpy.typing as npt
import pandas as pd
import shap
from scipy.sparse import issparse, spmatrix
from sklearn.base import BaseEstimator
from typing_extensions import override

try:
from typing import override # type: ignore[attr-defined]
except ImportError:
from typing_extensions import override

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
Expand Down Expand Up @@ -51,27 +55,9 @@ def _to_dense(
return feature_matrix


def _convert_to_array(value: Any) -> npt.NDArray[np.float64]:
"""Convert a value to a numpy array.

Parameters
----------
value : Any
The value to convert.

Returns
-------
npt.NDArray[np.float64]
The value as a numpy array.
"""
if isinstance(value, np.ndarray):
return value
if np.isscalar(value):
return np.array([value])
raise ValueError("Value is not a scalar or numpy array.")


def _get_prediction_function(pipeline: Pipeline | BaseEstimator) -> Any:
def _get_prediction_function(
pipeline: Pipeline | BaseEstimator,
) -> Callable[[npt.ArrayLike], npt.ArrayLike]:
"""Get the prediction function of a model.

Parameters
Expand Down Expand Up @@ -166,18 +152,13 @@ def _convert_shap_feature_weights_to_atom_weights(
return atom_weights


ShapExplanation: TypeAlias = list[
SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation
]


class AbstractSHAPExplainer(abc.ABC): # pylint: disable=too-few-public-methods
"""Abstract class for SHAP explainer objects."""

@abc.abstractmethod
def explain(
self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument
) -> ShapExplanation:
) -> list[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.

Parameters
Expand All @@ -199,6 +180,9 @@ class SHAPExplainerAdapter(
): # pylint: disable=too-few-public-methods
"""Adapter for SHAP explainer wrappers for handling molecules and pipelines."""

# used for dynamically defining the return type of the explain method
return_element_type_: type[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation]

def __init__(
self,
pipeline: Pipeline,
Expand Down Expand Up @@ -234,9 +218,6 @@ def __init__(

# determine type of returned explanation
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]
self.return_element_type_: type[
SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation
]
if isinstance(featurization_element, MolToMorganFP):
self.return_element_type_ = SHAPFeatureAndAtomExplanation
else:
Expand Down Expand Up @@ -271,7 +252,7 @@ def _prediction_is_valid(prediction: Any) -> bool:
@override
def explain(
self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument
) -> ShapExplanation:
) -> list[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation]:
"""Explain the predictions for the input data.

If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand All @@ -291,7 +272,9 @@ def explain(
"""
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]

explanation_results: ShapExplanation = []
explanation_results: list[
SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation
] = []
for input_sample in X:

input_sample = [input_sample]
Expand Down Expand Up @@ -359,7 +342,7 @@ def explain(
if issubclass(self.return_element_type_, BondExplanationMixin):
explanation_data["bond_weights"] = bond_weights
if issubclass(self.return_element_type_, SHAPExplanationMixin):
explanation_data["expected_value"] = _convert_to_array(
explanation_data["expected_value"] = np.atleast_1d(
self.explainer.expected_value
)

Expand Down
20 changes: 15 additions & 5 deletions molpipeline/experimental/explainability/visualization/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,19 @@ def __init__(
self.y_lim = y_lim
self.x_res = x_res
self.y_res = y_res
self._dx = (max(self.x_lim) - min(self.x_lim)) / self.x_res
self._dy = (max(self.y_lim) - min(self.y_lim)) / self.y_res
self.values = np.zeros((self.x_res, self.y_res))

@property
def dx(self) -> float:
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved
"""Length of cell in x-direction."""
return (max(self.x_lim) - min(self.x_lim)) / self.x_res
return self._dx

@property
def dy(self) -> float:
"""Length of cell in y-direction."""
return (max(self.y_lim) - min(self.y_lim)) / self.y_res
return self._dy

def grid_field_center(self, x_idx: int, y_idx: int) -> tuple[float, float]:
"""Center of cell specified by index along x and y.
Expand Down Expand Up @@ -149,6 +151,9 @@ def __init__(
y_lim: Sequence[float],
x_res: int,
y_res: int,
function_list: (
list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]] | None
) = None,
):
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize the ValueGrid with limits and resolution of the axes.

Expand All @@ -162,11 +167,16 @@ def __init__(
Resolution (number of cells) along x-axis.
y_res: int
Resolution (number of cells) along y-axis.
function_list: list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]], optional
List of functions to be evaluated for each cell, by default None.
"""
super().__init__(x_lim, y_lim, x_res, y_res)
self.function_list: list[
Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
] = []
if function_list is not None:
self.function_list: list[
Copy link
Collaborator

@c-w-feldmann c-w-feldmann Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hints for class/instance vars should be defined outside of functions.
...
I read a bit into this and it seems that its not actually a PEP, but only done in this way in PEP 526.

The PEP itself ist now also marked as historical document.
But the rest of the molpipeline code is done in the same way and I think that it's still a good Idea.
Current examples still do it this way: https://typing.readthedocs.io/en/latest/spec/class-compat.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
] = function_list
else:
self.function_list = []
self.values = np.zeros((self.x_res, self.y_res))

def add_function(
Expand Down
135 changes: 125 additions & 10 deletions molpipeline/experimental/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from matplotlib import colors
from matplotlib import pyplot as plt
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
from PIL import Image
from rdkit import Chem
from rdkit.Chem import Draw
Expand Down Expand Up @@ -176,7 +177,7 @@ def _add_gaussians_for_bonds(
bond_center = (a1_coords + a2_coords) / 2

func = GaussFunctor2D(
center=bond_center,
center=bond_center, # type: ignore
std1=bond_width,
std2=bond_length,
scale=bond_weights[i],
Expand Down Expand Up @@ -265,6 +266,127 @@ def make_sum_of_gaussians_grid(
return value_grid


def _add_shap_present_absent_features_text(
fig: Figure,
explanation: SHAPFeatureAndAtomExplanation,
sum_present_shap: float,
sum_absent_shap: float,
) -> None:
"""Add text to the figure to display the SHAP prediction composition.

The added text includes the prediction value, the expected value, the sum of the SHAP values for present features,
and the sum of the SHAP values for absent features.

Parameters
----------
fig: Figure
The figure.
explanation: SHAPFeatureAndAtomExplanation
The SHAP explanation.
sum_present_shap: float
The sum of the SHAP values for present features.
sum_absent_shap: float
The sum of the SHAP values for absent features.
"""
if explanation.prediction is None:
raise AssertionError("Prediction value is None.")
if explanation.expected_value is None:
raise AssertionError("Expected value is None.")

color1 = "black"
color2 = "green"
color3 = "darkorchid"

fontsize_numbers = 11
delta = 0.04
offset = 0.375
fig.text(
offset + delta,
0.18,
f"{explanation.prediction[-1]:.2f} =",
fontsize=fontsize_numbers,
ha="center",
)
fig.text(
offset + 2 * delta,
0.18,
f" {'' if explanation.expected_value[-1] >= 0 else '-'}",
ha="center",
fontsize=fontsize_numbers,
color=color1,
)
fig.text(
offset + 3 * delta,
0.18,
f" {abs(explanation.expected_value[-1]):.2f}",
ha="center",
fontsize=fontsize_numbers,
color=color1,
)
fig.text(
offset + 4 * delta,
0.18,
f" {'+' if sum_present_shap >= 0 else '-'}",
ha="center",
fontsize=fontsize_numbers,
color=color2,
)
fig.text(
offset + 5 * delta,
0.18,
f" {abs(sum_present_shap):.2f}",
ha="center",
fontsize=fontsize_numbers,
color=color2,
)
fig.text(
offset + 6 * delta,
0.18,
f" {'+' if sum_absent_shap >= 0 else '-'}",
fontsize=fontsize_numbers,
ha="center",
color=color3,
)
fig.text(
offset + 7 * delta,
0.18,
f" {abs(sum_absent_shap):.2f}",
ha="center",
fontsize=fontsize_numbers,
color=color3,
)

delta = 0.05
offset = offset + 0.0165
fig.text(offset, 0.13, "prediction =", ha="center", fontsize=10)
fig.text(
offset + 2 * delta,
0.12,
"expected\nvalue",
ha="center",
fontsize=10,
color=color1,
)
fig.text(offset + 3 * delta, 0.13, " + ", ha="center", fontsize=10, color=color2)
fig.text(
offset + 4 * delta,
0.12,
"features\npresent",
ha="center",
fontsize=10,
color=color2,
)
fig.text(offset + 5 * delta, 0.13, " + ", ha="center", fontsize=10, color=color3)
fig.text(
offset + 6 * delta,
0.12,
"features\nabsent",
ha="center",
fontsize=10,
color=color3,
)


def _structure_heatmap(
mol: RDKitMol,
atom_weights: npt.NDArray[np.float64],
Expand Down Expand Up @@ -462,16 +584,9 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches, too-many-local

fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0)

# note: the prediction/expected value of the last array element is used
text = (
f"$Prediction = {explanation.prediction[-1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[-1]:.2f}$ + " # noqa: W605 # pylint: disable=anomalous-backslash-in-string
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
_add_shap_present_absent_features_text(
fig, explanation, sum_present_shap, sum_absent_shap
)
fig.text(0.5, 0.18, text, ha="center")

image = plt_to_pil(fig)
# clear the figure and memory
Expand Down
Loading
Loading