From 3bc87e3d84f4a92b687cbf1405ab18c42cf78c46 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Tue, 14 Jan 2025 17:32:40 +0100 Subject: [PATCH] first part code review --- .../experimental/explainability/explainer.py | 34 ++++++------------- .../explainability/visualization/heatmaps.py | 22 +++++++++--- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/molpipeline/experimental/explainability/explainer.py b/molpipeline/experimental/explainability/explainer.py index 27ec824c..ed312d5f 100644 --- a/molpipeline/experimental/explainability/explainer.py +++ b/molpipeline/experimental/explainability/explainer.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc -from typing import Any, TypeAlias +from typing import Any, TypeAlias, Callable import numpy as np import numpy.typing as npt @@ -11,7 +11,11 @@ import shap from scipy.sparse import issparse, spmatrix from sklearn.base import BaseEstimator -from typing_extensions import override + +try: + from typing import override # type: ignore[attr-defined] +except ImportError: + from typing_extensions import override from molpipeline import Pipeline from molpipeline.abstract_pipeline_elements.core import OptionalMol @@ -51,27 +55,9 @@ def _to_dense( return feature_matrix -def _convert_to_array(value: Any) -> npt.NDArray[np.float64]: - """Convert a value to a numpy array. - - Parameters - ---------- - value : Any - The value to convert. - - Returns - ------- - npt.NDArray[np.float64] - The value as a numpy array. - """ - if isinstance(value, np.ndarray): - return value - if np.isscalar(value): - return np.array([value]) - raise ValueError("Value is not a scalar or numpy array.") - - -def _get_prediction_function(pipeline: Pipeline | BaseEstimator) -> Any: +def _get_prediction_function( + pipeline: Pipeline | BaseEstimator, +) -> Callable[[npt.Arraylike], npt.Arraylike]: """Get the prediction function of a model. Parameters @@ -359,7 +345,7 @@ def explain( if issubclass(self.return_element_type_, BondExplanationMixin): explanation_data["bond_weights"] = bond_weights if issubclass(self.return_element_type_, SHAPExplanationMixin): - explanation_data["expected_value"] = _convert_to_array( + explanation_data["expected_value"] = np.atleast_1d( self.explainer.expected_value ) diff --git a/molpipeline/experimental/explainability/visualization/heatmaps.py b/molpipeline/experimental/explainability/visualization/heatmaps.py index 4fdae123..8456b4ca 100644 --- a/molpipeline/experimental/explainability/visualization/heatmaps.py +++ b/molpipeline/experimental/explainability/visualization/heatmaps.py @@ -50,17 +50,19 @@ def __init__( self.y_lim = y_lim self.x_res = x_res self.y_res = y_res + self._dx = (max(self.x_lim) - min(self.x_lim)) / self.x_res + self._dy = (max(self.y_lim) - min(self.y_lim)) / self.y_res self.values = np.zeros((self.x_res, self.y_res)) @property def dx(self) -> float: """Length of cell in x-direction.""" - return (max(self.x_lim) - min(self.x_lim)) / self.x_res + return self._dx @property def dy(self) -> float: """Length of cell in y-direction.""" - return (max(self.y_lim) - min(self.y_lim)) / self.y_res + return self._dy def grid_field_center(self, x_idx: int, y_idx: int) -> tuple[float, float]: """Center of cell specified by index along x and y. @@ -149,6 +151,9 @@ def __init__( y_lim: Sequence[float], x_res: int, y_res: int, + function_list: ( + list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]] | None + ) = None, ): """Initialize the ValueGrid with limits and resolution of the axes. @@ -162,11 +167,18 @@ def __init__( Resolution (number of cells) along x-axis. y_res: int Resolution (number of cells) along y-axis. + function_list: list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]], optional + List of functions to be evaluated for each cell, by default None. """ super().__init__(x_lim, y_lim, x_res, y_res) - self.function_list: list[ - Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]] - ] = [] + if function_list is not None: + self.function_list: list[ + Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]] + ] = function_list + else: + self.function_list: list[ + Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]] + ] = [] self.values = np.zeros((self.x_res, self.y_res)) def add_function(