From 8d3ae06f29a9e7b01cc7021fc8c30bd8aaf0648a Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Thu, 5 Dec 2024 10:35:23 +0100 Subject: [PATCH] linting --- .../experimental/explainability/explainer.py | 31 ++++----- .../explainability/visualization/utils.py | 65 +++++++++++++++++++ .../visualization/visualization.py | 36 ++++------ .../test_shap_explainers.py | 2 +- 4 files changed, 95 insertions(+), 39 deletions(-) diff --git a/molpipeline/experimental/explainability/explainer.py b/molpipeline/experimental/explainability/explainer.py index e39b5a1e..27ec824c 100644 --- a/molpipeline/experimental/explainability/explainer.py +++ b/molpipeline/experimental/explainability/explainer.py @@ -166,18 +166,18 @@ def _convert_shap_feature_weights_to_atom_weights( return atom_weights -_SHAPExplainer_return_type_: TypeAlias = list[ +ShapExplanation: TypeAlias = list[ SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation ] -class AbstractSHAPExplainer(abc.ABC): +class AbstractSHAPExplainer(abc.ABC): # pylint: disable=too-few-public-methods """Abstract class for SHAP explainer objects.""" @abc.abstractmethod def explain( - self, X: Any, **kwargs: Any - ) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument + self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument + ) -> ShapExplanation: """Explain the predictions for the input data. Parameters @@ -194,14 +194,15 @@ def explain( """ -class SHAPExplainerAdapter(AbstractSHAPExplainer, abc.ABC): +class SHAPExplainerAdapter( + AbstractSHAPExplainer, abc.ABC +): # pylint: disable=too-few-public-methods """Adapter for SHAP explainer wrappers for handling molecules and pipelines.""" def __init__( self, pipeline: Pipeline, explainer: shap.TreeExplainer | shap.KernelExplainer, - **kwargs: Any, ) -> None: """Initialize the SHAPTreeExplainer. @@ -211,8 +212,6 @@ def __init__( The pipeline containing the model to explain. explainer : shap.TreeExplainer | shap.KernelExplainer The shap explainer object. - kwargs : Any - Additional keyword arguments for SHAP's TreeExplainer. """ self.pipeline = pipeline self.explainer = explainer @@ -271,8 +270,8 @@ def _prediction_is_valid(prediction: Any) -> bool: @override def explain( - self, X: Any, **kwargs: Any - ) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument + self, X: Any, **kwargs: Any # pylint: disable=invalid-name,unused-argument + ) -> ShapExplanation: """Explain the predictions for the input data. If the calculation of the SHAP values for an input sample fails, the explanation will be invalid. @@ -292,7 +291,7 @@ def explain( """ featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr] - explanation_results: _SHAPExplainer_return_type_ = [] + explanation_results: ShapExplanation = [] for input_sample in X: input_sample = [input_sample] @@ -369,7 +368,7 @@ def explain( return explanation_results -class SHAPTreeExplainer(SHAPExplainerAdapter): +class SHAPTreeExplainer(SHAPExplainerAdapter): # pylint: disable=too-few-public-methods """Wrapper for SHAP's TreeExplainer that can handle pipelines and molecules. Wraps SHAP's TreeExplainer to explain predictions of a pipeline containing a @@ -397,7 +396,7 @@ def __init__( Additional keyword arguments for SHAP's Explainer. """ explainer = self._create_explainer(pipeline, **kwargs) - super().__init__(pipeline, explainer, **kwargs) + super().__init__(pipeline, explainer) @staticmethod def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer: @@ -423,7 +422,9 @@ def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer: return explainer -class SHAPKernelExplainer(SHAPExplainerAdapter): +class SHAPKernelExplainer( + SHAPExplainerAdapter +): # pylint: disable=too-few-public-methods """Wrapper for SHAP's KernelExplainer that can handle pipelines and molecules.""" def __init__( @@ -441,7 +442,7 @@ def __init__( Additional keyword arguments for SHAP's Explainer. """ explainer = self._create_explainer(pipeline, **kwargs) - super().__init__(pipeline, explainer, **kwargs) + super().__init__(pipeline, explainer) @staticmethod def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.KernelExplainer: diff --git a/molpipeline/experimental/explainability/visualization/utils.py b/molpipeline/experimental/explainability/visualization/utils.py index d6ef2d0d..338821aa 100644 --- a/molpipeline/experimental/explainability/visualization/utils.py +++ b/molpipeline/experimental/explainability/visualization/utils.py @@ -170,3 +170,68 @@ def plt_to_pil(figure: plt.Figure) -> Image.Image: bio.seek(0) img = Image.open(bio) return img + + +def get_atom_coords_of_bond( + bond: Chem.Bond, conf: Chem.Conformer +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """Get the two atom coordinates of a bond in the conformation. + + Parameters + ---------- + bond: Chem.Bond + The bond. + conf: Chem.Conformer + The conformation. + + Returns + ------- + tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]] + The atom coordinates. + """ + a1 = bond.GetBeginAtom().GetIdx() + a1_pos = conf.GetAtomPosition(a1) + a1_coords = np.array([a1_pos.x, a1_pos.y]) + + a2 = bond.GetEndAtom().GetIdx() + a2_pos = conf.GetAtomPosition(a2) + a2_coords = np.array([a2_pos.x, a2_pos.y]) + + return a1_coords, a2_coords + + +def calc_present_and_absent_shap_contributions( + feature_vector: npt.NDArray[np.float64], feature_weights: npt.NDArray[np.float64] +) -> tuple[float, float]: + """Get the sum of present and absent SHAP values. + + Parameters + ---------- + feature_vector: npt.NDArray[np.float64] + The feature vector. + feature_weights: npt.NDArray[np.float64] + The feature weights. + + Raises + ------ + ValueError + If the feature vector is not binary. + + Returns + ------- + tuple[float, float] + The sum of present and absent SHAP values. + """ + + if feature_vector.max() > 1 or feature_vector.min() < 0: + raise ValueError( + "Feature vector must be binary. Alternatively, use the structure_heatmap function instead." + ) + + # determine present/absent features using the binary feature vector + present_shap = feature_weights * feature_vector + absent_shap = feature_weights * (1 - feature_vector) + sum_present_shap = sum(present_shap) + sum_absent_shap = sum(absent_shap) + + return sum_present_shap, sum_absent_shap diff --git a/molpipeline/experimental/explainability/visualization/visualization.py b/molpipeline/experimental/explainability/visualization/visualization.py index c2372cfc..24216b4b 100644 --- a/molpipeline/experimental/explainability/visualization/visualization.py +++ b/molpipeline/experimental/explainability/visualization/visualization.py @@ -20,10 +20,9 @@ from rdkit.Chem.Draw import rdMolDraw2D from molpipeline.abstract_pipeline_elements.core import RDKitMol -from molpipeline.experimental.explainability import ( +from molpipeline.experimental.explainability.explanation import ( SHAPFeatureAndAtomExplanation, ) - from molpipeline.experimental.explainability.visualization.gauss import GaussFunctor2D from molpipeline.experimental.explainability.visualization.heatmaps import ( ValueGrid, @@ -31,12 +30,14 @@ get_color_normalizer_from_data, ) from molpipeline.experimental.explainability.visualization.utils import ( + RGBAtuple, + calc_present_and_absent_shap_contributions, + get_atom_coords_of_bond, get_color_map_from_input, - plt_to_pil, - to_png, get_mol_lims, pad, - RGBAtuple, + plt_to_pil, + to_png, ) @@ -163,16 +164,11 @@ def _add_gaussians_for_bonds( ValueGrid object with added functions. """ # Adding Gauss-functions centered at bonds (position between the two bonded-atoms) - for i, b in enumerate(mol.GetBonds()): + for i, bond in enumerate(mol.GetBonds()): if bond_weights[i] == 0: continue - a1 = b.GetBeginAtom().GetIdx() - a1_pos = conf.GetAtomPosition(a1) - a1_coords = np.array([a1_pos.x, a1_pos.y]) - a2 = b.GetEndAtom().GetIdx() - a2_pos = conf.GetAtomPosition(a2) - a2_coords = np.array([a2_pos.x, a2_pos.y]) + a1_coords, a2_coords = get_atom_coords_of_bond(bond, conf) diff = a2_coords - a1_coords angle = np.arctan2(diff[0], diff[1]) @@ -379,7 +375,7 @@ def structure_heatmap( return image -def structure_heatmap_shap( # pylint: disable=too-many-branches +def structure_heatmap_shap( # pylint: disable=too-many-branches, too-many-locals explanation: SHAPFeatureAndAtomExplanation, color: str | Colormap | tuple[RGBAtuple, RGBAtuple, RGBAtuple] | None = None, width: int = 600, @@ -419,11 +415,6 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches if explanation.atom_weights is None: raise ValueError("Explanation does not contain atom weights.") - if explanation.feature_vector.max() > 1 or explanation.feature_vector.min() < 0: - raise ValueError( - "Feature vector must be binary. Alternatively, use the structure_heatmap function instead." - ) - if explanation.prediction.ndim > 2: raise ValueError( "Unsupported shape for prediction. Maximum 2 dimension is supported." @@ -436,11 +427,10 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches else: raise ValueError("Unsupported shape for feature weights.") - # determine present/absent features using the binary feature vector - present_shap = feature_weights * explanation.feature_vector - absent_shap = feature_weights * (1 - explanation.feature_vector) - sum_present_shap = sum(present_shap) - sum_absent_shap = sum(absent_shap) + # calculate the sum of the SHAP values for present and absent features + sum_present_shap, sum_absent_shap = calc_present_and_absent_shap_contributions( + explanation.feature_vector, feature_weights + ) with plt.ioff(): diff --git a/tests/test_experimental/test_explainability/test_shap_explainers.py b/tests/test_experimental/test_explainability/test_shap_explainers.py index 7ddf8b51..63a4035f 100644 --- a/tests/test_experimental/test_explainability/test_shap_explainers.py +++ b/tests/test_experimental/test_explainability/test_shap_explainers.py @@ -21,8 +21,8 @@ from molpipeline.experimental.explainability import ( SHAPFeatureAndAtomExplanation, SHAPFeatureExplanation, - SHAPTreeExplainer, SHAPKernelExplainer, + SHAPTreeExplainer, ) from molpipeline.experimental.explainability.explanation import AtomExplanationMixin from molpipeline.mol2any import (