Skip to content

Commit

Permalink
first part code review
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Jan 14, 2025
1 parent be76892 commit 3bc87e3
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
34 changes: 10 additions & 24 deletions molpipeline/experimental/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from __future__ import annotations

import abc
from typing import Any, TypeAlias
from typing import Any, TypeAlias, Callable

import numpy as np
import numpy.typing as npt
import pandas as pd
import shap
from scipy.sparse import issparse, spmatrix
from sklearn.base import BaseEstimator
from typing_extensions import override

try:
from typing import override # type: ignore[attr-defined]
except ImportError:
from typing_extensions import override

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
Expand Down Expand Up @@ -51,27 +55,9 @@ def _to_dense(
return feature_matrix


def _convert_to_array(value: Any) -> npt.NDArray[np.float64]:
"""Convert a value to a numpy array.
Parameters
----------
value : Any
The value to convert.
Returns
-------
npt.NDArray[np.float64]
The value as a numpy array.
"""
if isinstance(value, np.ndarray):
return value
if np.isscalar(value):
return np.array([value])
raise ValueError("Value is not a scalar or numpy array.")


def _get_prediction_function(pipeline: Pipeline | BaseEstimator) -> Any:
def _get_prediction_function(
pipeline: Pipeline | BaseEstimator,
) -> Callable[[npt.Arraylike], npt.Arraylike]:
"""Get the prediction function of a model.
Parameters
Expand Down Expand Up @@ -359,7 +345,7 @@ def explain(
if issubclass(self.return_element_type_, BondExplanationMixin):
explanation_data["bond_weights"] = bond_weights
if issubclass(self.return_element_type_, SHAPExplanationMixin):
explanation_data["expected_value"] = _convert_to_array(
explanation_data["expected_value"] = np.atleast_1d(
self.explainer.expected_value
)

Expand Down
22 changes: 17 additions & 5 deletions molpipeline/experimental/explainability/visualization/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,19 @@ def __init__(
self.y_lim = y_lim
self.x_res = x_res
self.y_res = y_res
self._dx = (max(self.x_lim) - min(self.x_lim)) / self.x_res
self._dy = (max(self.y_lim) - min(self.y_lim)) / self.y_res
self.values = np.zeros((self.x_res, self.y_res))

@property
def dx(self) -> float:
"""Length of cell in x-direction."""
return (max(self.x_lim) - min(self.x_lim)) / self.x_res
return self._dx

@property
def dy(self) -> float:
"""Length of cell in y-direction."""
return (max(self.y_lim) - min(self.y_lim)) / self.y_res
return self._dy

def grid_field_center(self, x_idx: int, y_idx: int) -> tuple[float, float]:
"""Center of cell specified by index along x and y.
Expand Down Expand Up @@ -149,6 +151,9 @@ def __init__(
y_lim: Sequence[float],
x_res: int,
y_res: int,
function_list: (
list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]] | None
) = None,
):
"""Initialize the ValueGrid with limits and resolution of the axes.
Expand All @@ -162,11 +167,18 @@ def __init__(
Resolution (number of cells) along x-axis.
y_res: int
Resolution (number of cells) along y-axis.
function_list: list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]], optional
List of functions to be evaluated for each cell, by default None.
"""
super().__init__(x_lim, y_lim, x_res, y_res)
self.function_list: list[
Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
] = []
if function_list is not None:
self.function_list: list[
Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
] = function_list
else:
self.function_list: list[
Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
] = []
self.values = np.zeros((self.x_res, self.y_res))

def add_function(
Expand Down

0 comments on commit 3bc87e3

Please sign in to comment.