Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Dec 5, 2024
1 parent 52387f5 commit 8d3ae06
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 39 deletions.
31 changes: 16 additions & 15 deletions molpipeline/experimental/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,18 @@ def _convert_shap_feature_weights_to_atom_weights(
return atom_weights


_SHAPExplainer_return_type_: TypeAlias = list[
ShapExplanation: TypeAlias = list[
SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation
]


class AbstractSHAPExplainer(abc.ABC):
class AbstractSHAPExplainer(abc.ABC): # pylint: disable=too-few-public-methods
"""Abstract class for SHAP explainer objects."""

@abc.abstractmethod
def explain(
self, X: Any, **kwargs: Any
) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument
self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument
) -> ShapExplanation:
"""Explain the predictions for the input data.
Parameters
Expand All @@ -194,14 +194,15 @@ def explain(
"""


class SHAPExplainerAdapter(AbstractSHAPExplainer, abc.ABC):
class SHAPExplainerAdapter(
AbstractSHAPExplainer, abc.ABC
): # pylint: disable=too-few-public-methods
"""Adapter for SHAP explainer wrappers for handling molecules and pipelines."""

def __init__(
self,
pipeline: Pipeline,
explainer: shap.TreeExplainer | shap.KernelExplainer,
**kwargs: Any,
) -> None:
"""Initialize the SHAPTreeExplainer.
Expand All @@ -211,8 +212,6 @@ def __init__(
The pipeline containing the model to explain.
explainer : shap.TreeExplainer | shap.KernelExplainer
The shap explainer object.
kwargs : Any
Additional keyword arguments for SHAP's TreeExplainer.
"""
self.pipeline = pipeline
self.explainer = explainer
Expand Down Expand Up @@ -271,8 +270,8 @@ def _prediction_is_valid(prediction: Any) -> bool:

@override
def explain(
self, X: Any, **kwargs: Any
) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument
self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument
) -> ShapExplanation:
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand All @@ -292,7 +291,7 @@ def explain(
"""
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]

explanation_results: _SHAPExplainer_return_type_ = []
explanation_results: ShapExplanation = []
for input_sample in X:

input_sample = [input_sample]
Expand Down Expand Up @@ -369,7 +368,7 @@ def explain(
return explanation_results


class SHAPTreeExplainer(SHAPExplainerAdapter):
class SHAPTreeExplainer(SHAPExplainerAdapter): # pylint: disable=too-few-public-methods
"""Wrapper for SHAP's TreeExplainer that can handle pipelines and molecules.
Wraps SHAP's TreeExplainer to explain predictions of a pipeline containing a
Expand Down Expand Up @@ -397,7 +396,7 @@ def __init__(
Additional keyword arguments for SHAP's Explainer.
"""
explainer = self._create_explainer(pipeline, **kwargs)
super().__init__(pipeline, explainer, **kwargs)
super().__init__(pipeline, explainer)

@staticmethod
def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer:
Expand All @@ -423,7 +422,9 @@ def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer:
return explainer


class SHAPKernelExplainer(SHAPExplainerAdapter):
class SHAPKernelExplainer(
SHAPExplainerAdapter
): # pylint: disable=too-few-public-methods
"""Wrapper for SHAP's KernelExplainer that can handle pipelines and molecules."""

def __init__(
Expand All @@ -441,7 +442,7 @@ def __init__(
Additional keyword arguments for SHAP's Explainer.
"""
explainer = self._create_explainer(pipeline, **kwargs)
super().__init__(pipeline, explainer, **kwargs)
super().__init__(pipeline, explainer)

@staticmethod
def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.KernelExplainer:
Expand Down
65 changes: 65 additions & 0 deletions molpipeline/experimental/explainability/visualization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,68 @@ def plt_to_pil(figure: plt.Figure) -> Image.Image:
bio.seek(0)
img = Image.open(bio)
return img


def get_atom_coords_of_bond(
bond: Chem.Bond, conf: Chem.Conformer
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
"""Get the two atom coordinates of a bond in the conformation.
Parameters
----------
bond: Chem.Bond
The bond.
conf: Chem.Conformer
The conformation.
Returns
-------
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]
The atom coordinates.
"""
a1 = bond.GetBeginAtom().GetIdx()
a1_pos = conf.GetAtomPosition(a1)
a1_coords = np.array([a1_pos.x, a1_pos.y])

a2 = bond.GetEndAtom().GetIdx()
a2_pos = conf.GetAtomPosition(a2)
a2_coords = np.array([a2_pos.x, a2_pos.y])

return a1_coords, a2_coords


def calc_present_and_absent_shap_contributions(
feature_vector: npt.NDArray[np.float64], feature_weights: npt.NDArray[np.float64]
) -> tuple[float, float]:
"""Get the sum of present and absent SHAP values.
Parameters
----------
feature_vector: npt.NDArray[np.float64]
The feature vector.
feature_weights: npt.NDArray[np.float64]
The feature weights.
Raises
------
ValueError
If the feature vector is not binary.
Returns
-------
tuple[float, float]
The sum of present and absent SHAP values.
"""

if feature_vector.max() > 1 or feature_vector.min() < 0:
raise ValueError(
"Feature vector must be binary. Alternatively, use the structure_heatmap function instead."
)

# determine present/absent features using the binary feature vector
present_shap = feature_weights * feature_vector
absent_shap = feature_weights * (1 - feature_vector)
sum_present_shap = sum(present_shap)
sum_absent_shap = sum(absent_shap)

return sum_present_shap, sum_absent_shap
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@
from rdkit.Chem.Draw import rdMolDraw2D

from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.experimental.explainability import (
from molpipeline.experimental.explainability.explanation import (
SHAPFeatureAndAtomExplanation,
)

from molpipeline.experimental.explainability.visualization.gauss import GaussFunctor2D
from molpipeline.experimental.explainability.visualization.heatmaps import (
ValueGrid,
color_canvas,
get_color_normalizer_from_data,
)
from molpipeline.experimental.explainability.visualization.utils import (
RGBAtuple,
calc_present_and_absent_shap_contributions,
get_atom_coords_of_bond,
get_color_map_from_input,
plt_to_pil,
to_png,
get_mol_lims,
pad,
RGBAtuple,
plt_to_pil,
to_png,
)


Expand Down Expand Up @@ -163,16 +164,11 @@ def _add_gaussians_for_bonds(
ValueGrid object with added functions.
"""
# Adding Gauss-functions centered at bonds (position between the two bonded-atoms)
for i, b in enumerate(mol.GetBonds()):
for i, bond in enumerate(mol.GetBonds()):
if bond_weights[i] == 0:
continue
a1 = b.GetBeginAtom().GetIdx()
a1_pos = conf.GetAtomPosition(a1)
a1_coords = np.array([a1_pos.x, a1_pos.y])

a2 = b.GetEndAtom().GetIdx()
a2_pos = conf.GetAtomPosition(a2)
a2_coords = np.array([a2_pos.x, a2_pos.y])
a1_coords, a2_coords = get_atom_coords_of_bond(bond, conf)

diff = a2_coords - a1_coords
angle = np.arctan2(diff[0], diff[1])
Expand Down Expand Up @@ -379,7 +375,7 @@ def structure_heatmap(
return image


def structure_heatmap_shap( # pylint: disable=too-many-branches
def structure_heatmap_shap( # pylint: disable=too-many-branches, too-many-locals
explanation: SHAPFeatureAndAtomExplanation,
color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None,
width: int = 600,
Expand Down Expand Up @@ -419,11 +415,6 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches
if explanation.atom_weights is None:
raise ValueError("Explanation does not contain atom weights.")

if explanation.feature_vector.max() > 1 or explanation.feature_vector.min() < 0:
raise ValueError(
"Feature vector must be binary. Alternatively, use the structure_heatmap function instead."
)

if explanation.prediction.ndim > 2:
raise ValueError(
"Unsupported shape for prediction. Maximum 2 dimension is supported."
Expand All @@ -436,11 +427,10 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches
else:
raise ValueError("Unsupported shape for feature weights.")

# determine present/absent features using the binary feature vector
present_shap = feature_weights * explanation.feature_vector
absent_shap = feature_weights * (1 - explanation.feature_vector)
sum_present_shap = sum(present_shap)
sum_absent_shap = sum(absent_shap)
# calculate the sum of the SHAP values for present and absent features
sum_present_shap, sum_absent_shap = calc_present_and_absent_shap_contributions(
explanation.feature_vector, feature_weights
)

with plt.ioff():

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from molpipeline.experimental.explainability import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPTreeExplainer,
SHAPKernelExplainer,
SHAPTreeExplainer,
)
from molpipeline.experimental.explainability.explanation import AtomExplanationMixin
from molpipeline.mol2any import (
Expand Down

0 comments on commit 8d3ae06

Please sign in to comment.