Skip to content

Commit

Permalink
add feature_names to explainability code
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 25, 2024
1 parent 5603cf0 commit 1981d6a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
12 changes: 6 additions & 6 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,6 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
explanation_results.append(self.return_element_type_())
continue

# Feature names should also be extracted from the Pipeline.
# But first, we need to add the names to the pipelines.
# Therefore, feature_names is just None currently.
feature_names = None

# compute the shap values for the features
feature_weights = self.explainer.shap_values(feature_vector, **kwargs)
feature_weights = np.asarray(feature_weights).squeeze()
Expand All @@ -331,7 +326,12 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
}
if issubclass(self.return_element_type_, FeatureInfoMixin):
explanation_data["feature_vector"] = feature_vector
explanation_data["feature_names"] = feature_names
if not hasattr(featurization_element, "feature_names"):
raise ValueError(
"Featurization element does not have a get_feature_names method."
)
explanation_data["feature_names"] = featurization_element.feature_names

if issubclass(self.return_element_type_, FeatureExplanationMixin):
explanation_data["feature_weights"] = feature_weights
if issubclass(self.return_element_type_, AtomExplanationMixin):
Expand Down
4 changes: 2 additions & 2 deletions molpipeline/explainability/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def is_valid(self) -> bool:
return all(
[
self.feature_vector is not None,
# self.feature_names is not None,
self.feature_names is not None,
self.molecule is not None,
self.prediction is not None,
self.feature_weights is not None,
Expand Down Expand Up @@ -107,7 +107,7 @@ def is_valid(self) -> bool:
return all(
[
self.feature_vector is not None,
# self.feature_names is not None, # TODO uncomment when PR is merged
self.feature_names is not None,
self.molecule is not None,
self.prediction is not None,
self.feature_weights is not None,
Expand Down
27 changes: 22 additions & 5 deletions tests/test_explainability/test_shap_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,16 @@ def _test_valid_explanation(
(nof_features,), explanation.feature_vector.shape # type: ignore[union-attr]
)

# feature names are not implemented yet
self.assertIsNone(explanation.feature_names)
# self.assertEqual(len(explanation.feature_names), explanation.feature_vector.shape[0])
# feature names should be a list of not empty strings
self.assertTrue(
all(
isinstance(name, str) and len(name) > 0
for name in explanation.feature_names
)
)
self.assertEqual(
len(explanation.feature_names), explanation.feature_vector.shape[0]
)

self.assertIsInstance(explanation.molecule, RDKitMol)
self.assertEqual(
Expand Down Expand Up @@ -176,9 +183,9 @@ def _test_valid_explanation(
(explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr]
)

def test_explanations_fingerprint_pipeline(
def test_explanations_fingerprint_pipeline( # pylint: disable=too-many-locals
self,
) -> None: # pylint: disable=too-many-locals
) -> None:
"""Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints."""

tree_estimators = [
Expand Down Expand Up @@ -371,6 +378,11 @@ def test_explanations_pipeline_with_physchem(self) -> None:
explainer=explainer,
)

self.assertEqual(
explanation.feature_names,
pipeline.named_steps["physchem"].feature_names,
)

def test_explanations_pipeline_with_concatenated_features(self) -> None:
"""Test SHAP's TreeExplainer wrapper on concatenated feature vector."""

Expand Down Expand Up @@ -428,3 +440,8 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None:
TEST_SMILES[i],
explainer=explainer,
)

self.assertEqual(
explanation.feature_names,
pipeline.named_steps["features"].feature_names,
)

0 comments on commit 1981d6a

Please sign in to comment.