Skip to content

Commit

Permalink
add xai notebook and adaptations
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 28, 2024
1 parent d6bc7c2 commit 398a5dc
Show file tree
Hide file tree
Showing 8 changed files with 1,443 additions and 134 deletions.
24 changes: 23 additions & 1 deletion molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ def _to_dense(
return feature_matrix


def _convert_to_array(value: Any) -> npt.NDArray[np.float64]:
"""Convert a value to a numpy array.
Parameters
----------
value : Any
The value to convert.
Returns
-------
npt.NDArray[np.float64]
The value as a numpy array.
"""
if isinstance(value, np.ndarray):
return value
if np.isscalar(value):
return np.array([value])
raise ValueError("Value is not a scalar or numpy array.")


def _get_prediction_function(pipeline: Pipeline | BaseEstimator) -> Any:
"""Get the prediction function of a model.
Expand Down Expand Up @@ -339,7 +359,9 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
if issubclass(self.return_element_type_, BondExplanationMixin):
explanation_data["bond_weights"] = bond_weights
if issubclass(self.return_element_type_, SHAPExplanationMixin):
explanation_data["expected_value"] = self.explainer.expected_value
explanation_data["expected_value"] = _convert_to_array(
self.explainer.expected_value
)

explanation_results.append(self.return_element_type_(**explanation_data))

Expand Down
2 changes: 1 addition & 1 deletion molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class _AbstractMoleculeExplanation(abc.ABC):
"""Abstract class representing an explanation for a prediction for a molecule."""

molecule: RDKitMol | None = None
prediction: float | npt.NDArray[np.float64] | None = None
prediction: npt.NDArray[np.float64] | None = None


@dataclasses.dataclass(kw_only=True)
Expand Down
120 changes: 65 additions & 55 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,68 +417,78 @@ def structure_heatmap_shap(
if explanation.atom_weights is None:
raise ValueError("Explanation does not contain atom weights.")

present_shap = explanation.feature_weights[:, 1] * explanation.feature_vector
absent_shap = explanation.feature_weights[:, 1] * (1 - explanation.feature_vector)
if explanation.feature_vector.max() > 1 or explanation.feature_vector.min() < 0:
raise ValueError(
f"Feature vector must be binary. Alternatively, use the structure_heatmap function instead."
)

if explanation.prediction.ndim > 2:
raise ValueError(
"Unsupported shape for prediction. Maximum 2 dimension is supported."
)

if explanation.feature_weights.ndim == 1:
feature_weights = explanation.feature_weights
elif explanation.feature_weights.ndim == 2:
feature_weights = explanation.feature_weights[:, 1]
else:
raise ValueError("Unsupported shape for feature weights.")

# determine present/absent features using the binary feature vector
present_shap = feature_weights * explanation.feature_vector
absent_shap = feature_weights * (1 - explanation.feature_vector)
sum_present_shap = sum(present_shap)
sum_absent_shap = sum(absent_shap)

drawer, _, _, normalizer, color_map = _structure_heatmap(
explanation.molecule,
explanation.atom_weights,
color=color,
width=width,
height=height,
color_limits=color_limits,
)
figure_bytes = drawer.GetDrawingText()
image = to_png(figure_bytes)
image_array = np.array(image)
with plt.ioff():

fig, ax = plt.subplots(figsize=(8, 8))
drawer, _, _, normalizer, color_map = _structure_heatmap(
explanation.molecule,
explanation.atom_weights,
color=color,
width=width,
height=height,
color_limits=color_limits,
)
figure_bytes = drawer.GetDrawingText()
image_heatmap = to_png(figure_bytes)
image_array = np.array(image_heatmap)

im = ax.imshow(
image_array,
cmap=color_map,
norm=normalizer,
)
# remove ticks
ax.set_xticks([])
ax.set_yticks([])
# remove border
for spine in ax.spines.values():
spine.set_visible(False)

fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0)

if isinstance(explanation.prediction, float):
# regression case
raise NotImplementedError("Regression case not yet implemented.")
if isinstance(explanation.prediction, np.ndarray):
if len(explanation.prediction) == 2:
# binary classification case
text = (
f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
else:
raise ValueError(
"Unsupported number of classes for prediction. Only binary classification is supported."
)
else:
raise ValueError(
"Unsupported type for prediction. Expected float or np.ndarray."
fig, ax = plt.subplots(figsize=(8, 8))

im = ax.imshow(
image_array,
cmap=color_map,
norm=normalizer,
)
# remove ticks
ax.set_xticks([])
ax.set_yticks([])
# remove border
for spine in ax.spines.values():
spine.set_visible(False)

fig.colorbar(im, ax=ax, orientation="vertical", fraction=0.015, pad=0.0)

# note: the prediction/expected value of the last array element is used
text = (
f"$P(y=1|X) = {explanation.prediction[-1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[-1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
fig.text(0.5, 0.18, text, ha="center")

fig.text(0.5, 0.18, text, ha="center")
image = plt_to_pil(fig)
# clear the figure and memory
plt.close(fig)

image = plt_to_pil(fig)
# clear the figure and memory
plt.close()
plt.clf()
plt.cla()
# remove dpi info because it crashes ipython's display function
if "dpi" in image.info:
del image.info["dpi"]
# keep RDKit's image info
image.info.update(image_heatmap.info)

return image
Loading

0 comments on commit 398a5dc

Please sign in to comment.