diff --git a/molpipeline/explainability/explainer.py b/molpipeline/explainability/explainer.py index f279a765..d50cc4f3 100644 --- a/molpipeline/explainability/explainer.py +++ b/molpipeline/explainability/explainer.py @@ -330,7 +330,7 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_: raise ValueError( "Featurization element does not have a get_feature_names method." ) - explanation_data["feature_names"] = featurization_element.feature_names + explanation_data["feature_names"] = featurization_element.feature_names # type: ignore[union-attr] if issubclass(self.return_element_type_, FeatureExplanationMixin): explanation_data["feature_weights"] = feature_weights diff --git a/tests/test_explainability/test_shap_explainers.py b/tests/test_explainability/test_shap_explainers.py index f7ea0557..0918006b 100644 --- a/tests/test_explainability/test_shap_explainers.py +++ b/tests/test_explainability/test_shap_explainers.py @@ -20,10 +20,7 @@ from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper from molpipeline.abstract_pipeline_elements.core import RDKitMol from molpipeline.any2mol import SmilesToMol -from molpipeline.explainability.explainer import ( - SHAPKernelExplainer, - SHAPTreeExplainer, -) +from molpipeline.explainability.explainer import SHAPKernelExplainer, SHAPTreeExplainer from molpipeline.explainability.explanation import ( AtomExplanationMixin, SHAPFeatureAndAtomExplanation, @@ -128,11 +125,11 @@ def _test_valid_explanation( self.assertTrue( all( isinstance(name, str) and len(name) > 0 - for name in explanation.feature_names + for name in explanation.feature_names # type: ignore[union-attr] ) ) self.assertEqual( - len(explanation.feature_names), explanation.feature_vector.shape[0] + len(explanation.feature_names), explanation.feature_vector.shape[0] # type: ignore ) self.assertIsInstance(explanation.molecule, RDKitMol) diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py index 98488504..3c75278c 100644 --- a/tests/test_explainability/test_visualization/test_visualization.py +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -12,14 +12,12 @@ from molpipeline.any2mol import SmilesToMol from molpipeline.explainability import ( SHAPFeatureAndAtomExplanation, + SHAPFeatureExplanation, SHAPTreeExplainer, structure_heatmap, structure_heatmap_shap, - SHAPFeatureExplanation, -) -from molpipeline.explainability.explainer import ( - SHAPKernelExplainer, ) +from molpipeline.explainability.explainer import SHAPKernelExplainer from molpipeline.mol2any import MolToMorganFP from molpipeline.utils.subpipeline import get_featurization_subpipeline