diff --git a/tests/test_explainability/test_shap_explainers.py b/tests/test_explainability/test_shap_explainers.py index 71e9aeca..56cbc02f 100644 --- a/tests/test_explainability/test_shap_explainers.py +++ b/tests/test_explainability/test_shap_explainers.py @@ -23,7 +23,6 @@ from molpipeline.explainability.explainer import ( SHAPKernelExplainer, SHAPTreeExplainer, - SHAPExplainerAdapter, ) from molpipeline.explainability.explanation import ( AtomExplanationMixin, diff --git a/tests/test_explainability/test_visualization/test_gaussian_grid.py b/tests/test_explainability/test_visualization/test_gaussian_grid.py index 398e068f..481cb2f1 100644 --- a/tests/test_explainability/test_visualization/test_gaussian_grid.py +++ b/tests/test_explainability/test_visualization/test_gaussian_grid.py @@ -31,7 +31,7 @@ class TestSumOfGaussiansGrid(unittest.TestCase): test_pipeline: ClassVar[Pipeline] test_explainer: ClassVar[SHAPTreeExplainer] test_explanations: ClassVar[ - list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation] + list[SHAPFeatureAndAtomExplanation | SHAPFeatureExplanation] ] @classmethod diff --git a/tests/test_explainability/test_visualization/test_visualization.py b/tests/test_explainability/test_visualization/test_visualization.py index 6390664e..df0055be 100644 --- a/tests/test_explainability/test_visualization/test_visualization.py +++ b/tests/test_explainability/test_visualization/test_visualization.py @@ -19,7 +19,6 @@ ) from molpipeline.explainability.explainer import ( SHAPKernelExplainer, - SHAPExplainerAdapter, ) from molpipeline.mol2any import MolToMorganFP from molpipeline.utils.subpipeline import get_featurization_subpipeline @@ -100,7 +99,7 @@ def test_structure_heatmap_fingerprint_based_atom_coloring(self) -> None: self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr] image = structure_heatmap( explanation.molecule, - explanation.atom_weights, # type: ignore[arg-type] + explanation.atom_weights, # type: ignore[union-type] width=128, height=128, ) # type: ignore[union-attr] @@ -115,9 +114,10 @@ def test_structure_heatmap_shap_explanation(self) -> None: ]: for explanation in explanation_list: self.assertTrue(explanation.is_valid()) + self.assertIsInstance(explanation, SHAPFeatureAndAtomExplanation) self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr] image = structure_heatmap_shap( - explanation=explanation, + explanation=explanation, # type: ignore[arg-type] width=128, height=128, ) # type: ignore[union-attr] @@ -143,9 +143,9 @@ def test_explicit_hydrogens(self) -> None: self.assertEqual(len(explanations1), 1) self.assertEqual(len(explanations2), 1) self.assertEqual(len(explanations3), 1) - self.assertIsInstance(explanations1[0].atom_weights, np.ndarray) - self.assertIsInstance(explanations2[0].atom_weights, np.ndarray) - self.assertIsInstance(explanations3[0].atom_weights, np.ndarray) + self.assertIsInstance(explanations1[0].atom_weights, np.ndarray) # type: ignore[union-attr] + self.assertIsInstance(explanations2[0].atom_weights, np.ndarray) # type: ignore[union-attr] + self.assertIsInstance(explanations3[0].atom_weights, np.ndarray) # type: ignore[union-attr] self.assertEqual(len(explanations1[0].atom_weights), 1) # type: ignore[arg-type] self.assertEqual(len(explanations2[0].atom_weights), 1) # type: ignore[arg-type] self.assertEqual(len(explanations3[0].atom_weights), 1) # type: ignore[arg-type] @@ -156,7 +156,7 @@ def test_explicit_hydrogens(self) -> None: self.assertTrue(explanation.is_valid()) image = structure_heatmap( explanation.molecule, - explanation.atom_weights, # type: ignore[arg-type] + explanation.atom_weights, # type: ignore[union-attr] width=128, height=128, ) # type: ignore[union-attr]