Skip to content

Commit

Permalink
explainability: Ignore mypy's arg-type error
Browse files Browse the repository at this point in the history
    - Ignore mypy's union-type errors manifesting as arg-type error.
  • Loading branch information
JochenSiegWork committed Jun 7, 2024
1 parent b70269a commit ee1340b
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tests/test_explainability/test_shap_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def _test_valid_explanation(

self.assertIsInstance(explanation.feature_vector, np.ndarray)
self.assertEqual(
(nof_features,), explanation.feature_vector.shape
) # type: ignore[union-attr]
(nof_features,), explanation.feature_vector.shape # type: ignore[union-attr]
)

# feature names are not implemented yet
self.assertIsNone(explanation.feature_names)
Expand All @@ -94,22 +94,22 @@ def _test_valid_explanation(
if is_regressor(estimator):
self.assertTrue((1,), explanation.prediction.shape) # type: ignore[union-attr]
self.assertEqual(
(nof_features,), explanation.feature_weights.shape
) # type: ignore[union-attr]
(nof_features,), explanation.feature_weights.shape # type: ignore[union-attr]
)
elif is_classifier(estimator):
self.assertTrue((2,), explanation.prediction.shape) # type: ignore[union-attr]
if isinstance(estimator, GradientBoostingClassifier):
# there is currently a bug in SHAP's TreeExplainer for GradientBoostingClassifier
# https://github.com/shap/shap/issues/3177 returning only one feature weight
# which is also based on log odds. This check is a workaround until the bug is fixed.
self.assertEqual(
(nof_features,), explanation.feature_weights.shape
) # type: ignore[union-attr]
(nof_features,), explanation.feature_weights.shape # type: ignore[union-attr]
)
else:
# normal binary classification case
self.assertEqual(
(nof_features, 2), explanation.feature_weights.shape
) # type: ignore[union-attr]
(nof_features, 2), explanation.feature_weights.shape # type: ignore[union-attr]
)
else:
raise ValueError("Error in unittest. Unsupported estimator.")

Expand All @@ -118,7 +118,7 @@ def _test_valid_explanation(
self.assertEqual(
explanation.atom_weights.shape, # type: ignore[union-attr]
(explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr]
) # type: ignore[union-attr]
)
else:
self.assertIsNone(explanation.atom_weights)

Expand Down Expand Up @@ -162,11 +162,11 @@ def test_explanations_fingerprint_pipeline(self) -> None:
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[union-attr]
mol_reader_subpipeline, # type: ignore[arg-type]
n_bits,
TEST_SMILES[i],
is_morgan_fingerprint=True,
) # type: ignore[union-attr]
)

def test_explanations_pipeline_with_invalid_inputs(self) -> None:
"""Test SHAP's TreeExplainer wrapper with invalid inputs."""
Expand Down Expand Up @@ -234,11 +234,11 @@ def test_explanations_pipeline_with_invalid_inputs(self) -> None:
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[union-attr]
mol_reader_subpipeline, # type: ignore[arg-type]
n_bits,
TEST_SMILES_WITH_BAD_SMILES[i],
is_morgan_fingerprint=True,
) # type: ignore[union-attr]
)

def test_explanations_pipeline_with_physchem(self) -> None:
"""Test SHAP's TreeExplainer wrapper on physchem feature vector."""
Expand Down Expand Up @@ -276,11 +276,11 @@ def test_explanations_pipeline_with_physchem(self) -> None:
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[union-attr]
mol_reader_subpipeline, # type: ignore[arg-type]
pipeline.named_steps["physchem"].n_features,
TEST_SMILES[i],
is_morgan_fingerprint=False,
) # type: ignore[union-attr]
)

def test_explanations_pipeline_with_concatenated_features(self) -> None:
"""Test SHAP's TreeExplainer wrapper on concatenated feature vector."""
Expand Down Expand Up @@ -334,8 +334,8 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None:
self._test_valid_explanation(
explanation,
estimator,
mol_reader_subpipeline, # type: ignore[union-attr]
mol_reader_subpipeline, # type: ignore[arg-type]
pipeline.named_steps["features"].n_features,
TEST_SMILES[i],
is_morgan_fingerprint=False,
) # type: ignore[union-attr]
)

0 comments on commit ee1340b

Please sign in to comment.