From 1981d6a69d458e0d6e509755a89c284c503ac1b4 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Mon, 25 Nov 2024 16:28:38 +0100 Subject: [PATCH] add feature_names to explainability code --- molpipeline/explainability/explainer.py | 12 ++++----- molpipeline/explainability/explanation.py | 4 +-- .../test_shap_explainers.py | 27 +++++++++++++++---- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/molpipeline/explainability/explainer.py b/molpipeline/explainability/explainer.py index d930735e..f279a765 100644 --- a/molpipeline/explainability/explainer.py +++ b/molpipeline/explainability/explainer.py @@ -301,11 +301,6 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_: explanation_results.append(self.return_element_type_()) continue - # Feature names should also be extracted from the Pipeline. - # But first, we need to add the names to the pipelines. - # Therefore, feature_names is just None currently. - feature_names = None - # compute the shap values for the features feature_weights = self.explainer.shap_values(feature_vector, **kwargs) feature_weights = np.asarray(feature_weights).squeeze() @@ -331,7 +326,12 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_: } if issubclass(self.return_element_type_, FeatureInfoMixin): explanation_data["feature_vector"] = feature_vector - explanation_data["feature_names"] = feature_names + if not hasattr(featurization_element, "feature_names"): + raise ValueError( + "Featurization element does not have a get_feature_names method." + ) + explanation_data["feature_names"] = featurization_element.feature_names + if issubclass(self.return_element_type_, FeatureExplanationMixin): explanation_data["feature_weights"] = feature_weights if issubclass(self.return_element_type_, AtomExplanationMixin): diff --git a/molpipeline/explainability/explanation.py b/molpipeline/explainability/explanation.py index b12de8e4..917a872b 100644 --- a/molpipeline/explainability/explanation.py +++ b/molpipeline/explainability/explanation.py @@ -78,7 +78,7 @@ def is_valid(self) -> bool: return all( [ self.feature_vector is not None, - # self.feature_names is not None, + self.feature_names is not None, self.molecule is not None, self.prediction is not None, self.feature_weights is not None, @@ -107,7 +107,7 @@ def is_valid(self) -> bool: return all( [ self.feature_vector is not None, - # self.feature_names is not None, # TODO uncomment when PR is merged + self.feature_names is not None, self.molecule is not None, self.prediction is not None, self.feature_weights is not None, diff --git a/tests/test_explainability/test_shap_explainers.py b/tests/test_explainability/test_shap_explainers.py index 25d29c51..f7ea0557 100644 --- a/tests/test_explainability/test_shap_explainers.py +++ b/tests/test_explainability/test_shap_explainers.py @@ -124,9 +124,16 @@ def _test_valid_explanation( (nof_features,), explanation.feature_vector.shape # type: ignore[union-attr] ) - # feature names are not implemented yet - self.assertIsNone(explanation.feature_names) - # self.assertEqual(len(explanation.feature_names), explanation.feature_vector.shape[0]) + # feature names should be a list of not empty strings + self.assertTrue( + all( + isinstance(name, str) and len(name) > 0 + for name in explanation.feature_names + ) + ) + self.assertEqual( + len(explanation.feature_names), explanation.feature_vector.shape[0] + ) self.assertIsInstance(explanation.molecule, RDKitMol) self.assertEqual( @@ -176,9 +183,9 @@ def _test_valid_explanation( (explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr] ) - def test_explanations_fingerprint_pipeline( + def test_explanations_fingerprint_pipeline( # pylint: disable=too-many-locals self, - ) -> None: # pylint: disable=too-many-locals + ) -> None: """Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints.""" tree_estimators = [ @@ -371,6 +378,11 @@ def test_explanations_pipeline_with_physchem(self) -> None: explainer=explainer, ) + self.assertEqual( + explanation.feature_names, + pipeline.named_steps["physchem"].feature_names, + ) + def test_explanations_pipeline_with_concatenated_features(self) -> None: """Test SHAP's TreeExplainer wrapper on concatenated feature vector.""" @@ -428,3 +440,8 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None: TEST_SMILES[i], explainer=explainer, ) + + self.assertEqual( + explanation.feature_names, + pipeline.named_steps["features"].feature_names, + )